This commit is contained in:
Steven Fackler 2016-09-12 21:48:49 -07:00
parent bcb104793b
commit fbcdd6b547
12 changed files with 119 additions and 196 deletions

View File

@ -207,7 +207,8 @@ impl error::Error for ConnectError {
fn cause(&self) -> Option<&error::Error> { fn cause(&self) -> Option<&error::Error> {
match *self { match *self {
ConnectError::ConnectParams(ref err) | ConnectError::Ssl(ref err) => Some(&**err), ConnectError::ConnectParams(ref err) |
ConnectError::Ssl(ref err) => Some(&**err),
ConnectError::Db(ref err) => Some(&**err), ConnectError::Db(ref err) => Some(&**err),
ConnectError::Io(ref err) => Some(err), ConnectError::Io(ref err) => Some(err),
} }

View File

@ -245,7 +245,8 @@ impl InnerConnection {
let user = match user { let user = match user {
Some(user) => user, Some(user) => user,
None => { None => {
return Err(ConnectError::ConnectParams("User missing from connection parameters".into())); return Err(ConnectError::ConnectParams("User missing from connection parameters"
.into()));
} }
}; };
@ -337,8 +338,7 @@ impl InnerConnection {
} }
} }
fn read_message_with_notification_nonblocking(&mut self) fn read_message_with_notification_nonblocking(&mut self) -> std::io::Result<Option<Backend>> {
-> std::io::Result<Option<Backend>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop { loop {
match try_desync!(self, self.stream.read_message_nonblocking()) { match try_desync!(self, self.stream.read_message_nonblocking()) {
@ -471,7 +471,8 @@ impl InnerConnection {
let more_rows; let more_rows;
loop { loop {
match try!(self.read_message()) { match try!(self.read_message()) {
Backend::EmptyQueryResponse | Backend::CommandComplete { .. } => { Backend::EmptyQueryResponse |
Backend::CommandComplete { .. } => {
more_rows = false; more_rows = false;
break; break;
} }
@ -813,9 +814,7 @@ impl InnerConnection {
Backend::ReadyForQuery { .. } => break, Backend::ReadyForQuery { .. } => break,
Backend::DataRow { row } => { Backend::DataRow { row } => {
result.push(row.into_iter() result.push(row.into_iter()
.map(|opt| { .map(|opt| opt.map(|b| String::from_utf8_lossy(&b).into_owned()))
opt.map(|b| String::from_utf8_lossy(&b).into_owned())
})
.collect()); .collect());
} }
Backend::CopyInResponse { .. } => { Backend::CopyInResponse { .. } => {

View File

@ -8,24 +8,15 @@ pub enum Backend {
AuthenticationCleartextPassword, AuthenticationCleartextPassword,
AuthenticationGSS, AuthenticationGSS,
AuthenticationKerberosV5, AuthenticationKerberosV5,
AuthenticationMD5Password { AuthenticationMD5Password { salt: [u8; 4] },
salt: [u8; 4],
},
AuthenticationOk, AuthenticationOk,
AuthenticationSCMCredential, AuthenticationSCMCredential,
AuthenticationSSPI, AuthenticationSSPI,
BackendKeyData { BackendKeyData { process_id: i32, secret_key: i32 },
process_id: i32,
secret_key: i32,
},
BindComplete, BindComplete,
CloseComplete, CloseComplete,
CommandComplete { CommandComplete { tag: String },
tag: String, CopyData { data: Vec<u8> },
},
CopyData {
data: Vec<u8>,
},
CopyDone, CopyDone,
CopyInResponse { CopyInResponse {
format: u8, format: u8,
@ -35,37 +26,22 @@ pub enum Backend {
format: u8, format: u8,
column_formats: Vec<u16>, column_formats: Vec<u16>,
}, },
DataRow { DataRow { row: Vec<Option<Vec<u8>>> },
row: Vec<Option<Vec<u8>>>,
},
EmptyQueryResponse, EmptyQueryResponse,
ErrorResponse { ErrorResponse { fields: Vec<(u8, String)> },
fields: Vec<(u8, String)>,
},
NoData, NoData,
NoticeResponse { NoticeResponse { fields: Vec<(u8, String)> },
fields: Vec<(u8, String)>,
},
NotificationResponse { NotificationResponse {
process_id: i32, process_id: i32,
channel: String, channel: String,
payload: String, payload: String,
}, },
ParameterDescription { ParameterDescription { types: Vec<Oid> },
types: Vec<Oid>, ParameterStatus { parameter: String, value: String },
},
ParameterStatus {
parameter: String,
value: String,
},
ParseComplete, ParseComplete,
PortalSuspended, PortalSuspended,
ReadyForQuery { ReadyForQuery { _state: u8 },
_state: u8, RowDescription { descriptions: Vec<RowDescriptionEntry>, },
},
RowDescription {
descriptions: Vec<RowDescriptionEntry>,
},
} }
impl Backend { impl Backend {
@ -89,9 +65,7 @@ impl Backend {
Message::BindComplete => Backend::BindComplete, Message::BindComplete => Backend::BindComplete,
Message::CloseComplete => Backend::CloseComplete, Message::CloseComplete => Backend::CloseComplete,
Message::CommandComplete(body) => { Message::CommandComplete(body) => {
Backend::CommandComplete { Backend::CommandComplete { tag: body.tag().to_owned() }
tag: body.tag().to_owned()
}
} }
Message::CopyData(body) => Backend::CopyData { data: body.data().to_owned() }, Message::CopyData(body) => Backend::CopyData { data: body.data().to_owned() },
Message::CopyDone => Backend::CopyDone, Message::CopyDone => Backend::CopyDone,
@ -115,13 +89,17 @@ impl Backend {
Message::EmptyQueryResponse => Backend::EmptyQueryResponse, Message::EmptyQueryResponse => Backend::EmptyQueryResponse,
Message::ErrorResponse(body) => { Message::ErrorResponse(body) => {
Backend::ErrorResponse { Backend::ErrorResponse {
fields: try!(body.fields().map(|f| (f.type_(), f.value().to_owned())).collect()), fields: try!(body.fields()
.map(|f| (f.type_(), f.value().to_owned()))
.collect()),
} }
} }
Message::NoData => Backend::NoData, Message::NoData => Backend::NoData,
Message::NoticeResponse(body) => { Message::NoticeResponse(body) => {
Backend::NoticeResponse { Backend::NoticeResponse {
fields: try!(body.fields().map(|f| (f.type_(), f.value().to_owned())).collect()), fields: try!(body.fields()
.map(|f| (f.type_(), f.value().to_owned()))
.collect()),
} }
} }
Message::NotificationResponse(body) => { Message::NotificationResponse(body) => {
@ -132,9 +110,7 @@ impl Backend {
} }
} }
Message::ParameterDescription(body) => { Message::ParameterDescription(body) => {
Backend::ParameterDescription { Backend::ParameterDescription { types: try!(body.parameters().collect()) }
types: try!(body.parameters().collect()),
}
} }
Message::ParameterStatus(body) => { Message::ParameterStatus(body) => {
Backend::ParameterStatus { Backend::ParameterStatus {
@ -158,9 +134,7 @@ impl Backend {
format: f.format(), format: f.format(),
} }
}); });
Backend::RowDescription { Backend::RowDescription { descriptions: try!(fields.collect()) }
descriptions: try!(fields.collect()),
}
} }
_ => return Err(io::Error::new(io::ErrorKind::InvalidInput, "unknown message type")), _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, "unknown message type")),
}; };

View File

@ -85,9 +85,8 @@ impl MessageStream {
match b { match b {
Ok(b) => self.inner_read_message(b).map(Some), Ok(b) => self.inner_read_message(b).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => { Err(ref e) if e.kind() == io::ErrorKind::WouldBlock ||
Ok(None) e.kind() == io::ErrorKind::TimedOut => Ok(None),
}
Err(e) => Err(e), Err(e) => Err(e),
} }
} }

View File

@ -397,11 +397,7 @@ impl<'trans, 'stmt> FallibleIterator for LazyRows<'trans, 'stmt> {
fn size_hint(&self) -> (usize, Option<usize>) { fn size_hint(&self) -> (usize, Option<usize>) {
let lower = self.data.len(); let lower = self.data.len();
let upper = if self.more_rows { let upper = if self.more_rows { None } else { Some(lower) };
None
} else {
Some(lower)
};
(lower, upper) (lower, upper)
} }
} }

View File

@ -12,8 +12,8 @@ use types::{SessionInfo, Type, ToSql};
use message::Backend; use message::Backend;
use rows::{Rows, LazyRows}; use rows::{Rows, LazyRows};
use transaction::Transaction; use transaction::Transaction;
use {bad_response, Connection, StatementInternals, Result, RowsNew, InnerConnection, SessionInfoNew, use {bad_response, Connection, StatementInternals, Result, RowsNew, InnerConnection,
LazyRowsNew, DbErrorNew, ColumnNew, StatementInfo, TransactionInternals}; SessionInfoNew, LazyRowsNew, DbErrorNew, ColumnNew, StatementInfo, TransactionInternals};
/// A prepared statement. /// A prepared statement.
pub struct Statement<'conn> { pub struct Statement<'conn> {
@ -417,7 +417,8 @@ impl<'conn> Statement<'conn> {
Ok(n) => data = &data[n..], Ok(n) => data = &data[n..],
Err(e) => { Err(e) => {
loop { loop {
if let Backend::ReadyForQuery { .. } = try!(info.conn.read_message()) { if let Backend::ReadyForQuery { .. } =
try!(info.conn.read_message()) {
return Err(Error::Io(e)); return Err(Error::Io(e));
} }
} }

View File

@ -4,8 +4,7 @@ use std::cell::Cell;
use std::fmt; use std::fmt;
use std::ascii::AsciiExt; use std::ascii::AsciiExt;
use {bad_response, Result, Connection, TransactionInternals, ConfigInternals, use {bad_response, Result, Connection, TransactionInternals, ConfigInternals, IsolationLevelNew};
IsolationLevelNew};
use error::Error; use error::Error;
use rows::Rows; use rows::Rows;
use stmt::Statement; use stmt::Statement;
@ -199,15 +198,13 @@ impl<'conn> Transaction<'conn> {
debug_assert!(self.depth == conn.trans_depth); debug_assert!(self.depth == conn.trans_depth);
conn.trans_depth -= 1; conn.trans_depth -= 1;
match (self.commit.get(), &self.savepoint_name) { match (self.commit.get(), &self.savepoint_name) {
(false, &Some(ref savepoint_name)) => { (false, &Some(ref sp)) => try!(conn.quick_query(&format!("ROLLBACK TO {}", sp))),
conn.quick_query(&format!("ROLLBACK TO {}", savepoint_name)) (false, &None) => try!(conn.quick_query("ROLLBACK")),
} (true, &Some(ref sp)) => try!(conn.quick_query(&format!("RELEASE {}", sp))),
(false, &None) => conn.quick_query("ROLLBACK"), (true, &None) => try!(conn.quick_query("COMMIT")),
(true, &Some(ref savepoint_name)) => { };
conn.quick_query(&format!("RELEASE {}", savepoint_name))
} Ok(())
(true, &None) => conn.quick_query("COMMIT"),
}.map(|_| ())
} }
/// Like `Connection::prepare`. /// Like `Connection::prepare`.

View File

@ -634,7 +634,8 @@ impl ToSql for bool {
impl<'a, T: ToSql> ToSql for &'a [T] { impl<'a, T: ToSql> ToSql for &'a [T] {
to_sql_checked!(); to_sql_checked!();
fn to_sql<W: Write + ?Sized>(&self, ty: &Type, fn to_sql<W: Write + ?Sized>(&self,
ty: &Type,
mut w: &mut W, mut w: &mut W,
ctx: &SessionInfo) ctx: &SessionInfo)
-> Result<IsNull> { -> Result<IsNull> {

View File

@ -39,7 +39,11 @@ impl<T: FromSql> FromSql for Date<T> {
} }
} }
impl<T: ToSql> ToSql for Date<T> { impl<T: ToSql> ToSql for Date<T> {
fn to_sql<W: Write+?Sized>(&self, ty: &Type, out: &mut W, ctx: &SessionInfo) -> Result<IsNull> { fn to_sql<W: Write + ?Sized>(&self,
ty: &Type,
out: &mut W,
ctx: &SessionInfo)
-> Result<IsNull> {
if *ty != Type::Date { if *ty != Type::Date {
return Err(Error::Conversion("expected date type".into())); return Err(Error::Conversion("expected date type".into()));
} }
@ -95,7 +99,11 @@ impl<T: FromSql> FromSql for Timestamp<T> {
} }
impl<T: ToSql> ToSql for Timestamp<T> { impl<T: ToSql> ToSql for Timestamp<T> {
fn to_sql<W: Write+?Sized>(&self, ty: &Type, out: &mut W, ctx: &SessionInfo) -> Result<IsNull> { fn to_sql<W: Write + ?Sized>(&self,
ty: &Type,
out: &mut W,
ctx: &SessionInfo)
-> Result<IsNull> {
if *ty != Type::Timestamp && *ty != Type::Timestamptz { if *ty != Type::Timestamp && *ty != Type::Timestamptz {
return Err(Error::Conversion("expected timestamp or timestamptz type".into())); return Err(Error::Conversion("expected timestamp or timestamptz type".into()));
} }

View File

@ -126,7 +126,8 @@ fn decode_inner(c: &str, full_url: bool) -> DecodeResult<String> {
(Some(one), Some(two)) => [one, two], (Some(one), Some(two)) => [one, two],
_ => { _ => {
return Err("Malformed input: found '%' without two \ return Err("Malformed input: found '%' without two \
trailing bytes".to_owned()) trailing bytes"
.to_owned())
} }
}; };
@ -135,31 +136,16 @@ fn decode_inner(c: &str, full_url: bool) -> DecodeResult<String> {
_ => { _ => {
return Err("Malformed input: found '%' followed by \ return Err("Malformed input: found '%' followed by \
invalid hex values. Character '%' must \ invalid hex values. Character '%' must \
escaped.".to_owned()) escaped."
.to_owned())
} }
}; };
// Only decode some characters if full_url: // Only decode some characters if full_url:
match bytes_from_hex[0] as char { match bytes_from_hex[0] as char {
// gen-delims: // gen-delims:
':' | ':' | '/' | '?' | '#' | '[' | ']' | '@' | '!' | '$' | '&' | '"' |
'/' | '(' | ')' | '*' | '+' | ',' | ';' | '=' if full_url => {
'?' |
'#' |
'[' |
']' |
'@' |
'!' |
'$' |
'&' |
'"' |
'(' |
')' |
'*' |
'+' |
',' |
';' |
'=' if full_url => {
out.push('%'); out.push('%');
out.push(bytes[0] as char); out.push(bytes[0] as char);
out.push(bytes[1] as char); out.push(bytes[1] as char);
@ -201,8 +187,7 @@ fn query_from_str(rawquery: &str) -> DecodeResult<Query> {
pub fn get_scheme(rawurl: &str) -> DecodeResult<(&str, &str)> { pub fn get_scheme(rawurl: &str) -> DecodeResult<(&str, &str)> {
for (i, c) in rawurl.chars().enumerate() { for (i, c) in rawurl.chars().enumerate() {
let result = match c { let result = match c {
'A'...'Z' | 'A'...'Z' | 'a'...'z' => continue,
'a'...'z' => continue,
'0'...'9' | '+' | '-' | '.' => { '0'...'9' | '+' | '-' | '.' => {
if i != 0 { if i != 0 {
continue; continue;
@ -268,29 +253,13 @@ fn get_authority(rawurl: &str) -> DecodeResult<(Option<UserInfo>, &str, Option<u
// deal with input class first // deal with input class first
match c { match c {
'0'...'9' => (), '0'...'9' => (),
'A'...'F' | 'A'...'F' | 'a'...'f' => {
'a'...'f' => {
if input == Input::Digit { if input == Input::Digit {
input = Input::Hex; input = Input::Hex;
} }
} }
'G'...'Z' | 'G'...'Z' | 'g'...'z' | '-' | '.' | '_' | '~' | '%' | '&' | '\'' | '(' | ')' |
'g'...'z' | '+' | '!' | '*' | ',' | ';' | '=' => input = Input::Unreserved,
'-' |
'.' |
'_' |
'~' |
'%' |
'&' |
'\'' |
'(' |
')' |
'+' |
'!' |
'*' |
',' |
';' |
'=' => input = Input::Unreserved,
':' | '@' | '?' | '#' | '/' => { ':' | '@' | '?' | '#' | '/' => {
// separators, don't change anything // separators, don't change anything
} }
@ -372,17 +341,14 @@ fn get_authority(rawurl: &str) -> DecodeResult<(Option<UserInfo>, &str, Option<u
// finish up // finish up
match st { match st {
State::PassHostPort | State::PassHostPort | State::Ip6Port => {
State::Ip6Port => {
if input != Input::Digit { if input != Input::Digit {
return Err("Non-digit characters in port.".to_owned()); return Err("Non-digit characters in port.".to_owned());
} }
host = &rawurl[begin..pos]; host = &rawurl[begin..pos];
port = Some(&rawurl[pos + 1..end]); port = Some(&rawurl[pos + 1..end]);
} }
State::Ip6Host | State::Ip6Host | State::InHost | State::Start => host = &rawurl[begin..end],
State::InHost |
State::Start => host = &rawurl[begin..end],
State::InPort => { State::InPort => {
if input != Input::Digit { if input != Input::Digit {
return Err("Non-digit characters in port.".to_owned()); return Err("Non-digit characters in port.".to_owned());
@ -413,27 +379,8 @@ fn get_path(rawurl: &str, is_authority: bool) -> DecodeResult<(String, &str)> {
let mut end = len; let mut end = len;
for (i, c) in rawurl.chars().enumerate() { for (i, c) in rawurl.chars().enumerate() {
match c { match c {
'A'...'Z' | 'A'...'Z' | 'a'...'z' | '0'...'9' | '&' | '\'' | '(' | ')' | '.' | '@' | ':' |
'a'...'z' | '%' | '/' | '+' | '!' | '*' | ',' | ';' | '=' | '_' | '-' | '~' => continue,
'0'...'9' |
'&' |
'\'' |
'(' |
')' |
'.' |
'@' |
':' |
'%' |
'/' |
'+' |
'!' |
'*' |
',' |
';' |
'=' |
'_' |
'-' |
'~' => continue,
'?' | '#' => { '?' | '#' => {
end = i; end = i;
break; break;