diff --git a/src/io_util.rs b/src/io_util.rs index 5db9aeb3..07311ace 100644 --- a/src/io_util.rs +++ b/src/io_util.rs @@ -14,30 +14,57 @@ use message::FrontendMessage::SslRequest; const DEFAULT_PORT: u16 = 5432; -pub trait StreamWrapper: Read+Write+Send { - fn get_ref(&self) -> &S; - fn get_mut(&mut self) -> &mut S; +pub struct Stream(InternalStream); + +impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } } -impl StreamWrapper for SslStream { - fn get_ref(&self) -> &S { +impl Write for Stream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl StreamWrapper for Stream { + fn get_ref(&self) -> &Stream { + self + } + + fn get_mut(&mut self) -> &mut Stream { + self + } +} + +pub trait StreamWrapper: Read+Write+Send { + fn get_ref(&self) -> &Stream; + fn get_mut(&mut self) -> &mut Stream; +} + +impl StreamWrapper for SslStream { + fn get_ref(&self) -> &Stream { self.get_ref() } - fn get_mut(&mut self) -> &mut S { + fn get_mut(&mut self) -> &mut Stream { self.get_mut() } } pub trait NegotiateSsl { - fn negotiate_ssl(&mut self, host: &str, stream: S) - -> Result>, Box> - where S: Read+Write+Send+'static; + fn negotiate_ssl(&mut self, host: &str, stream: Stream) + -> Result, Box>; } impl NegotiateSsl for SslContext { - fn negotiate_ssl(&mut self, _: &str, stream: S) -> Result>, Box> - where S: Read+Write+Send+'static { + fn negotiate_ssl(&mut self, _: &str, stream: Stream) + -> Result, Box> { let stream = try!(SslStream::new(self, stream)); Ok(Box::new(stream)) } @@ -56,8 +83,7 @@ pub enum SslMode { pub enum NoSsl {} impl NegotiateSsl for NoSsl { - fn negotiate_ssl(&mut self, _: &str, _: S) - -> Result>, Box> { + fn negotiate_ssl(&mut self, _: &str, _: Stream) -> Result, Box> { match *self {} } } @@ -68,16 +94,6 @@ pub enum InternalStream { Unix(UnixStream), } -impl StreamWrapper for InternalStream { - fn get_ref(&self) -> &InternalStream { - self - } - - fn get_mut(&mut self) -> &mut InternalStream { - self - } -} - impl Read for InternalStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match *self { @@ -122,9 +138,9 @@ fn open_socket(params: &ConnectParams) -> Result { } pub fn initialize_stream(params: &ConnectParams, ssl: &mut SslMode) - -> Result>, ConnectError> + -> Result, ConnectError> where N: NegotiateSsl { - let mut socket = try!(open_socket(params)); + let mut socket = Stream(try!(open_socket(params))); let (ssl_required, negotiator) = match *ssl { SslMode::None => return Ok(Box::new(socket)), diff --git a/src/lib.rs b/src/lib.rs index e9a1fd8c..b47122cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,11 +79,10 @@ use std::path::PathBuf; pub use error::{Error, ConnectError, SqlState, DbError, ErrorPosition}; #[doc(inline)] pub use types::{Oid, Type, Kind, ToSql, FromSql}; -pub use io_util::{SslMode, NegotiateSsl, StreamWrapper, NoSsl}; +pub use io_util::{SslMode, NegotiateSsl, StreamWrapper, NoSsl, Stream}; use types::IsNull; #[doc(inline)] pub use types::Slice; -use io_util::InternalStream; use message::BackendMessage::*; use message::FrontendMessage::*; use message::{FrontendMessage, BackendMessage, RowDescriptionEntry}; @@ -465,7 +464,7 @@ struct CachedStatement { } struct InnerConnection { - stream: BufStream>>, + stream: BufStream>, notice_handler: Box, notifications: VecDeque, cancel_data: CancelData,