Make SqlState into an opaque type rather than enum

This commit is contained in:
Steven Fackler 2017-07-08 20:52:36 -07:00
parent 6f0950b45b
commit 01a1529624
11 changed files with 1094 additions and 1131 deletions

View File

@ -7,3 +7,4 @@ authors = ["Steven Fackler <sfackler@gmail.com>"]
phf_codegen = "=0.7.21" phf_codegen = "=0.7.21"
regex = "0.1" regex = "0.1"
marksman_escape = "0.1" marksman_escape = "0.1"
linked-hash-map = "0.4"

View File

@ -1,6 +1,7 @@
extern crate phf_codegen; extern crate phf_codegen;
extern crate regex; extern crate regex;
extern crate marksman_escape; extern crate marksman_escape;
extern crate linked_hash_map;
use std::ascii::AsciiExt; use std::ascii::AsciiExt;
use std::path::Path; use std::path::Path;

View File

@ -2,29 +2,22 @@ use std::fs::File;
use std::io::{Write, BufWriter}; use std::io::{Write, BufWriter};
use std::path::Path; use std::path::Path;
use phf_codegen; use phf_codegen;
use linked_hash_map::LinkedHashMap;
use snake_to_camel;
const ERRCODES_TXT: &'static str = include_str!("errcodes.txt"); const ERRCODES_TXT: &'static str = include_str!("errcodes.txt");
struct Code {
code: String,
variant: String,
}
pub fn build(path: &Path) { pub fn build(path: &Path) {
let mut file = BufWriter::new(File::create(path.join("error/sqlstate.rs")).unwrap()); let mut file = BufWriter::new(File::create(path.join("error/sqlstate.rs")).unwrap());
let codes = parse_codes(); let codes = parse_codes();
make_header(&mut file); make_type(&mut file);
make_enum(&codes, &mut file); make_consts(&codes, &mut file);
make_map(&codes, &mut file); make_map(&codes, &mut file);
make_impl(&codes, &mut file);
} }
fn parse_codes() -> Vec<Code> { fn parse_codes() -> LinkedHashMap<String, Vec<String>> {
let mut codes = vec![]; let mut codes = LinkedHashMap::new();
for line in ERRCODES_TXT.lines() { for line in ERRCODES_TXT.lines() {
if line.starts_with("#") || line.starts_with("Section") || line.trim().is_empty() { if line.starts_with("#") || line.starts_with("Section") || line.trim().is_empty() {
@ -34,131 +27,70 @@ fn parse_codes() -> Vec<Code> {
let mut it = line.split_whitespace(); let mut it = line.split_whitespace();
let code = it.next().unwrap().to_owned(); let code = it.next().unwrap().to_owned();
it.next(); it.next();
it.next(); let name = it.next().unwrap().replace("ERRCODE_", "");
// for 2202E
let name = match it.next() {
Some(name) => name,
None => continue,
};
let variant = match variant_name(&code) {
Some(variant) => variant,
None => snake_to_camel(&name),
};
codes.push(Code { codes.entry(code).or_insert_with(Vec::new).push(name);
code: code,
variant: variant,
});
} }
codes codes
} }
fn variant_name(code: &str) -> Option<String> { fn make_type(file: &mut BufWriter<File>) {
match code {
"01004" => Some("WarningStringDataRightTruncation".to_owned()),
"22001" => Some("DataStringDataRightTruncation".to_owned()),
"2F002" => Some("SqlRoutineModifyingSqlDataNotPermitted".to_owned()),
"38002" => Some("ForeignRoutineModifyingSqlDataNotPermitted".to_owned()),
"2F003" => Some("SqlRoutineProhibitedSqlStatementAttempted".to_owned()),
"38003" => Some("ForeignRoutineProhibitedSqlStatementAttempted".to_owned()),
"2F004" => Some("SqlRoutineReadingSqlDataNotPermitted".to_owned()),
"38004" => Some("ForeignRoutineReadingSqlDataNotPermitted".to_owned()),
"22004" => Some("DataNullValueNotAllowed".to_owned()),
"39004" => Some("ExternalRoutineInvocationNullValueNotAllowed".to_owned()),
_ => None,
}
}
fn make_header(file: &mut BufWriter<File>) {
write!( write!(
file, file,
"// Autogenerated file - DO NOT EDIT "// Autogenerated file - DO NOT EDIT
use phf; use phf;
use std::borrow::Cow;
" /// A SQLSTATE error code
).unwrap();
}
fn make_enum(codes: &[Code], file: &mut BufWriter<File>) {
write!(
file,
r#"/// SQLSTATE error codes
#[derive(PartialEq, Eq, Clone, Debug)] #[derive(PartialEq, Eq, Clone, Debug)]
#[allow(enum_variant_names)] pub struct SqlState(Cow<'static, str>);
pub enum SqlState {{
"#
).unwrap();
for code in codes {
write!(
file,
" /// `{}`
{},\n",
code.code,
code.variant
).unwrap();
}
write!(
file,
" /// An unknown code
Other(String),
}}
"
).unwrap();
}
fn make_map(codes: &[Code], file: &mut BufWriter<File>) {
write!(
file,
"#[cfg_attr(rustfmt, rustfmt_skip)]
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = "
).unwrap();
let mut builder = phf_codegen::Map::new();
for code in codes {
builder.entry(&*code.code, &format!("SqlState::{}", code.variant));
}
builder.build(file).unwrap();
write!(file, ";\n").unwrap();
}
fn make_impl(codes: &[Code], file: &mut BufWriter<File>) {
write!(
file,
r#"
impl SqlState {{ impl SqlState {{
/// Creates a `SqlState` from its error code. /// Creates a `SqlState` from its error code.
pub fn from_code(s: &str) -> SqlState {{ pub fn from_code(s: &str) -> SqlState {{
match SQLSTATE_MAP.get(s) {{ match SQLSTATE_MAP.get(s) {{
Some(state) => state.clone(), Some(state) => state.clone(),
None => SqlState::Other(s.to_owned()), None => SqlState(Cow::Owned(s.to_string())),
}} }}
}} }}
/// Returns the error code corresponding to the `SqlState`. /// Returns the error code corresponding to the `SqlState`.
pub fn code(&self) -> &str {{ pub fn code(&self) -> &str {{
match *self {{"# &self.0
).unwrap();
for code in codes {
write!(
file,
r#"
SqlState::{} => "{}","#,
code.variant,
code.code
).unwrap();
}
write!(
file,
r#"
SqlState::Other(ref s) => s,
}}
}} }}
}} }}
"# "
).unwrap(); ).unwrap();
} }
fn make_consts(codes: &LinkedHashMap<String, Vec<String>>, file: &mut BufWriter<File>) {
for (code, names) in codes {
for name in names {
write!(
file,
r#"
/// {code}
pub const {name}: SqlState = SqlState(Cow::Borrowed("{code}"));
"#,
name = name,
code = code,
).unwrap();
}
}
}
fn make_map(codes: &LinkedHashMap<String, Vec<String>>, file: &mut BufWriter<File>) {
write!(
file,
"
#[cfg_attr(rustfmt, rustfmt_skip)]
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = "
).unwrap();
let mut builder = phf_codegen::Map::new();
for (code, names) in codes {
builder.entry(&**code, &names[0]);
}
builder.build(file).unwrap();
write!(file, ";\n").unwrap();
}

View File

@ -5,7 +5,7 @@ use std::convert::From;
use std::fmt; use std::fmt;
use std::io; use std::io;
pub use self::sqlstate::SqlState; pub use self::sqlstate::*;
mod sqlstate; mod sqlstate;

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,8 @@ use std::io;
use std::error; use std::error;
#[doc(inline)] #[doc(inline)]
pub use postgres_shared::error::{DbError, ConnectError, ErrorPosition, Severity, SqlState}; // FIXME
pub use postgres_shared::error::*;
/// An error encountered when communicating with the Postgres server. /// An error encountered when communicating with the Postgres server.
#[derive(Debug)] #[derive(Debug)]

View File

@ -93,7 +93,7 @@ use postgres_protocol::message::backend::{self, ErrorFields};
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use postgres_shared::rows::RowData; use postgres_shared::rows::RowData;
use error::{Error, ConnectError, SqlState, DbError}; use error::{Error, ConnectError, DbError, UNDEFINED_COLUMN, UNDEFINED_TABLE};
use tls::TlsHandshake; use tls::TlsHandshake;
use notification::{Notifications, Notification}; use notification::{Notifications, Notification};
use params::{IntoConnectParams, User}; use params::{IntoConnectParams, User};
@ -771,7 +771,7 @@ impl InnerConnection {
) { ) {
Ok(..) => {} Ok(..) => {}
// Range types weren't added until Postgres 9.2, so pg_range may not exist // Range types weren't added until Postgres 9.2, so pg_range may not exist
Err(Error::Db(ref e)) if e.code == SqlState::UndefinedTable => { Err(Error::Db(ref e)) if e.code == UNDEFINED_TABLE => {
self.raw_prepare( self.raw_prepare(
TYPEINFO_QUERY, TYPEINFO_QUERY,
"SELECT t.typname, t.typtype, t.typelem, NULL::OID, \ "SELECT t.typname, t.typtype, t.typelem, NULL::OID, \
@ -862,7 +862,7 @@ impl InnerConnection {
) { ) {
Ok(..) => {} Ok(..) => {}
// Postgres 9.0 doesn't have enumsortorder // Postgres 9.0 doesn't have enumsortorder
Err(Error::Db(ref e)) if e.code == SqlState::UndefinedColumn => { Err(Error::Db(ref e)) if e.code == UNDEFINED_COLUMN => {
self.raw_prepare( self.raw_prepare(
TYPEINFO_ENUM_QUERY, TYPEINFO_ENUM_QUERY,
"SELECT enumlabel \ "SELECT enumlabel \

View File

@ -12,10 +12,9 @@ extern crate native_tls;
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use postgres::{HandleNotice, Connection, GenericConnection, TlsMode}; use postgres::{HandleNotice, Connection, GenericConnection, TlsMode};
use postgres::transaction::{self, IsolationLevel}; use postgres::transaction::{self, IsolationLevel};
use postgres::error::{Error, ConnectError, DbError}; use postgres::error::{Error, ConnectError, DbError, SYNTAX_ERROR, QUERY_CANCELED, UNDEFINED_TABLE,
INVALID_CATALOG_NAME, INVALID_PASSWORD, CARDINALITY_VIOLATION};
use postgres::types::{Oid, Type, Kind, WrongType}; use postgres::types::{Oid, Type, Kind, WrongType};
use postgres::error::SqlState::{SyntaxError, QueryCanceled, UndefinedTable, InvalidCatalogName,
InvalidPassword, CardinalityViolation};
use postgres::error::ErrorPosition::Normal; use postgres::error::ErrorPosition::Normal;
use postgres::rows::RowIndex; use postgres::rows::RowIndex;
use postgres::notification::Notification; use postgres::notification::Notification;
@ -59,7 +58,7 @@ fn test_prepare_err() {
)); ));
let stmt = conn.prepare("invalid sql database"); let stmt = conn.prepare("invalid sql database");
match stmt { match stmt {
Err(Error::Db(ref e)) if e.code == SyntaxError && e.position == Some(Normal(1)) => {} Err(Error::Db(ref e)) if e.code == SYNTAX_ERROR && e.position == Some(Normal(1)) => {}
Err(e) => panic!("Unexpected result {:?}", e), Err(e) => panic!("Unexpected result {:?}", e),
_ => panic!("Unexpected result"), _ => panic!("Unexpected result"),
} }
@ -68,7 +67,7 @@ fn test_prepare_err() {
#[test] #[test]
fn test_unknown_database() { fn test_unknown_database() {
match Connection::connect("postgres://postgres@localhost:5433/asdf", TlsMode::None) { match Connection::connect("postgres://postgres@localhost:5433/asdf", TlsMode::None) {
Err(ConnectError::Db(ref e)) if e.code == InvalidCatalogName => {} Err(ConnectError::Db(ref e)) if e.code == INVALID_CATALOG_NAME => {}
Err(resp) => panic!("Unexpected result {:?}", resp), Err(resp) => panic!("Unexpected result {:?}", resp),
_ => panic!("Unexpected result"), _ => panic!("Unexpected result"),
} }
@ -455,7 +454,7 @@ fn test_batch_execute_error() {
let stmt = conn.prepare("SELECT * FROM foo ORDER BY id"); let stmt = conn.prepare("SELECT * FROM foo ORDER BY id");
match stmt { match stmt {
Err(Error::Db(ref e)) if e.code == UndefinedTable => {} Err(Error::Db(ref e)) if e.code == UNDEFINED_TABLE => {}
Err(e) => panic!("unexpected error {:?}", e), Err(e) => panic!("unexpected error {:?}", e),
_ => panic!("unexpected success"), _ => panic!("unexpected success"),
} }
@ -520,7 +519,7 @@ FROM (SELECT gs.i
LIMIT 2) ss", LIMIT 2) ss",
)); ));
match stmt.query(&[]) { match stmt.query(&[]) {
Err(Error::Db(ref e)) if e.code == CardinalityViolation => {} Err(Error::Db(ref e)) if e.code == CARDINALITY_VIOLATION => {}
Err(err) => panic!("Unexpected error {:?}", err), Err(err) => panic!("Unexpected error {:?}", err),
Ok(_) => panic!("Expected failure"), Ok(_) => panic!("Expected failure"),
}; };
@ -917,13 +916,16 @@ fn test_cancel_query() {
let t = thread::spawn(move || { let t = thread::spawn(move || {
thread::sleep(Duration::from_millis(500)); thread::sleep(Duration::from_millis(500));
assert!( assert!(
postgres::cancel_query("postgres://postgres@localhost:5433", TlsMode::None, &cancel_data) postgres::cancel_query(
.is_ok() "postgres://postgres@localhost:5433",
TlsMode::None,
&cancel_data,
).is_ok()
); );
}); });
match conn.execute("SELECT pg_sleep(10)", &[]) { match conn.execute("SELECT pg_sleep(10)", &[]) {
Err(Error::Db(ref e)) if e.code == QueryCanceled => {} Err(Error::Db(ref e)) if e.code == QUERY_CANCELED => {}
Err(res) => panic!("Unexpected result {:?}", res), Err(res) => panic!("Unexpected result {:?}", res),
_ => panic!("Unexpected result"), _ => panic!("Unexpected result"),
} }
@ -1011,7 +1013,10 @@ fn test_plaintext_pass() {
#[test] #[test]
fn test_plaintext_pass_no_pass() { fn test_plaintext_pass_no_pass() {
let ret = Connection::connect("postgres://pass_user@localhost:5433/postgres", TlsMode::None); let ret = Connection::connect(
"postgres://pass_user@localhost:5433/postgres",
TlsMode::None,
);
match ret { match ret {
Err(ConnectError::ConnectParams(..)) => (), Err(ConnectError::ConnectParams(..)) => (),
Err(err) => panic!("Unexpected error {:?}", err), Err(err) => panic!("Unexpected error {:?}", err),
@ -1026,7 +1031,7 @@ fn test_plaintext_pass_wrong_pass() {
TlsMode::None, TlsMode::None,
); );
match ret { match ret {
Err(ConnectError::Db(ref e)) if e.code == InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == INVALID_PASSWORD => {}
Err(err) => panic!("Unexpected error {:?}", err), Err(err) => panic!("Unexpected error {:?}", err),
_ => panic!("Expected error"), _ => panic!("Expected error"),
} }
@ -1052,9 +1057,12 @@ fn test_md5_pass_no_pass() {
#[test] #[test]
fn test_md5_pass_wrong_pass() { fn test_md5_pass_wrong_pass() {
let ret = Connection::connect("postgres://md5_user:asdf@localhost:5433/postgres", TlsMode::None); let ret = Connection::connect(
"postgres://md5_user:asdf@localhost:5433/postgres",
TlsMode::None,
);
match ret { match ret {
Err(ConnectError::Db(ref e)) if e.code == InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == INVALID_PASSWORD => {}
Err(err) => panic!("Unexpected error {:?}", err), Err(err) => panic!("Unexpected error {:?}", err),
_ => panic!("Expected error"), _ => panic!("Expected error"),
} }
@ -1070,7 +1078,10 @@ fn test_scram_pass() {
#[test] #[test]
fn test_scram_pass_no_pass() { fn test_scram_pass_no_pass() {
let ret = Connection::connect("postgres://scram_user@localhost:5433/postgres", TlsMode::None); let ret = Connection::connect(
"postgres://scram_user@localhost:5433/postgres",
TlsMode::None,
);
match ret { match ret {
Err(ConnectError::ConnectParams(..)) => (), Err(ConnectError::ConnectParams(..)) => (),
Err(err) => panic!("Unexpected error {:?}", err), Err(err) => panic!("Unexpected error {:?}", err),
@ -1080,9 +1091,12 @@ fn test_scram_pass_no_pass() {
#[test] #[test]
fn test_scram_pass_wrong_pass() { fn test_scram_pass_wrong_pass() {
let ret = Connection::connect("postgres://scram_user:asdf@localhost:5433/postgres", TlsMode::None); let ret = Connection::connect(
"postgres://scram_user:asdf@localhost:5433/postgres",
TlsMode::None,
);
match ret { match ret {
Err(ConnectError::Db(ref e)) if e.code == InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == INVALID_PASSWORD => {}
Err(err) => panic!("Unexpected error {:?}", err), Err(err) => panic!("Unexpected error {:?}", err),
_ => panic!("Expected error"), _ => panic!("Expected error"),
} }

View File

@ -7,7 +7,8 @@ use std::fmt;
use Connection; use Connection;
#[doc(inline)] #[doc(inline)]
pub use postgres_shared::error::{DbError, ConnectError, ErrorPosition, Severity, SqlState}; // FIXME
pub use postgres_shared::error::*;
/// A runtime error. /// A runtime error.
#[derive(Debug)] #[derive(Debug)]

View File

@ -89,7 +89,7 @@ use tokio_core::reactor::Handle;
#[doc(inline)] #[doc(inline)]
pub use postgres_shared::{params, CancelData, Notification}; pub use postgres_shared::{params, CancelData, Notification};
use error::{ConnectError, Error, DbError, SqlState}; use error::{ConnectError, Error, DbError, UNDEFINED_TABLE, UNDEFINED_COLUMN};
use params::{ConnectParams, IntoConnectParams}; use params::{ConnectParams, IntoConnectParams};
use stmt::{Statement, Column}; use stmt::{Statement, Column};
use stream::PostgresStream; use stream::PostgresStream;
@ -774,7 +774,7 @@ impl Connection {
match e { match e {
// Range types weren't added until Postgres 9.2, so pg_range may not exist // Range types weren't added until Postgres 9.2, so pg_range may not exist
Error::Db(e, c) => { Error::Db(e, c) => {
if e.code != SqlState::UndefinedTable { if e.code != UNDEFINED_TABLE {
return Either::B(Err(Error::Db(e, c)).into_future()); return Either::B(Err(Error::Db(e, c)).into_future());
} }
@ -832,7 +832,7 @@ impl Connection {
ORDER BY enumsortorder", ORDER BY enumsortorder",
).or_else(|e| match e { ).or_else(|e| match e {
Error::Db(e, c) => { Error::Db(e, c) => {
if e.code != SqlState::UndefinedColumn { if e.code != UNDEFINED_COLUMN {
return Either::B(Err(Error::Db(e, c)).into_future()); return Either::B(Err(Error::Db(e, c)).into_future());
} }

View File

@ -6,7 +6,7 @@ use std::time::Duration;
use tokio_core::reactor::{Core, Interval}; use tokio_core::reactor::{Core, Interval};
use super::*; use super::*;
use error::{Error, ConnectError, SqlState}; use error::{Error, ConnectError, INVALID_PASSWORD, INVALID_AUTHORIZATION_SPECIFICATION, QUERY_CANCELED};
use params::{ConnectParams, Host}; use params::{ConnectParams, Host};
use types::{ToSql, FromSql, Type, IsNull, Kind}; use types::{ToSql, FromSql, Type, IsNull, Kind};
@ -48,7 +48,7 @@ fn md5_user_wrong_pass() {
&handle, &handle,
); );
match l.run(done) { match l.run(done) {
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == INVALID_PASSWORD => {}
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"), Ok(_) => panic!("unexpected success"),
} }
@ -92,7 +92,7 @@ fn pass_user_wrong_pass() {
&handle, &handle,
); );
match l.run(done) { match l.run(done) {
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == INVALID_PASSWORD => {}
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"), Ok(_) => panic!("unexpected success"),
} }
@ -123,7 +123,7 @@ fn batch_execute_err() {
.and_then(|c| c.batch_execute("SELECT * FROM bogo")) .and_then(|c| c.batch_execute("SELECT * FROM bogo"))
.then(|r| match r { .then(|r| match r {
Err(Error::Db(e, s)) => { Err(Error::Db(e, s)) => {
assert!(e.code == SqlState::UndefinedTable); assert!(e.code == UNDEFINED_TABLE);
s.batch_execute("SELECT * FROM foo") s.batch_execute("SELECT * FROM foo")
} }
Err(e) => panic!("unexpected error: {}", e), Err(e) => panic!("unexpected error: {}", e),
@ -249,7 +249,7 @@ fn ssl_user_ssl_required() {
); );
match l.run(done) { match l.run(done) {
Err(ConnectError::Db(e)) => assert!(e.code == SqlState::InvalidAuthorizationSpecification), Err(ConnectError::Db(e)) => assert!(e.code == INVALID_AUTHORIZATION_SPECIFICATION),
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"), Ok(_) => panic!("unexpected success"),
} }
@ -437,7 +437,7 @@ fn cancel() {
let (select, cancel) = l.run(done).unwrap(); let (select, cancel) = l.run(done).unwrap();
cancel.unwrap(); cancel.unwrap();
match select { match select {
Err(Error::Db(e, _)) => assert_eq!(e.code, SqlState::QueryCanceled), Err(Error::Db(e, _)) => assert_eq!(e.code, QUERY_CANCELED),
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"), Ok(_) => panic!("unexpected success"),
} }