Update to protocol 0.2

This commit is contained in:
Steven Fackler 2016-12-18 16:11:38 -08:00
parent 8e5f28eb70
commit 9d53e677ff
9 changed files with 231 additions and 158 deletions

5
.gitignore vendored
View File

@ -1,5 +1,6 @@
target/ target
Cargo.lock Cargo.lock
.cargo/ .cargo
.idea .idea
*.iml *.iml
.vscode

View File

@ -41,7 +41,7 @@ fallible-iterator = "0.1.3"
hex = "0.2" hex = "0.2"
log = "0.3" log = "0.3"
phf = "=0.7.20" phf = "=0.7.20"
postgres-protocol = "0.1" postgres-protocol = "0.2"
bit-vec = { version = "0.4", optional = true } bit-vec = { version = "0.4", optional = true }
chrono = { version = "0.2.14", optional = true } chrono = { version = "0.2.14", optional = true }
eui48 = { version = "0.1", optional = true } eui48 = { version = "0.1", optional = true }

View File

@ -120,10 +120,10 @@ fn make_impl(codes: &[Code], file: &mut BufWriter<File>) {
write!(file, r#" write!(file, r#"
impl SqlState {{ impl SqlState {{
/// Creates a `SqlState` from its error code. /// Creates a `SqlState` from its error code.
pub fn from_code(s: String) -> SqlState {{ pub fn from_code(s: &str) -> SqlState {{
match SQLSTATE_MAP.get(&*s) {{ match SQLSTATE_MAP.get(s) {{
Some(state) => state.clone(), Some(state) => state.clone(),
None => SqlState::Other(s), None => SqlState::Other(s.to_owned()),
}} }}
}} }}

View File

@ -1,11 +1,12 @@
//! Error types. //! Error types.
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::ErrorFields;
use std::error; use std::error;
use std::convert::From; use std::convert::From;
use std::fmt; use std::fmt;
use std::io; use std::io;
use std::result; use std::result;
use std::collections::HashMap;
pub use self::sqlstate::SqlState; pub use self::sqlstate::SqlState;
use {Result, DbErrorNew}; use {Result, DbErrorNew};
@ -85,52 +86,92 @@ pub struct DbError {
} }
impl DbErrorNew for DbError { impl DbErrorNew for DbError {
fn new_raw(fields: Vec<(u8, String)>) -> result::Result<DbError, ()> { fn new_raw(fields: &mut ErrorFields) -> io::Result<DbError> {
let mut map: HashMap<_, _> = fields.into_iter().collect(); let mut severity = None;
let mut code = None;
let mut message = None;
let mut detail = None;
let mut hint = None;
let mut normal_position = None;
let mut internal_position = None;
let mut internal_query = None;
let mut where_ = None;
let mut schema = None;
let mut table = None;
let mut column = None;
let mut datatype = None;
let mut constraint = None;
let mut file = None;
let mut line = None;
let mut routine = None;
while let Some(field) = try!(fields.next()) {
match field.type_() {
b'S' => severity = Some(field.value().to_owned()),
b'C' => code = Some(SqlState::from_code(field.value())),
b'M' => message = Some(field.value().to_owned()),
b'D' => detail = Some(field.value().to_owned()),
b'H' => hint = Some(field.value().to_owned()),
b'P' => normal_position = Some(try!(field.value().parse::<u32>().map_err(|_| ::bad_response()))),
b'p' => internal_position = Some(try!(field.value().parse::<u32>().map_err(|_| ::bad_response()))),
b'q' => internal_query = Some(field.value().to_owned()),
b'W' => where_ = Some(field.value().to_owned()),
b's' => schema = Some(field.value().to_owned()),
b't' => table = Some(field.value().to_owned()),
b'c' => column = Some(field.value().to_owned()),
b'd' => datatype = Some(field.value().to_owned()),
b'n' => constraint = Some(field.value().to_owned()),
b'F' => file = Some(field.value().to_owned()),
b'L' => line = Some(try!(field.value().parse::<u32>().map_err(|_| ::bad_response()))),
b'R' => routine = Some(field.value().to_owned()),
_ => {},
}
}
Ok(DbError { Ok(DbError {
severity: try!(map.remove(&b'S').ok_or(())), severity: try!(severity.ok_or_else(|| ::bad_response())),
code: SqlState::from_code(try!(map.remove(&b'C').ok_or(()))), code: try!(code.ok_or_else(|| ::bad_response())),
message: try!(map.remove(&b'M').ok_or(())), message: try!(message.ok_or_else(|| ::bad_response())),
detail: map.remove(&b'D'), detail: detail,
hint: map.remove(&b'H'), hint: hint,
position: match map.remove(&b'P') { position: match normal_position {
Some(pos) => Some(ErrorPosition::Normal(try!(pos.parse().map_err(|_| ())))), Some(position) => Some(ErrorPosition::Normal(position)),
None => { None => {
match map.remove(&b'p') { match internal_position {
Some(pos) => { Some(position) => {
Some(ErrorPosition::Internal { Some(ErrorPosition::Internal {
position: try!(pos.parse().map_err(|_| ())), position: position,
query: try!(map.remove(&b'q').ok_or(())), query: try!(internal_query.ok_or_else(|| ::bad_response())),
}) })
} }
None => None, None => None,
} }
} }
}, },
where_: map.remove(&b'W'), where_: where_,
schema: map.remove(&b's'), schema: schema,
table: map.remove(&b't'), table: table,
column: map.remove(&b'c'), column: column,
datatype: map.remove(&b'd'), datatype: datatype,
constraint: map.remove(&b'n'), constraint: constraint,
file: map.remove(&b'F'), file: file,
line: map.remove(&b'L').and_then(|l| l.parse().ok()), line: line,
routine: map.remove(&b'R'), routine: routine,
_p: (), _p: (),
}) })
} }
fn new_connect<T>(fields: Vec<(u8, String)>) -> result::Result<T, ConnectError> { fn new_connect<T>(fields: &mut ErrorFields) -> result::Result<T, ConnectError> {
match DbError::new_raw(fields) { match DbError::new_raw(fields) {
Ok(err) => Err(ConnectError::Db(Box::new(err))), Ok(err) => Err(ConnectError::Db(Box::new(err))),
Err(()) => Err(ConnectError::Io(::bad_response())), Err(e) => Err(ConnectError::Io(e)),
} }
} }
fn new<T>(fields: Vec<(u8, String)>) -> Result<T> { fn new<T>(fields: &mut ErrorFields) -> Result<T> {
match DbError::new_raw(fields) { match DbError::new_raw(fields) {
Ok(err) => Err(Error::Db(Box::new(err))), Ok(err) => Err(Error::Db(Box::new(err))),
Err(()) => Err(Error::Io(::bad_response())), Err(e) => Err(Error::Io(e)),
} }
} }
} }

View File

@ -782,10 +782,10 @@ static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = ::phf::Map {
impl SqlState { impl SqlState {
/// Creates a `SqlState` from its error code. /// Creates a `SqlState` from its error code.
pub fn from_code(s: String) -> SqlState { pub fn from_code(s: &str) -> SqlState {
match SQLSTATE_MAP.get(&*s) { match SQLSTATE_MAP.get(s) {
Some(state) => state.clone(), Some(state) => state.clone(),
None => SqlState::Other(s), None => SqlState::Other(s.to_owned()),
} }
} }

View File

@ -79,6 +79,7 @@ extern crate log;
extern crate phf; extern crate phf;
extern crate postgres_protocol; extern crate postgres_protocol;
use fallible_iterator::FallibleIterator;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::collections::{VecDeque, HashMap}; use std::collections::{VecDeque, HashMap};
use std::fmt; use std::fmt;
@ -88,7 +89,7 @@ use std::result;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use postgres_protocol::authentication; use postgres_protocol::authentication;
use postgres_protocol::message::backend::{self, RowDescriptionEntry}; use postgres_protocol::message::backend::{self, ErrorFields};
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use error::{Error, ConnectError, SqlState, DbError}; use error::{Error, ConnectError, SqlState, DbError};
@ -307,12 +308,14 @@ impl InnerConnection {
loop { loop {
match try!(conn.read_message()) { match try!(conn.read_message()) {
backend::Message::BackendKeyData { process_id, secret_key } => { backend::Message::BackendKeyData(body) => {
conn.cancel_data.process_id = process_id; conn.cancel_data.process_id = body.process_id();
conn.cancel_data.secret_key = secret_key; conn.cancel_data.secret_key = body.secret_key();
}
backend::Message::ReadyForQuery(_) => break,
backend::Message::ErrorResponse(body) => {
return DbError::new_connect(&mut body.fields())
} }
backend::Message::ReadyForQuery { .. } => break,
backend::Message::ErrorResponse { fields } => return DbError::new_connect(fields),
_ => return Err(ConnectError::Io(bad_response())), _ => return Err(ConnectError::Io(bad_response())),
} }
} }
@ -320,17 +323,18 @@ impl InnerConnection {
Ok(conn) Ok(conn)
} }
fn read_message_with_notification(&mut self) -> io::Result<backend::Message> { fn read_message_with_notification(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop { loop {
match try_desync!(self, self.stream.read_message()) { match try_desync!(self, self.stream.read_message()) {
backend::Message::NoticeResponse { fields } => { backend::Message::NoticeResponse(body) => {
if let Ok(err) = DbError::new_raw(fields) { if let Ok(err) = DbError::new_raw(&mut body.fields()) {
self.notice_handler.handle_notice(err); self.notice_handler.handle_notice(err);
} }
} }
backend::Message::ParameterStatus { parameter, value } => { backend::Message::ParameterStatus(body) => {
self.parameters.insert(parameter, value); self.parameters.insert(try!(body.name()).to_owned(),
try!(body.value()).to_owned());
} }
val => return Ok(val), val => return Ok(val),
} }
@ -339,17 +343,18 @@ impl InnerConnection {
fn read_message_with_notification_timeout(&mut self, fn read_message_with_notification_timeout(&mut self,
timeout: Duration) timeout: Duration)
-> io::Result<Option<backend::Message>> { -> io::Result<Option<backend::Message<Vec<u8>>>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop { loop {
match try_desync!(self, self.stream.read_message_timeout(timeout)) { match try_desync!(self, self.stream.read_message_timeout(timeout)) {
Some(backend::Message::NoticeResponse { fields }) => { Some(backend::Message::NoticeResponse(body)) => {
if let Ok(err) = DbError::new_raw(fields) { if let Ok(err) = DbError::new_raw(&mut body.fields()) {
self.notice_handler.handle_notice(err); self.notice_handler.handle_notice(err);
} }
} }
Some(backend::Message::ParameterStatus { parameter, value }) => { Some(backend::Message::ParameterStatus(body)) => {
self.parameters.insert(parameter, value); self.parameters.insert(try!(body.name()).to_owned(),
try!(body.value()).to_owned());
} }
val => return Ok(val), val => return Ok(val),
} }
@ -357,31 +362,32 @@ impl InnerConnection {
} }
fn read_message_with_notification_nonblocking(&mut self) fn read_message_with_notification_nonblocking(&mut self)
-> io::Result<Option<backend::Message>> { -> io::Result<Option<backend::Message<Vec<u8>>>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop { loop {
match try_desync!(self, self.stream.read_message_nonblocking()) { match try_desync!(self, self.stream.read_message_nonblocking()) {
Some(backend::Message::NoticeResponse { fields }) => { Some(backend::Message::NoticeResponse(body)) => {
if let Ok(err) = DbError::new_raw(fields) { if let Ok(err) = DbError::new_raw(&mut body.fields()) {
self.notice_handler.handle_notice(err); self.notice_handler.handle_notice(err);
} }
} }
Some(backend::Message::ParameterStatus { parameter, value }) => { Some(backend::Message::ParameterStatus(body)) => {
self.parameters.insert(parameter, value); self.parameters.insert(try!(body.name()).to_owned(),
try!(body.value()).to_owned());
} }
val => return Ok(val), val => return Ok(val),
} }
} }
} }
fn read_message(&mut self) -> io::Result<backend::Message> { fn read_message(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
loop { loop {
match try!(self.read_message_with_notification()) { match try!(self.read_message_with_notification()) {
backend::Message::NotificationResponse { process_id, channel, payload } => { backend::Message::NotificationResponse(body) => {
self.notifications.push_back(Notification { self.notifications.push_back(Notification {
process_id: process_id, process_id: body.process_id(),
channel: channel, channel: try!(body.channel()).to_owned(),
payload: payload, payload: try!(body.message()).to_owned(),
}) })
} }
val => return Ok(val), val => return Ok(val),
@ -399,28 +405,30 @@ impl InnerConnection {
try!(self.stream.write_message(|buf| frontend::password_message(&pass, buf))); try!(self.stream.write_message(|buf| frontend::password_message(&pass, buf)));
try!(self.stream.flush()); try!(self.stream.flush());
} }
backend::Message::AuthenticationMD5Password { salt } => { backend::Message::AuthenticationMd5Password(body) => {
let pass = try!(user.password.ok_or_else(|| { let pass = try!(user.password.ok_or_else(|| {
ConnectError::ConnectParams("a password was requested but not provided".into()) ConnectError::ConnectParams("a password was requested but not provided".into())
})); }));
let output = authentication::md5_hash(user.user.as_bytes(), pass.as_bytes(), salt); let output = authentication::md5_hash(user.user.as_bytes(),
pass.as_bytes(),
body.salt());
try!(self.stream.write_message(|buf| frontend::password_message(&output, buf))); try!(self.stream.write_message(|buf| frontend::password_message(&output, buf)));
try!(self.stream.flush()); try!(self.stream.flush());
} }
backend::Message::AuthenticationKerberosV5 | backend::Message::AuthenticationKerberosV5 |
backend::Message::AuthenticationSCMCredential | backend::Message::AuthenticationScmCredential |
backend::Message::AuthenticationGSS | backend::Message::AuthenticationGss |
backend::Message::AuthenticationSSPI => { backend::Message::AuthenticationSspi => {
return Err(ConnectError::Io(io::Error::new(io::ErrorKind::Other, return Err(ConnectError::Io(io::Error::new(io::ErrorKind::Other,
"unsupported authentication"))) "unsupported authentication")))
} }
backend::Message::ErrorResponse { fields } => return DbError::new_connect(fields), backend::Message::ErrorResponse(body) => return DbError::new_connect(&mut body.fields()),
_ => return Err(ConnectError::Io(bad_response())), _ => return Err(ConnectError::Io(bad_response())),
} }
match try!(self.read_message()) { match try!(self.read_message()) {
backend::Message::AuthenticationOk => Ok(()), backend::Message::AuthenticationOk => Ok(()),
backend::Message::ErrorResponse { fields } => DbError::new_connect(fields), backend::Message::ErrorResponse(body) => DbError::new_connect(&mut body.fields()),
_ => Err(ConnectError::Io(bad_response())), _ => Err(ConnectError::Io(bad_response())),
} }
} }
@ -439,35 +447,43 @@ impl InnerConnection {
match try!(self.read_message()) { match try!(self.read_message()) {
backend::Message::ParseComplete => {} backend::Message::ParseComplete => {}
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(self.wait_for_ready()); try!(self.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
_ => bad_response!(self), _ => bad_response!(self),
} }
let raw_param_types = match try!(self.read_message()) { let raw_param_types = match try!(self.read_message()) {
backend::Message::ParameterDescription { types } => types, backend::Message::ParameterDescription(body) => body,
_ => bad_response!(self), _ => bad_response!(self),
}; };
let raw_columns = match try!(self.read_message()) { let raw_columns = match try!(self.read_message()) {
backend::Message::RowDescription { descriptions } => descriptions, backend::Message::RowDescription(body) => Some(body),
backend::Message::NoData => vec![], backend::Message::NoData => None,
_ => bad_response!(self), _ => bad_response!(self),
}; };
try!(self.wait_for_ready()); try!(self.wait_for_ready());
let mut param_types = vec![]; let param_types = try!(raw_param_types
for oid in raw_param_types { .parameters()
param_types.push(try!(self.get_type(oid))); .map_err(Into::into)
} .and_then(|oid| self.get_type(oid))
.collect());
let mut columns = vec![]; let columns = match raw_columns {
for RowDescriptionEntry { name, type_oid, .. } in raw_columns { Some(body) => {
columns.push(Column::new(name, try!(self.get_type(type_oid)))); try!(body.fields()
} .and_then(|field| {
Ok(Column::new(field.name().to_owned(),
try!(self.get_type(field.type_oid()))))
})
.collect())
}
None => vec![],
};
Ok((param_types, columns)) Ok((param_types, columns))
} }
@ -477,7 +493,7 @@ impl InnerConnection {
loop { loop {
match try!(self.read_message()) { match try!(self.read_message()) {
backend::Message::EmptyQueryResponse | backend::Message::EmptyQueryResponse |
backend::Message::CommandComplete { .. } => { backend::Message::CommandComplete(_) => {
more_rows = false; more_rows = false;
break; break;
} }
@ -485,12 +501,15 @@ impl InnerConnection {
more_rows = true; more_rows = true;
break; break;
} }
backend::Message::DataRow { row } => buf.push_back(row), backend::Message::DataRow(body) => {
backend::Message::ErrorResponse { fields } => { let row = try!(body.values().map(|v| v.map(ToOwned::to_owned)).collect());
try!(self.wait_for_ready()); buf.push_back(row);
return DbError::new(fields);
} }
backend::Message::CopyInResponse { .. } => { backend::Message::ErrorResponse(body) => {
try!(self.wait_for_ready());
return DbError::new(&mut body.fields());
}
backend::Message::CopyInResponse(_) => {
try!(self.stream.write_message(|buf| { try!(self.stream.write_message(|buf| {
frontend::copy_fail("COPY queries cannot be directly executed", buf) frontend::copy_fail("COPY queries cannot be directly executed", buf)
})); }));
@ -498,9 +517,9 @@ impl InnerConnection {
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))); .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf))));
try!(self.stream.flush()); try!(self.stream.flush());
} }
backend::Message::CopyOutResponse { .. } => { backend::Message::CopyOutResponse(_) => {
loop { loop {
if let backend::Message::ReadyForQuery { .. } = try!(self.read_message()) { if let backend::Message::ReadyForQuery(_) = try!(self.read_message()) {
break; break;
} }
} }
@ -563,9 +582,9 @@ impl InnerConnection {
match try!(self.read_message()) { match try!(self.read_message()) {
backend::Message::BindComplete => Ok(()), backend::Message::BindComplete => Ok(()),
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(self.wait_for_ready()); try!(self.wait_for_ready());
DbError::new(fields) DbError::new(&mut body.fields())
} }
_ => { _ => {
self.desynchronized = true; self.desynchronized = true;
@ -618,7 +637,7 @@ impl InnerConnection {
try!(self.stream.flush()); try!(self.stream.flush());
let resp = match try!(self.read_message()) { let resp = match try!(self.read_message()) {
backend::Message::CloseComplete => Ok(()), backend::Message::CloseComplete => Ok(()),
backend::Message::ErrorResponse { fields } => DbError::new(fields), backend::Message::ErrorResponse(body) => DbError::new(&mut body.fields()),
_ => bad_response!(self), _ => bad_response!(self),
}; };
try!(self.wait_for_ready()); try!(self.wait_for_ready());
@ -813,7 +832,7 @@ impl InnerConnection {
#[allow(needless_return)] #[allow(needless_return)]
fn wait_for_ready(&mut self) -> Result<()> { fn wait_for_ready(&mut self) -> Result<()> {
match try!(self.read_message()) { match try!(self.read_message()) {
backend::Message::ReadyForQuery { .. } => Ok(()), backend::Message::ReadyForQuery(_) => Ok(()),
_ => bad_response!(self), _ => bad_response!(self),
} }
} }
@ -827,13 +846,14 @@ impl InnerConnection {
let mut result = vec![]; let mut result = vec![];
loop { loop {
match try!(self.read_message()) { match try!(self.read_message()) {
backend::Message::ReadyForQuery { .. } => break, backend::Message::ReadyForQuery(_) => break,
backend::Message::DataRow { row } => { backend::Message::DataRow(body) => {
result.push(row.into_iter() let row = try!(body.values()
.map(|opt| opt.map(|b| String::from_utf8_lossy(&b).into_owned())) .map(|v| v.map(|v| String::from_utf8_lossy(v).into_owned()))
.collect()); .collect());
result.push(row);
} }
backend::Message::CopyInResponse { .. } => { backend::Message::CopyInResponse(_) => {
try!(self.stream.write_message(|buf| { try!(self.stream.write_message(|buf| {
frontend::copy_fail("COPY queries cannot be directly executed", buf) frontend::copy_fail("COPY queries cannot be directly executed", buf)
})); }));
@ -841,9 +861,9 @@ impl InnerConnection {
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))); .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf))));
try!(self.stream.flush()); try!(self.stream.flush());
} }
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(self.wait_for_ready()); try!(self.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
_ => {} _ => {}
} }
@ -1314,9 +1334,9 @@ trait OtherNew {
} }
trait DbErrorNew { trait DbErrorNew {
fn new_raw(fields: Vec<(u8, String)>) -> result::Result<DbError, ()>; fn new_raw(fields: &mut ErrorFields) -> io::Result<DbError>;
fn new_connect<T>(fields: Vec<(u8, String)>) -> result::Result<T, ConnectError>; fn new_connect<T>(fields: &mut ErrorFields) -> result::Result<T, ConnectError>;
fn new<T>(fields: Vec<(u8, String)>) -> Result<T>; fn new<T>(fields: &mut ErrorFields) -> Result<T>;
} }
trait RowsNew<'a> { trait RowsNew<'a> {

View File

@ -113,11 +113,11 @@ impl<'a> FallibleIterator for Iter<'a> {
} }
match conn.read_message_with_notification_nonblocking() { match conn.read_message_with_notification_nonblocking() {
Ok(Some(backend::Message::NotificationResponse { process_id, channel, payload })) => { Ok(Some(backend::Message::NotificationResponse(body))) => {
Ok(Some(Notification { Ok(Some(Notification {
process_id: process_id, process_id: body.process_id(),
channel: channel, channel: try!(body.channel()).to_owned(),
payload: payload, payload: try!(body.message()).to_owned(),
})) }))
} }
Ok(None) => Ok(None), Ok(None) => Ok(None),
@ -152,11 +152,11 @@ impl<'a> FallibleIterator for BlockingIter<'a> {
} }
match conn.read_message_with_notification() { match conn.read_message_with_notification() {
Ok(backend::Message::NotificationResponse { process_id, channel, payload }) => { Ok(backend::Message::NotificationResponse(body)) => {
Ok(Some(Notification { Ok(Some(Notification {
process_id: process_id, process_id: body.process_id(),
channel: channel, channel: try!(body.channel()).to_owned(),
payload: payload, payload: try!(body.message()).to_owned(),
})) }))
} }
Err(err) => Err(Error::Io(err)), Err(err) => Err(Error::Io(err)),
@ -188,11 +188,11 @@ impl<'a> FallibleIterator for TimeoutIter<'a> {
} }
match conn.read_message_with_notification_timeout(self.timeout) { match conn.read_message_with_notification_timeout(self.timeout) {
Ok(Some(backend::Message::NotificationResponse { process_id, channel, payload })) => { Ok(Some(backend::Message::NotificationResponse(body))) => {
Ok(Some(Notification { Ok(Some(Notification {
process_id: process_id, process_id: body.process_id(),
channel: channel, channel: try!(body.channel()).to_owned(),
payload: payload, payload: try!(body.message()).to_owned(),
})) }))
} }
Ok(None) => Ok(None), Ok(None) => Ok(None),

View File

@ -47,12 +47,12 @@ impl MessageStream {
self.stream.write_all(&self.buf).map_err(From::from) self.stream.write_all(&self.buf).map_err(From::from)
} }
fn inner_read_message(&mut self, b: u8) -> io::Result<backend::Message> { fn inner_read_message(&mut self, b: u8) -> io::Result<backend::Message<Vec<u8>>> {
self.buf.resize(MESSAGE_HEADER_SIZE, 0); self.buf.resize(MESSAGE_HEADER_SIZE, 0);
self.buf[0] = b; self.buf[0] = b;
try!(self.stream.read_exact(&mut self.buf[1..])); try!(self.stream.read_exact(&mut self.buf[1..]));
let len = match try!(backend::Message::parse(&self.buf)) { let len = match try!(backend::Message::parse_owned(&self.buf)) {
ParseResult::Complete { message, .. } => return Ok(message), ParseResult::Complete { message, .. } => return Ok(message),
ParseResult::Incomplete { required_size } => Some(required_size.unwrap()), ParseResult::Incomplete { required_size } => Some(required_size.unwrap()),
}; };
@ -62,13 +62,13 @@ impl MessageStream {
try!(self.stream.read_exact(&mut self.buf[MESSAGE_HEADER_SIZE..])); try!(self.stream.read_exact(&mut self.buf[MESSAGE_HEADER_SIZE..]));
}; };
match try!(backend::Message::parse(&self.buf)) { match try!(backend::Message::parse_owned(&self.buf)) {
ParseResult::Complete { message, .. } => Ok(message), ParseResult::Complete { message, .. } => Ok(message),
ParseResult::Incomplete { .. } => unreachable!(), ParseResult::Incomplete { .. } => unreachable!(),
} }
} }
pub fn read_message(&mut self) -> io::Result<backend::Message> { pub fn read_message(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
let mut b = [0; 1]; let mut b = [0; 1];
try!(self.stream.read_exact(&mut b)); try!(self.stream.read_exact(&mut b));
self.inner_read_message(b[0]) self.inner_read_message(b[0])
@ -76,7 +76,7 @@ impl MessageStream {
pub fn read_message_timeout(&mut self, pub fn read_message_timeout(&mut self,
timeout: Duration) timeout: Duration)
-> io::Result<Option<backend::Message>> { -> io::Result<Option<backend::Message<Vec<u8>>>> {
try!(self.set_read_timeout(Some(timeout))); try!(self.set_read_timeout(Some(timeout)));
let mut b = [0; 1]; let mut b = [0; 1];
let r = self.stream.read_exact(&mut b); let r = self.stream.read_exact(&mut b);
@ -90,7 +90,7 @@ impl MessageStream {
} }
} }
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message>> { pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message<Vec<u8>>>> {
try!(self.set_nonblocking(true)); try!(self.set_nonblocking(true));
let mut b = [0; 1]; let mut b = [0; 1];
let r = self.stream.read_exact(&mut b); let r = self.stream.read_exact(&mut b);

View File

@ -1,5 +1,6 @@
//! Prepared statements //! Prepared statements
use fallible_iterator::FallibleIterator;
use std::cell::{Cell, RefMut}; use std::cell::{Cell, RefMut};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::fmt; use std::fmt;
@ -132,20 +133,20 @@ impl<'conn> Statement<'conn> {
let num; let num;
loop { loop {
match try!(conn.read_message()) { match try!(conn.read_message()) {
backend::Message::DataRow { .. } => {} backend::Message::DataRow(_) => {}
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(conn.wait_for_ready()); try!(conn.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
backend::Message::CommandComplete { tag } => { backend::Message::CommandComplete(body) => {
num = parse_update_count(tag); num = parse_update_count(try!(body.tag()));
break; break;
} }
backend::Message::EmptyQueryResponse => { backend::Message::EmptyQueryResponse => {
num = 0; num = 0;
break; break;
} }
backend::Message::CopyInResponse { .. } => { backend::Message::CopyInResponse(_) => {
try!(conn.stream.write_message(|buf| { try!(conn.stream.write_message(|buf| {
frontend::copy_fail("COPY queries cannot be directly executed", buf) frontend::copy_fail("COPY queries cannot be directly executed", buf)
})); }));
@ -153,13 +154,13 @@ impl<'conn> Statement<'conn> {
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))); .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf))));
try!(conn.stream.flush()); try!(conn.stream.flush());
} }
backend::Message::CopyOutResponse { .. } => { backend::Message::CopyOutResponse(_) => {
loop { loop {
match try!(conn.read_message()) { match try!(conn.read_message()) {
backend::Message::CopyDone => break, backend::Message::CopyDone => break,
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(conn.wait_for_ready()); try!(conn.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
_ => {} _ => {}
} }
@ -269,14 +270,20 @@ impl<'conn> Statement<'conn> {
try!(conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)); try!(conn.raw_execute(&self.info.name, "", 0, self.param_types(), params));
let (format, column_formats) = match try!(conn.read_message()) { let (format, column_formats) = match try!(conn.read_message()) {
backend::Message::CopyInResponse { format, column_formats } => (format, column_formats), backend::Message::CopyInResponse(body) => {
backend::Message::ErrorResponse { fields } => { let format = body.format();
let column_formats = try!(body.column_formats()
.map(|f| Format::from_u16(f))
.collect());
(format, column_formats)
}
backend::Message::ErrorResponse(body) => {
try!(conn.wait_for_ready()); try!(conn.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
_ => { _ => {
loop { loop {
if let backend::Message::ReadyForQuery { .. } = try!(conn.read_message()) { if let backend::Message::ReadyForQuery(_) = try!(conn.read_message()) {
return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput, return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput,
"called `copy_in` on a \ "called `copy_in` on a \
non-`COPY FROM STDIN` \ non-`COPY FROM STDIN` \
@ -289,7 +296,7 @@ impl<'conn> Statement<'conn> {
let mut info = CopyInfo { let mut info = CopyInfo {
conn: conn, conn: conn,
format: Format::from_u16(format as u16), format: Format::from_u16(format as u16),
column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(), column_formats: column_formats,
}; };
let mut buf = [0; 16 * 1024]; let mut buf = [0; 16 * 1024];
@ -311,7 +318,7 @@ impl<'conn> Statement<'conn> {
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))); .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf))));
try!(info.conn.stream.flush()); try!(info.conn.stream.flush());
match try!(info.conn.read_message()) { match try!(info.conn.read_message()) {
backend::Message::ErrorResponse { .. } => { backend::Message::ErrorResponse(_) => {
// expected from the CopyFail // expected from the CopyFail
} }
_ => { _ => {
@ -330,10 +337,10 @@ impl<'conn> Statement<'conn> {
try!(info.conn.stream.flush()); try!(info.conn.stream.flush());
let num = match try!(info.conn.read_message()) { let num = match try!(info.conn.read_message()) {
backend::Message::CommandComplete { tag } => parse_update_count(tag), backend::Message::CommandComplete(body) => parse_update_count(try!(body.tag())),
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(info.conn.wait_for_ready()); try!(info.conn.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
_ => { _ => {
info.conn.desynchronized = true; info.conn.desynchronized = true;
@ -372,17 +379,21 @@ impl<'conn> Statement<'conn> {
try!(conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)); try!(conn.raw_execute(&self.info.name, "", 0, self.param_types(), params));
let (format, column_formats) = match try!(conn.read_message()) { let (format, column_formats) = match try!(conn.read_message()) {
backend::Message::CopyOutResponse { format, column_formats } => { backend::Message::CopyOutResponse(body) => {
let format = body.format();
let column_formats = try!(body.column_formats()
.map(|f| Format::from_u16(f))
.collect());
(format, column_formats) (format, column_formats)
} }
backend::Message::CopyInResponse { .. } => { backend::Message::CopyInResponse(_) => {
try!(conn.stream.write_message(|buf| frontend::copy_fail("", buf))); try!(conn.stream.write_message(|buf| frontend::copy_fail("", buf)));
try!(conn.stream try!(conn.stream
.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))); .write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf))));
try!(conn.stream.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))); try!(conn.stream.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf))));
try!(conn.stream.flush()); try!(conn.stream.flush());
match try!(conn.read_message()) { match try!(conn.read_message()) {
backend::Message::ErrorResponse { .. } => { backend::Message::ErrorResponse(_) => {
// expected from the CopyFail // expected from the CopyFail
} }
_ => { _ => {
@ -395,13 +406,13 @@ impl<'conn> Statement<'conn> {
"called `copy_out` on a non-`COPY TO \ "called `copy_out` on a non-`COPY TO \
STDOUT` statement"))); STDOUT` statement")));
} }
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
try!(conn.wait_for_ready()); try!(conn.wait_for_ready());
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
_ => { _ => {
loop { loop {
if let backend::Message::ReadyForQuery { .. } = try!(conn.read_message()) { if let backend::Message::ReadyForQuery(_) = try!(conn.read_message()) {
return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput, return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput,
"called `copy_out` on a \ "called `copy_out` on a \
non-`COPY TO STDOUT` statement"))); non-`COPY TO STDOUT` statement")));
@ -413,20 +424,20 @@ impl<'conn> Statement<'conn> {
let mut info = CopyInfo { let mut info = CopyInfo {
conn: conn, conn: conn,
format: Format::from_u16(format as u16), format: Format::from_u16(format as u16),
column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(), column_formats: column_formats,
}; };
let count; let count;
loop { loop {
match try!(info.conn.read_message()) { match try!(info.conn.read_message()) {
backend::Message::CopyData { data } => { backend::Message::CopyData(body) => {
let mut data = &data[..]; let mut data = body.data();
while !data.is_empty() { while !data.is_empty() {
match w.write_with_info(data, &info) { match w.write_with_info(data, &info) {
Ok(n) => data = &data[n..], Ok(n) => data = &data[n..],
Err(e) => { Err(e) => {
loop { loop {
if let backend::Message::ReadyForQuery { .. } = if let backend::Message::ReadyForQuery(_) =
try!(info.conn.read_message()) { try!(info.conn.read_message()) {
return Err(Error::Io(e)); return Err(Error::Io(e));
} }
@ -436,21 +447,21 @@ impl<'conn> Statement<'conn> {
} }
} }
backend::Message::CopyDone => {} backend::Message::CopyDone => {}
backend::Message::CommandComplete { tag } => { backend::Message::CommandComplete(body) => {
count = parse_update_count(tag); count = parse_update_count(try!(body.tag()));
break; break;
} }
backend::Message::ErrorResponse { fields } => { backend::Message::ErrorResponse(body) => {
loop { loop {
if let backend::Message::ReadyForQuery { .. } = if let backend::Message::ReadyForQuery(_) =
try!(info.conn.read_message()) { try!(info.conn.read_message()) {
return DbError::new(fields); return DbError::new(&mut body.fields());
} }
} }
} }
_ => { _ => {
loop { loop {
if let backend::Message::ReadyForQuery { .. } = if let backend::Message::ReadyForQuery(_) =
try!(info.conn.read_message()) { try!(info.conn.read_message()) {
return Err(Error::Io(bad_response())); return Err(Error::Io(bad_response()));
} }
@ -586,6 +597,6 @@ impl Format {
} }
} }
fn parse_update_count(tag: String) -> u64 { fn parse_update_count(tag: &str) -> u64 {
tag.split(' ').last().unwrap().parse().unwrap_or(0) tag.split(' ').last().unwrap().parse().unwrap_or(0)
} }