diff --git a/Cargo.toml b/Cargo.toml index 41d9837e..30159560 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ with-uuid = ["uuid"] [dependencies] bufstream = "0.1" byteorder = "0.5" +fallible-iterator = "0.1" hex = "0.2" log = "0.3" phf = "=0.7.15" diff --git a/src/lib.rs b/src/lib.rs index 1a680201..1288c691 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,7 @@ extern crate bufstream; extern crate byteorder; +extern crate fallible_iterator; extern crate hex; #[macro_use] extern crate log; @@ -64,7 +65,7 @@ use postgres_protocol::message::frontend; use error::{Error, ConnectError, SqlState, DbError}; use io::TlsHandshake; -use message::{Backend, RowDescriptionEntry, ReadMessage}; +use message::{Backend, RowDescriptionEntry}; use notification::{Notifications, Notification}; use params::{ConnectParams, IntoConnectParams, UserInfo}; use priv_io::MessageStream; @@ -288,8 +289,8 @@ impl InnerConnection { loop { match try!(conn.read_message()) { Backend::BackendKeyData { process_id, secret_key } => { - conn.cancel_data.process_id = process_id as i32; - conn.cancel_data.secret_key = secret_key as i32; + conn.cancel_data.process_id = process_id; + conn.cancel_data.secret_key = secret_key; } Backend::ReadyForQuery { .. } => break, Backend::ErrorResponse { fields } => return DbError::new_connect(fields), @@ -303,7 +304,7 @@ impl InnerConnection { fn read_message_with_notification(&mut self) -> std_io::Result { debug_assert!(!self.desynchronized); loop { - match try_desync!(self, ReadMessage::read_message(&mut self.stream)) { + match try_desync!(self, self.stream.read_message()) { Backend::NoticeResponse { fields } => { if let Ok(err) = DbError::new_raw(fields) { self.notice_handler.handle_notice(err); @@ -357,9 +358,9 @@ impl InnerConnection { fn read_message(&mut self) -> std_io::Result { loop { match try!(self.read_message_with_notification()) { - Backend::NotificationResponse { pid, channel, payload } => { + Backend::NotificationResponse { process_id, channel, payload } => { self.notifications.push_back(Notification { - pid: pid, + process_id: process_id, channel: channel, payload: payload, }) diff --git a/src/message.rs b/src/message.rs index 2172fd0c..e522b853 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,11 +1,8 @@ +use fallible_iterator::FallibleIterator; +use postgres_protocol::message::backend::Message; use std::io; -use std::io::prelude::*; -use std::mem; -use std::time::Duration; -use byteorder::{BigEndian, ReadBytesExt}; use types::Oid; -use priv_io::StreamOptions; pub enum Backend { AuthenticationCleartextPassword, @@ -18,8 +15,8 @@ pub enum Backend { AuthenticationSCMCredential, AuthenticationSSPI, BackendKeyData { - process_id: u32, - secret_key: u32, + process_id: i32, + secret_key: i32, }, BindComplete, CloseComplete, @@ -50,7 +47,7 @@ pub enum Backend { fields: Vec<(u8, String)>, }, NotificationResponse { - pid: u32, + process_id: i32, channel: String, payload: String, }, @@ -71,6 +68,107 @@ pub enum Backend { }, } +impl Backend { + pub fn convert(message: Message) -> io::Result { + let ret = match message { + Message::AuthenticationCleartextPassword => Backend::AuthenticationCleartextPassword, + Message::AuthenticationGss => Backend::AuthenticationGSS, + Message::AuthenticationKerberosV5 => Backend::AuthenticationKerberosV5, + Message::AuthenticationMd55Password(body) => { + Backend::AuthenticationMD5Password { salt: body.salt() } + } + Message::AuthenticationOk => Backend::AuthenticationOk, + Message::AuthenticationScmCredential => Backend::AuthenticationSCMCredential, + Message::AuthenticationSspi => Backend::AuthenticationSSPI, + Message::BackendKeyData(body) => { + Backend::BackendKeyData { + process_id: body.process_id(), + secret_key: body.secret_key(), + } + } + Message::BindComplete => Backend::BindComplete, + Message::CloseComplete => Backend::CloseComplete, + Message::CommandComplete(body) => { + Backend::CommandComplete { + tag: body.tag().to_owned() + } + } + Message::CopyData(body) => Backend::CopyData { data: body.data().to_owned() }, + Message::CopyDone => Backend::CopyDone, + Message::CopyInResponse(body) => { + Backend::CopyInResponse { + format: body.format(), + column_formats: try!(body.column_formats().collect()), + } + } + Message::CopyOutResponse(body) => { + Backend::CopyOutResponse { + format: body.format(), + column_formats: try!(body.column_formats().collect()), + } + } + Message::DataRow(body) => { + Backend::DataRow { + row: try!(body.values().map(|r| r.map(|d| d.to_owned())).collect()), + } + } + Message::EmptyQueryResponse => Backend::EmptyQueryResponse, + Message::ErrorResponse(body) => { + Backend::ErrorResponse { + fields: try!(body.fields().map(|f| (f.type_(), f.value().to_owned())).collect()), + } + } + Message::NoData => Backend::NoData, + Message::NoticeResponse(body) => { + Backend::NoticeResponse { + fields: try!(body.fields().map(|f| (f.type_(), f.value().to_owned())).collect()), + } + } + Message::NotificationResponse(body) => { + Backend::NotificationResponse { + process_id: body.process_id(), + channel: body.channel().to_owned(), + payload: body.message().to_owned(), + } + } + Message::ParameterDescription(body) => { + Backend::ParameterDescription { + types: try!(body.parameters().collect()), + } + } + Message::ParameterStatus(body) => { + Backend::ParameterStatus { + parameter: body.name().to_owned(), + value: body.value().to_owned(), + } + } + Message::ParseComplete => Backend::ParseComplete, + Message::PortalSuspended => Backend::PortalSuspended, + Message::ReadyForQuery(body) => Backend::ReadyForQuery { _state: body.status() }, + Message::RowDescription(body) => { + let fields = body.fields() + .map(|f| { + RowDescriptionEntry { + name: f.name().to_owned(), + table_oid: f.table_oid(), + column_id: f.column_id(), + type_oid: f.type_oid(), + type_size: f.type_size(), + type_modifier: f.type_modifier(), + format: f.format(), + } + }); + Backend::RowDescription { + descriptions: try!(fields.collect()), + } + } + _ => return Err(io::Error::new(io::ErrorKind::InvalidInput, "unknown message type")), + }; + + Ok(ret) + } +} + pub struct RowDescriptionEntry { pub name: String, pub table_oid: Oid, @@ -80,254 +178,3 @@ pub struct RowDescriptionEntry { pub type_modifier: i32, pub format: i16, } - -#[doc(hidden)] -trait ReadCStr { - fn read_cstr(&mut self) -> io::Result; -} - -impl ReadCStr for R { - fn read_cstr(&mut self) -> io::Result { - let mut buf = vec![]; - try!(self.read_until(0, &mut buf)); - buf.pop(); - String::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::Other, err)) - } -} - -#[doc(hidden)] -pub trait ReadMessage { - fn read_message(&mut self) -> io::Result; - - fn read_message_timeout(&mut self, timeout: Duration) -> io::Result>; - - fn read_message_nonblocking(&mut self) -> io::Result>; - - fn finish_read_message(&mut self, ident: u8) -> io::Result; -} - -impl ReadMessage for R { - fn read_message(&mut self) -> io::Result { - let ident = try!(self.read_u8()); - self.finish_read_message(ident) - } - - fn read_message_timeout(&mut self, timeout: Duration) -> io::Result> { - try!(self.set_read_timeout(Some(timeout))); - let ident = self.read_u8(); - try!(self.set_read_timeout(None)); - - match ident { - Ok(ident) => self.finish_read_message(ident).map(Some), - Err(e) => { - let e: io::Error = e.into(); - if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut { - Ok(None) - } else { - Err(e) - } - } - } - } - - fn read_message_nonblocking(&mut self) -> io::Result> { - try!(self.set_nonblocking(true)); - let ident = self.read_u8(); - try!(self.set_nonblocking(false)); - - match ident { - Ok(ident) => self.finish_read_message(ident).map(Some), - Err(e) => { - let e: io::Error = e.into(); - if e.kind() == io::ErrorKind::WouldBlock { - Ok(None) - } else { - Err(e) - } - } - } - } - - #[allow(cyclomatic_complexity)] - fn finish_read_message(&mut self, ident: u8) -> io::Result { - // subtract size of length value - let len = try!(self.read_u32::()) - mem::size_of::() as u32; - let mut rdr = self.by_ref().take(len as u64); - - let ret = match ident { - b'1' => Backend::ParseComplete, - b'2' => Backend::BindComplete, - b'3' => Backend::CloseComplete, - b'A' => { - Backend::NotificationResponse { - pid: try!(rdr.read_u32::()), - channel: try!(rdr.read_cstr()), - payload: try!(rdr.read_cstr()), - } - } - b'c' => Backend::CopyDone, - b'C' => Backend::CommandComplete { tag: try!(rdr.read_cstr()) }, - b'd' => { - let mut data = vec![]; - try!(rdr.read_to_end(&mut data)); - Backend::CopyData { data: data } - } - b'D' => try!(read_data_row(&mut rdr)), - b'E' => Backend::ErrorResponse { fields: try!(read_fields(&mut rdr)) }, - b'G' => { - let format = try!(rdr.read_u8()); - let mut column_formats = vec![]; - for _ in 0..try!(rdr.read_u16::()) { - column_formats.push(try!(rdr.read_u16::())); - } - Backend::CopyInResponse { - format: format, - column_formats: column_formats, - } - } - b'H' => { - let format = try!(rdr.read_u8()); - let mut column_formats = vec![]; - for _ in 0..try!(rdr.read_u16::()) { - column_formats.push(try!(rdr.read_u16::())); - } - Backend::CopyOutResponse { - format: format, - column_formats: column_formats, - } - } - b'I' => Backend::EmptyQueryResponse, - b'K' => { - Backend::BackendKeyData { - process_id: try!(rdr.read_u32::()), - secret_key: try!(rdr.read_u32::()), - } - } - b'n' => Backend::NoData, - b'N' => Backend::NoticeResponse { fields: try!(read_fields(&mut rdr)) }, - b'R' => try!(read_auth_message(&mut rdr)), - b's' => Backend::PortalSuspended, - b'S' => { - Backend::ParameterStatus { - parameter: try!(rdr.read_cstr()), - value: try!(rdr.read_cstr()), - } - } - b't' => try!(read_parameter_description(&mut rdr)), - b'T' => try!(read_row_description(&mut rdr)), - b'Z' => Backend::ReadyForQuery { _state: try!(rdr.read_u8()) }, - t => { - return Err(io::Error::new(io::ErrorKind::Other, - format!("unexpected message tag `{}`", t))) - } - }; - if rdr.limit() != 0 { - return Err(io::Error::new(io::ErrorKind::Other, "didn't read entire message")); - } - Ok(ret) - } -} - -fn read_fields(buf: &mut R) -> io::Result> { - let mut fields = vec![]; - loop { - let ty = try!(buf.read_u8()); - if ty == 0 { - break; - } - - fields.push((ty, try!(buf.read_cstr()))); - } - - Ok(fields) -} - -fn read_data_row(buf: &mut R) -> io::Result { - let len = try!(buf.read_u16::()) as usize; - let mut values = Vec::with_capacity(len); - - for _ in 0..len { - let val = match try!(buf.read_i32::()) { - -1 => None, - len => { - let mut data = vec![0; len as usize]; - try!(buf.read_exact(&mut data)); - Some(data) - } - }; - values.push(val); - } - - Ok(Backend::DataRow { row: values }) -} - -fn read_auth_message(buf: &mut R) -> io::Result { - Ok(match try!(buf.read_i32::()) { - 0 => Backend::AuthenticationOk, - 2 => Backend::AuthenticationKerberosV5, - 3 => Backend::AuthenticationCleartextPassword, - 5 => { - let mut salt = [0; 4]; - try!(buf.read_exact(&mut salt)); - Backend::AuthenticationMD5Password { salt: salt } - } - 6 => Backend::AuthenticationSCMCredential, - 7 => Backend::AuthenticationGSS, - 9 => Backend::AuthenticationSSPI, - t => { - return Err(io::Error::new(io::ErrorKind::Other, - format!("unexpected authentication tag `{}`", t))) - } - }) -} - -fn read_parameter_description(buf: &mut R) -> io::Result { - let len = try!(buf.read_u16::()) as usize; - let mut types = Vec::with_capacity(len); - - for _ in 0..len { - types.push(try!(buf.read_u32::())); - } - - Ok(Backend::ParameterDescription { types: types }) -} - -fn read_row_description(buf: &mut R) -> io::Result { - let len = try!(buf.read_u16::()) as usize; - let mut types = Vec::with_capacity(len); - - for _ in 0..len { - types.push(RowDescriptionEntry { - name: try!(buf.read_cstr()), - table_oid: try!(buf.read_u32::()), - column_id: try!(buf.read_i16::()), - type_oid: try!(buf.read_u32::()), - type_size: try!(buf.read_i16::()), - type_modifier: try!(buf.read_i32::()), - format: try!(buf.read_i16::()), - }) - } - - Ok(Backend::RowDescription { descriptions: types }) -} - -trait FromUsize: Sized { - fn from_usize(x: usize) -> io::Result; -} - -macro_rules! from_usize { - ($t:ty) => { - impl FromUsize for $t { - fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::max_value() as usize { - Err(io::Error::new(io::ErrorKind::InvalidInput, "value too large to transmit")) - } else { - Ok(x as $t) - } - } - } - } -} - -from_usize!(u16); -from_usize!(i32); diff --git a/src/notification.rs b/src/notification.rs index 4f48e5ce..b8d6b343 100644 --- a/src/notification.rs +++ b/src/notification.rs @@ -11,7 +11,7 @@ use error::Error; #[derive(Clone, Debug)] pub struct Notification { /// The process ID of the notifying backend process. - pub pid: u32, + pub process_id: i32, /// The name of the channel that the notify has been raised on. pub channel: String, /// The "payload" string passed from the notifying process. @@ -110,9 +110,9 @@ impl<'a> Iterator for Iter<'a> { } match conn.read_message_with_notification_nonblocking() { - Ok(Some(Backend::NotificationResponse { pid, channel, payload })) => { + Ok(Some(Backend::NotificationResponse { process_id, channel, payload })) => { Some(Ok(Notification { - pid: pid, + process_id: process_id, channel: channel, payload: payload, })) @@ -148,9 +148,9 @@ impl<'a> Iterator for BlockingIter<'a> { } match conn.read_message_with_notification() { - Ok(Backend::NotificationResponse { pid, channel, payload }) => { + Ok(Backend::NotificationResponse { process_id, channel, payload }) => { Some(Ok(Notification { - pid: pid, + process_id: process_id, channel: channel, payload: payload, })) @@ -187,9 +187,9 @@ impl<'a> Iterator for TimeoutIter<'a> { } match conn.read_message_with_notification_timeout(self.timeout) { - Ok(Some(Backend::NotificationResponse { pid, channel, payload })) => { + Ok(Some(Backend::NotificationResponse { process_id, channel, payload })) => { Some(Ok(Notification { - pid: pid, + process_id: process_id, channel: channel, payload: payload, })) diff --git a/src/priv_io.rs b/src/priv_io.rs index 16b6ec6c..a1168c69 100644 --- a/src/priv_io.rs +++ b/src/priv_io.rs @@ -15,9 +15,10 @@ use postgres_protocol::message::frontend; use postgres_protocol::message::backend::{self, ParseResult}; use TlsMode; -use params::{ConnectParams, ConnectTarget}; use error::ConnectError; use io::TlsStream; +use message::Backend; +use params::{ConnectParams, ConnectTarget}; const DEFAULT_PORT: u16 = 5432; const MESSAGE_HEADER_SIZE: usize = 5; @@ -45,9 +46,10 @@ impl MessageStream { self.stream.write_all(&self.buf) } - pub fn read_message<'a>(&'a mut self) -> io::Result> { + fn raw_read_message<'a>(&'a mut self, b: u8) -> io::Result> { self.buf.resize(MESSAGE_HEADER_SIZE, 0); - try!(self.stream.read_exact(&mut self.buf)); + self.buf[0] = b; + try!(self.stream.read_exact(&mut self.buf[1..])); let len = match try!(backend::Message::parse(&self.buf)) { // FIXME this is dumb but an explicit return runs into borrowck issues :( @@ -66,34 +68,46 @@ impl MessageStream { } } + fn inner_read_message(&mut self, b: u8) -> io::Result { + let message = try!(self.raw_read_message(b)); + Backend::convert(message) + } + + pub fn read_message(&mut self) -> io::Result { + let b = try!(self.stream.read_u8()); + self.inner_read_message(b) + } + + pub fn read_message_timeout(&mut self, timeout: Duration) -> io::Result> { + try!(self.set_read_timeout(Some(timeout))); + let b = self.stream.read_u8(); + try!(self.set_read_timeout(None)); + + match b { + Ok(b) => self.inner_read_message(b).map(Some), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => { + Ok(None) + } + Err(e) => Err(e), + } + } + + pub fn read_message_nonblocking(&mut self) -> io::Result> { + try!(self.set_nonblocking(true)); + let b = self.stream.read_u8(); + try!(self.set_nonblocking(false)); + + match b { + Ok(b) => self.inner_read_message(b).map(Some), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None), + Err(e) => Err(e), + } + } + pub fn flush(&mut self) -> io::Result<()> { self.stream.flush() } -} -impl io::Read for MessageStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.stream.read(buf) - } -} - -impl io::BufRead for MessageStream { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.stream.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.stream.consume(amt) - } -} - -#[doc(hidden)] -pub trait StreamOptions { - fn set_read_timeout(&self, timeout: Option) -> io::Result<()>; - fn set_nonblocking(&self, nonblock: bool) -> io::Result<()>; -} - -impl StreamOptions for MessageStream { fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { match self.stream.get_ref().get_ref().0 { InternalStream::Tcp(ref s) => s.set_read_timeout(timeout), diff --git a/tests/test.rs b/tests/test.rs index 9148248d..047321a3 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -576,12 +576,12 @@ fn test_notification_iterator_some() { or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel2, 'world'", &[])); check_notification(Notification { - pid: 0, + process_id: 0, channel: "test_notification_iterator_one_channel".to_string(), payload: "hello".to_string() }, it.next().unwrap().unwrap()); check_notification(Notification { - pid: 0, + process_id: 0, channel: "test_notification_iterator_one_channel2".to_string(), payload: "world".to_string() }, it.next().unwrap().unwrap()); @@ -589,7 +589,7 @@ fn test_notification_iterator_some() { or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel, '!'", &[])); check_notification(Notification { - pid: 0, + process_id: 0, channel: "test_notification_iterator_one_channel".to_string(), payload: "!".to_string() }, it.next().unwrap().unwrap()); @@ -609,7 +609,7 @@ fn test_notifications_next_block() { let notifications = conn.notifications(); check_notification(Notification { - pid: 0, + process_id: 0, channel: "test_notifications_next_block".to_string(), payload: "foo".to_string() }, or_panic!(notifications.blocking_iter().next().unwrap())); @@ -631,7 +631,7 @@ fn test_notification_next_timeout() { let notifications = conn.notifications(); let mut it = notifications.timeout_iter(Duration::from_secs(1)); check_notification(Notification { - pid: 0, + process_id: 0, channel: "test_notifications_next_timeout".to_string(), payload: "foo".to_string() }, or_panic!(it.next().unwrap()));