diff --git a/.travis.yml b/.travis.yml index dd471519..eee474f5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ before_script: - "./.travis/setup.sh" script: - cargo test -- cargo test --features "uuid rustc-serialize time unix_socket serde chrono" +- cargo test --features "uuid rustc-serialize time unix_socket serde chrono openssl" - cargo doc --no-deps --features "unix_socket" after_success: - test $TRAVIS_PULL_REQUEST == "false" && test $TRAVIS_BRANCH == "master" && test $TRAVIS_RUST_VERSION == "nightly" && ./.travis/update_docs.sh diff --git a/Cargo.toml b/Cargo.toml index d0f3ab34..e40274e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,18 +24,19 @@ path = "tests/test.rs" phf_codegen = "0.7" [dependencies] -phf = "0.7" -openssl = "0.6" -log = "0.3" -rustc-serialize = "0.3" +bufstream = "0.1" byteorder = "0.3" debug-builders = "0.1" -bufstream = "0.1" -uuid = { version = "0.1", optional = true } -unix_socket = { version = "0.3", optional = true } -time = { version = "0.1.14", optional = true } -serde = { version = "0.3", optional = true } +log = "0.3" +phf = "0.7" +rust-crypto = "0.2" +rustc-serialize = "0.3" chrono = { version = "0.2.14", optional = true } +openssl = { version = "0.6", optional = true } +serde = { version = "0.3", optional = true } +time = { version = "0.1.14", optional = true } +unix_socket = { version = "0.3", optional = true } +uuid = { version = "0.1", optional = true } [dev-dependencies] url = "0.2" diff --git a/src/error.rs b/src/error.rs index 7b491d09..62a5b778 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,6 @@ pub use ugh_privacy::DbError; use byteorder; -use openssl::ssl::error::SslError; use phf; use std::error; use std::convert::From; @@ -29,8 +28,8 @@ pub enum ConnectError { UnsupportedAuthentication, /// The Postgres server does not support SSL encryption. NoSslSupport, - /// There was an error initializing the SSL session. - SslError(SslError), + /// There was an error initializing the SSL session + SslError(Box), /// There was an error communicating with the server. IoError(io::Error), /// The server sent an unexpected response. @@ -67,7 +66,7 @@ impl error::Error for ConnectError { fn cause(&self) -> Option<&error::Error> { match *self { ConnectError::DbError(ref err) => Some(err), - ConnectError::SslError(ref err) => Some(err), + ConnectError::SslError(ref err) => Some(&**err), ConnectError::IoError(ref err) => Some(err), _ => None } @@ -86,12 +85,6 @@ impl From for ConnectError { } } -impl From for ConnectError { - fn from(err: SslError) -> ConnectError { - ConnectError::SslError(err) - } -} - impl From for ConnectError { fn from(err: byteorder::Error) -> ConnectError { ConnectError::IoError(From::from(err)) diff --git a/src/io/mod.rs b/src/io/mod.rs new file mode 100644 index 00000000..addbe916 --- /dev/null +++ b/src/io/mod.rs @@ -0,0 +1,27 @@ +//! Types and traits for SSL adaptors. +pub use priv_io::Stream; + +use std::error::Error; +use std::io::prelude::*; + +#[cfg(feature = "openssl")] +mod openssl; + +/// A trait implemented by SSL adaptors. +pub trait StreamWrapper: Read+Write+Send { + /// Returns a reference to the underlying `Stream`. + fn get_ref(&self) -> &Stream; + + /// Returns a mutable reference to the underlying `Stream`. + fn get_mut(&mut self) -> &mut Stream; +} + +/// A trait implemented by types that can negotiate SSL over a Postgres stream. +pub trait NegotiateSsl { + /// Negotiates an SSL session, returning a wrapper around the provided + /// stream. + /// + /// The host portion of the connection parameters is provided for hostname + /// verification. + fn negotiate_ssl(&self, host: &str, stream: Stream) -> Result, Box>; +} diff --git a/src/io/openssl.rs b/src/io/openssl.rs new file mode 100644 index 00000000..999c4bab --- /dev/null +++ b/src/io/openssl.rs @@ -0,0 +1,23 @@ +extern crate openssl; + +use std::error::Error; + +use self::openssl::ssl::{SslContext, SslStream}; +use io::{StreamWrapper, Stream, NegotiateSsl}; + +impl StreamWrapper for SslStream { + fn get_ref(&self) -> &Stream { + self.get_ref() + } + + fn get_mut(&mut self) -> &mut Stream { + self.get_mut() + } +} + +impl NegotiateSsl for SslContext { + fn negotiate_ssl(&self, _: &str, stream: Stream) -> Result, Box> { + let stream = try!(SslStream::new(self, stream)); + Ok(Box::new(stream)) + } +} diff --git a/src/io_util.rs b/src/io_util.rs deleted file mode 100644 index 2661d09d..00000000 --- a/src/io_util.rs +++ /dev/null @@ -1,90 +0,0 @@ -use openssl::ssl::{SslStream, MaybeSslStream}; -use std::io; -use std::io::prelude::*; -use std::net::TcpStream; -#[cfg(feature = "unix_socket")] -use unix_socket::UnixStream; -use byteorder::ReadBytesExt; - -use {ConnectParams, SslMode, ConnectTarget, ConnectError}; -use message; -use message::WriteMessage; -use message::FrontendMessage::SslRequest; - -const DEFAULT_PORT: u16 = 5432; - -pub enum InternalStream { - Tcp(TcpStream), - #[cfg(feature = "unix_socket")] - Unix(UnixStream), -} - -impl Read for InternalStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - InternalStream::Tcp(ref mut s) => s.read(buf), - #[cfg(feature = "unix_socket")] - InternalStream::Unix(ref mut s) => s.read(buf), - } - } -} - -impl Write for InternalStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - match *self { - InternalStream::Tcp(ref mut s) => s.write(buf), - #[cfg(feature = "unix_socket")] - InternalStream::Unix(ref mut s) => s.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match *self { - InternalStream::Tcp(ref mut s) => s.flush(), - #[cfg(feature = "unix_socket")] - InternalStream::Unix(ref mut s) => s.flush(), - } - } -} - -fn open_socket(params: &ConnectParams) -> Result { - let port = params.port.unwrap_or(DEFAULT_PORT); - match params.target { - ConnectTarget::Tcp(ref host) => { - Ok(try!(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp))) - } - #[cfg(feature = "unix_socket")] - ConnectTarget::Unix(ref path) => { - let mut path = path.clone(); - path.push(&format!(".s.PGSQL.{}", port)); - Ok(try!(UnixStream::connect(&path).map(InternalStream::Unix))) - } - } -} - -pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode) - -> Result, ConnectError> { - let mut socket = try!(open_socket(params)); - - let (ssl_required, ctx) = match *ssl { - SslMode::None => return Ok(MaybeSslStream::Normal(socket)), - SslMode::Prefer(ref ctx) => (false, ctx), - SslMode::Require(ref ctx) => (true, ctx) - }; - - try!(socket.write_message(&SslRequest { code: message::SSL_CODE })); - try!(socket.flush()); - - if try!(socket.read_u8()) == 'N' as u8 { - if ssl_required { - return Err(ConnectError::NoSslSupport); - } else { - return Ok(MaybeSslStream::Normal(socket)); - } - } - - match SslStream::new(ctx, socket) { - Ok(stream) => Ok(MaybeSslStream::Ssl(stream)), - Err(err) => Err(ConnectError::SslError(err)) - } -} diff --git a/src/lib.rs b/src/lib.rs index 86e0f17e..b99a62f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,9 +47,9 @@ extern crate bufstream; extern crate byteorder; +extern crate crypto; #[macro_use] extern crate log; -extern crate openssl; extern crate phf; extern crate rustc_serialize as serialize; #[cfg(feature = "unix_socket")] @@ -57,17 +57,16 @@ extern crate unix_socket; extern crate debug_builders; use bufstream::BufStream; +use crypto::digest::Digest; +use crypto::md5::Md5; use debug_builders::DebugStruct; -use openssl::crypto::hash::{self, Hasher}; -use openssl::ssl::{SslContext, MaybeSslStream}; -use serialize::hex::ToHex; use std::ascii::AsciiExt; use std::borrow::{ToOwned, Cow}; use std::cell::{Cell, RefCell}; use std::collections::{VecDeque, HashMap}; use std::fmt; use std::iter::IntoIterator; -use std::io; +use std::io as std_io; use std::io::prelude::*; use std::mem; use std::slice; @@ -80,10 +79,10 @@ use std::path::PathBuf; pub use error::{Error, ConnectError, SqlState, DbError, ErrorPosition}; #[doc(inline)] pub use types::{Oid, Type, Kind, ToSql, FromSql}; +use io::{StreamWrapper, NegotiateSsl}; use types::IsNull; #[doc(inline)] pub use types::Slice; -use io_util::InternalStream; use message::BackendMessage::*; use message::FrontendMessage::*; use message::{FrontendMessage, BackendMessage, RowDescriptionEntry}; @@ -94,9 +93,10 @@ use url::Url; mod macros; mod error; -mod io_util; +pub mod io; mod message; mod ugh_privacy; +mod priv_io; mod url; mod util; pub mod types; @@ -388,9 +388,10 @@ pub struct CancelData { /// postgres::cancel_query(url, &SslMode::None, cancel_data); /// ``` pub fn cancel_query(params: T, ssl: &SslMode, data: CancelData) - -> result::Result<(), ConnectError> where T: IntoConnectParams { + -> result::Result<(), ConnectError> + where T: IntoConnectParams { let params = try!(params.into_connect_params()); - let mut socket = try!(io_util::initialize_stream(¶ms, ssl)); + let mut socket = try!(priv_io::initialize_stream(¶ms, ssl)); try!(socket.write_message(&CancelRequest { code: message::CANCEL_CODE, @@ -456,6 +457,16 @@ impl IsolationLevel { } } +/// Specifies the SSL support requested for a new connection. +pub enum SslMode { + /// The connection will not use SSL. + None, + /// The connection will use SSL if the backend supports it. + Prefer(Box), + /// The connection must use SSL. + Require(Box), +} + #[derive(Clone)] struct CachedStatement { name: String, @@ -464,7 +475,7 @@ struct CachedStatement { } struct InnerConnection { - stream: BufStream>, + stream: BufStream>, notice_handler: Box, notifications: VecDeque, cancel_data: CancelData, @@ -489,7 +500,7 @@ impl InnerConnection { fn connect(params: T, ssl: &SslMode) -> result::Result where T: IntoConnectParams { let params = try!(params.into_connect_params()); - let stream = try!(io_util::initialize_stream(¶ms, ssl)); + let stream = try!(priv_io::initialize_stream(¶ms, ssl)); let ConnectParams { user, database, mut options, .. } = params; @@ -569,7 +580,7 @@ impl InnerConnection { } } - fn write_messages(&mut self, messages: &[FrontendMessage]) -> io::Result<()> { + fn write_messages(&mut self, messages: &[FrontendMessage]) -> std_io::Result<()> { debug_assert!(!self.desynchronized); for message in messages { try_desync!(self, self.stream.write_message(message)); @@ -577,7 +588,7 @@ impl InnerConnection { Ok(try_desync!(self, self.stream.flush())) } - fn read_one_message(&mut self) -> io::Result> { + fn read_one_message(&mut self) -> std_io::Result> { debug_assert!(!self.desynchronized); match try_desync!(self, self.stream.read_message()) { NoticeResponse { fields } => { @@ -594,7 +605,7 @@ impl InnerConnection { } } - fn read_message_with_notification(&mut self) -> io::Result { + fn read_message_with_notification(&mut self) -> std_io::Result { loop { if let Some(msg) = try!(self.read_one_message()) { return Ok(msg); @@ -602,7 +613,7 @@ impl InnerConnection { } } - fn read_message(&mut self) -> io::Result { + fn read_message(&mut self) -> std_io::Result { loop { match try!(self.read_message_with_notification()) { NotificationResponse { pid, channel, payload } => { @@ -628,13 +639,14 @@ impl InnerConnection { } AuthenticationMD5Password { salt } => { let pass = try!(user.password.ok_or(ConnectError::MissingPassword)); - let mut hasher = Hasher::new(hash::Type::MD5); - let _ = hasher.write_all(pass.as_bytes()); - let _ = hasher.write_all(user.user.as_bytes()); - let output = hasher.finish().to_hex(); - let _ = hasher.write_all(output.as_bytes()); - let _ = hasher.write_all(&salt); - let output = format!("md5{}", hasher.finish().to_hex()); + let mut hasher = Md5::new(); + let _ = hasher.input(pass.as_bytes()); + let _ = hasher.input(user.user.as_bytes()); + let output = hasher.result_str(); + hasher.reset(); + let _ = hasher.input(output.as_bytes()); + let _ = hasher.input(&salt); + let output = format!("md5{}", hasher.result_str()); try!(self.write_messages(&[PasswordMessage { password: &output }])); @@ -1131,13 +1143,6 @@ impl Connection { self.batch_execute(level.to_set_query()) } - /// # Deprecated - /// - /// Use `transaction_isolation` instead. - pub fn get_transaction_isolation(&self) -> Result { - self.transaction_isolation() - } - /// Returns the isolation level which will be used for future transactions. pub fn transaction_isolation(&self) -> Result { let mut conn = self.conn.borrow_mut(); @@ -1251,17 +1256,6 @@ impl Connection { } } -/// Specifies the SSL support requested for a new connection. -#[derive(Debug)] -pub enum SslMode { - /// The connection will not use SSL. - None, - /// The connection will use SSL if the backend supports it. - Prefer(SslContext), - /// The connection must use SSL. - Require(SslContext) -} - /// Represents a transaction on a database connection. /// /// The transaction will roll back by default. diff --git a/src/priv_io.rs b/src/priv_io.rs new file mode 100644 index 00000000..6a531936 --- /dev/null +++ b/src/priv_io.rs @@ -0,0 +1,153 @@ +use byteorder::ReadBytesExt; +use std::io; +use std::io::prelude::*; +use std::net::TcpStream; +#[cfg(feature = "unix_socket")] +use unix_socket::UnixStream; +#[cfg(unix)] +use std::os::unix::io::{AsRawFd, RawFd}; +#[cfg(windows)] +use std::os::windows::io::{AsRawSocket, RawSocket}; + +use {SslMode, ConnectError, ConnectParams, ConnectTarget}; +use io::{NegotiateSsl, StreamWrapper}; +use message::{self, WriteMessage}; +use message::FrontendMessage::SslRequest; + +const DEFAULT_PORT: u16 = 5432; + +/// A connection to the Postgres server. +/// +/// It implements `Read`, `Write` and `StreamWrapper`, as well as `AsRawFd` on +/// Unix platforms and `AsRawSocket` on Windows platforms. +pub struct Stream(InternalStream); + +impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl Write for Stream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl StreamWrapper for Stream { + fn get_ref(&self) -> &Stream { + self + } + + fn get_mut(&mut self) -> &mut Stream { + self + } +} + +#[cfg(unix)] +impl AsRawFd for Stream { + fn as_raw_fd(&self) -> RawFd { + match self.0 { + InternalStream::Tcp(ref s) => s.as_raw_fd(), + #[cfg(feature = "unix_socket")] + InternalStream::Unix(ref s) => s.as_raw_fd(), + } + } +} + +#[cfg(windows)] +impl AsRawSocket for Stream { + fn as_raw_socket(&self) -> RawSocket { + // Unix sockets aren't supported on windows, so no need to match + match self.0 { + InternalStream::Tcp(ref s) => s.as_raw_socket(), + } + } +} + +enum InternalStream { + Tcp(TcpStream), + #[cfg(feature = "unix_socket")] + Unix(UnixStream), +} + +impl Read for InternalStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *self { + InternalStream::Tcp(ref mut s) => s.read(buf), + #[cfg(feature = "unix_socket")] + InternalStream::Unix(ref mut s) => s.read(buf), + } + } +} + +impl Write for InternalStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + match *self { + InternalStream::Tcp(ref mut s) => s.write(buf), + #[cfg(feature = "unix_socket")] + InternalStream::Unix(ref mut s) => s.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match *self { + InternalStream::Tcp(ref mut s) => s.flush(), + #[cfg(feature = "unix_socket")] + InternalStream::Unix(ref mut s) => s.flush(), + } + } +} + +fn open_socket(params: &ConnectParams) -> Result { + let port = params.port.unwrap_or(DEFAULT_PORT); + match params.target { + ConnectTarget::Tcp(ref host) => { + Ok(try!(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp))) + } + #[cfg(feature = "unix_socket")] + ConnectTarget::Unix(ref path) => { + let mut path = path.clone(); + path.push(&format!(".s.PGSQL.{}", port)); + Ok(try!(UnixStream::connect(&path).map(InternalStream::Unix))) + } + } +} + +pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode) + -> Result, ConnectError> { + let mut socket = Stream(try!(open_socket(params))); + + let (ssl_required, negotiator) = match *ssl { + SslMode::None => return Ok(Box::new(socket)), + SslMode::Prefer(ref negotiator) => (false, negotiator), + SslMode::Require(ref negotiator) => (true, negotiator), + }; + + try!(socket.write_message(&SslRequest { code: message::SSL_CODE })); + try!(socket.flush()); + + if try!(socket.read_u8()) == 'N' as u8 { + if ssl_required { + return Err(ConnectError::NoSslSupport); + } else { + return Ok(Box::new(socket)); + } + } + + // Postgres doesn't support SSL over unix sockets + let host = match params.target { + ConnectTarget::Tcp(ref host) => host, + #[cfg(feature = "unix_socket")] + ConnectTarget::Unix(_) => return Err(ConnectError::BadResponse) + }; + + match negotiator.negotiate_ssl(host, socket) { + Ok(stream) => Ok(stream), + Err(err) => Err(ConnectError::SslError(err)) + } +} diff --git a/tests/test.rs b/tests/test.rs index df596854..4a910f74 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,10 +1,11 @@ extern crate postgres; extern crate rustc_serialize as serialize; extern crate url; +#[cfg(feature = "openssl")] extern crate openssl; -use openssl::ssl::SslContext; -use openssl::ssl::SslMethod; +#[cfg(feature = "openssl")] +use openssl::ssl::{SslContext, SslMethod}; use std::thread; use postgres::{HandleNotice, @@ -670,18 +671,20 @@ fn test_cancel_query() { } #[test] +#[cfg(feature = "openssl")] fn test_require_ssl_conn() { let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); let conn = or_panic!(Connection::connect("postgres://postgres@localhost", - &SslMode::Require(ctx))); + &mut SslMode::Require(Box::new(ctx)))); or_panic!(conn.execute("SELECT 1::VARCHAR", &[])); } #[test] +#[cfg(feature = "openssl")] fn test_prefer_ssl_conn() { let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); let conn = or_panic!(Connection::connect("postgres://postgres@localhost", - &SslMode::Prefer(ctx))); + &mut SslMode::Prefer(Box::new(ctx)))); or_panic!(conn.execute("SELECT 1::VARCHAR", &[])); }