diff --git a/postgres-tokio/Cargo.toml b/postgres-tokio/Cargo.toml index ab62f9d1..7ff9c7b7 100644 --- a/postgres-tokio/Cargo.toml +++ b/postgres-tokio/Cargo.toml @@ -3,6 +3,9 @@ name = "postgres-tokio" version = "0.1.0" authors = ["Steven Fackler "] +[features] +with-openssl = ["tokio-openssl", "openssl"] + [dependencies] fallible-iterator = "0.1.3" futures = "0.1.7" @@ -12,3 +15,6 @@ postgres-protocol = "0.2" tokio-core = "0.1" tokio-dns-unofficial = "0.1" tokio-uds = "0.1" + +tokio-openssl = { version = "0.1", optional = true } +openssl = { version = "0.9", optional = true } diff --git a/postgres-tokio/src/lib.rs b/postgres-tokio/src/lib.rs index d4570acb..5525583c 100644 --- a/postgres-tokio/src/lib.rs +++ b/postgres-tokio/src/lib.rs @@ -7,6 +7,11 @@ extern crate tokio_core; extern crate tokio_dns; extern crate tokio_uds; +#[cfg(feature = "tokio-openssl")] +extern crate tokio_openssl; +#[cfg(feature = "openssl")] +extern crate openssl; + use fallible_iterator::FallibleIterator; use futures::{Future, IntoFuture, BoxFuture, Stream, Sink, Poll, StartSend}; use futures::future::Either; @@ -31,13 +36,21 @@ use error::{ConnectError, Error, DbError}; use params::{ConnectParams, IntoConnectParams}; use stream::PostgresStream; use types::{Oid, Type, ToSql, SessionInfo, IsNull, FromSql, WrongType}; +use tls::Handshake; pub mod error; mod stream; +pub mod tls; #[cfg(test)] mod test; +pub enum TlsMode { + Require(Box), + Prefer(Box), + None, +} + #[derive(Debug, Copy, Clone)] pub struct CancelData { pub process_id: i32, @@ -119,7 +132,10 @@ impl fmt::Debug for Connection { } impl Connection { - pub fn connect(params: T, handle: &Handle) -> BoxFuture + pub fn connect(params: T, + tls_mode: TlsMode, + handle: &Handle) + -> BoxFuture where T: IntoConnectParams { let params = match params.into_connect_params() { @@ -127,8 +143,7 @@ impl Connection { Err(e) => return futures::failed(ConnectError::ConnectParams(e)).boxed(), }; - stream::connect(params.host(), params.port(), handle) - .map_err(ConnectError::Io) + stream::connect(params.host().clone(), params.port(), tls_mode, handle) .map(|s| { let (sender, receiver) = mpsc::channel(); Connection(InnerConnection { diff --git a/postgres-tokio/src/stream.rs b/postgres-tokio/src/stream.rs index daaa6083..d5ed78d3 100644 --- a/postgres-tokio/src/stream.rs +++ b/postgres-tokio/src/stream.rs @@ -1,6 +1,8 @@ -use futures::{BoxFuture, Future, IntoFuture, Async}; +use futures::{BoxFuture, Future, IntoFuture, Async, Sink, Stream as FuturesStream}; +use futures::future::Either; use postgres_shared::params::Host; use postgres_protocol::message::backend::{self, ParseResult}; +use postgres_protocol::message::frontend; use std::io::{self, Read, Write}; use tokio_core::io::{Io, Codec, EasyBuf, Framed}; use tokio_core::net::TcpStream; @@ -8,68 +10,117 @@ use tokio_core::reactor::Handle; use tokio_dns; use tokio_uds::UnixStream; -pub type PostgresStream = Framed; +use TlsMode; +use error::ConnectError; +use tls::TlsStream; -pub fn connect(host: &Host, - port: u16, - handle: &Handle) - -> BoxFuture { - match *host { +pub type PostgresStream = Framed, PostgresCodec>; + +pub fn connect(host: Host, + port: u16, + tls_mode: TlsMode, + handle: &Handle) + -> BoxFuture { + let inner = match host { Host::Tcp(ref host) => { - tokio_dns::tcp_connect((&**host, port), handle.remote().clone()) - .map(|s| InnerStream::Tcp(s).framed(PostgresCodec)) - .boxed() + Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone()) + .map(|s| Stream(InnerStream::Tcp(s)))) } Host::Unix(ref host) => { let addr = host.join(format!(".s.PGSQL.{}", port)); - UnixStream::connect(addr, handle) - .map(|s| InnerStream::Unix(s).framed(PostgresCodec)) - .into_future() - .boxed() + Either::B(UnixStream::connect(addr, handle) + .map(|s| Stream(InnerStream::Unix(s))) + .into_future()) } - } + }; + + let (required, mut handshaker) = match tls_mode { + TlsMode::Require(h) => (true, h), + TlsMode::Prefer(h) => (false, h), + TlsMode::None => { + return inner.map(|s| { + let s: Box = Box::new(s); + s.framed(PostgresCodec) + }) + .map_err(ConnectError::Io) + .boxed() + }, + }; + + inner.map(|s| s.framed(SslCodec)) + .and_then(|s| { + let mut buf = vec![]; + frontend::ssl_request(&mut buf); + s.send(buf) + }) + .and_then(|s| s.into_future().map_err(|e| e.0)) + .map_err(ConnectError::Io) + .and_then(move |(m, s)| { + let s = s.into_inner(); + match (m, required) { + (Some(b'N'), true) => { + Either::A(Err(ConnectError::Tls("the server does not support TLS".into())).into_future()) + } + (Some(b'N'), false) => { + let s: Box = Box::new(s); + Either::A(Ok(s).into_future()) + }, + (None, _) => Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))).into_future()), + _ => { + let host = match host { + Host::Tcp(ref host) => host, + Host::Unix(_) => unreachable!(), + }; + Either::B(handshaker.handshake(host, s).map_err(ConnectError::Tls)) + } + } + }) + .map(|s| s.framed(PostgresCodec)) + .boxed() } -pub enum InnerStream { +pub struct Stream(InnerStream); + +enum InnerStream { Tcp(TcpStream), Unix(UnixStream), } -impl Read for InnerStream { +impl Read for Stream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { + match self.0 { InnerStream::Tcp(ref mut s) => s.read(buf), InnerStream::Unix(ref mut s) => s.read(buf), } } } -impl Write for InnerStream { +impl Write for Stream { fn write(&mut self, buf: &[u8]) -> io::Result { - match *self { + match self.0 { InnerStream::Tcp(ref mut s) => s.write(buf), InnerStream::Unix(ref mut s) => s.write(buf), } } fn flush(&mut self) -> io::Result<()> { - match *self { + match self.0 { InnerStream::Tcp(ref mut s) => s.flush(), InnerStream::Unix(ref mut s) => s.flush(), } } } -impl Io for InnerStream { +impl Io for Stream { fn poll_read(&mut self) -> Async<()> { - match *self { + match self.0 { InnerStream::Tcp(ref mut s) => s.poll_read(), InnerStream::Unix(ref mut s) => s.poll_read(), } } fn poll_write(&mut self) -> Async<()> { - match *self { + match self.0 { InnerStream::Tcp(ref mut s) => s.poll_write(), InnerStream::Unix(ref mut s) => s.poll_write(), } @@ -98,3 +149,25 @@ impl Codec for PostgresCodec { Ok(()) } } + +struct SslCodec; + +impl Codec for SslCodec { + type In = u8; + type Out = Vec; + + fn decode(&mut self, buf: &mut EasyBuf) -> io::Result> { + if buf.as_slice().is_empty() { + Ok(None) + } else { + let byte = buf.as_slice()[0]; + buf.drain_to(1); + Ok(Some(byte)) + } + } + + fn encode(&mut self, msg: Vec, buf: &mut Vec) -> io::Result<()> { + buf.extend_from_slice(&msg); + Ok(()) + } +} diff --git a/postgres-tokio/src/test.rs b/postgres-tokio/src/test.rs index 2dea6493..3f3d6a25 100644 --- a/postgres-tokio/src/test.rs +++ b/postgres-tokio/src/test.rs @@ -10,7 +10,7 @@ use params::ConnectParams; fn basic() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://postgres@localhost", &handle) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) .then(|c| c.unwrap().close()); l.run(done).unwrap(); } @@ -19,7 +19,9 @@ fn basic() { fn md5_user() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://md5_user:password@localhost/postgres", &handle); + let done = Connection::connect("postgres://md5_user:password@localhost/postgres", + TlsMode::None, + &handle); l.run(done).unwrap(); } @@ -27,7 +29,9 @@ fn md5_user() { fn md5_user_no_pass() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://md5_user@localhost/postgres", &handle); + let done = Connection::connect("postgres://md5_user@localhost/postgres", + TlsMode::None, + &handle); match l.run(done) { Err(ConnectError::ConnectParams(_)) => {} Err(e) => panic!("unexpected error {}", e), @@ -39,7 +43,9 @@ fn md5_user_no_pass() { fn md5_user_wrong_pass() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres", &handle); + let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres", + TlsMode::None, + &handle); match l.run(done) { Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {} Err(e) => panic!("unexpected error {}", e), @@ -51,7 +57,9 @@ fn md5_user_wrong_pass() { fn pass_user() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://pass_user:password@localhost/postgres", &handle); + let done = Connection::connect("postgres://pass_user:password@localhost/postgres", + TlsMode::None, + &handle); l.run(done).unwrap(); } @@ -59,7 +67,9 @@ fn pass_user() { fn pass_user_no_pass() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://pass_user@localhost/postgres", &handle); + let done = Connection::connect("postgres://pass_user@localhost/postgres", + TlsMode::None, + &handle); match l.run(done) { Err(ConnectError::ConnectParams(_)) => {} Err(e) => panic!("unexpected error {}", e), @@ -71,7 +81,9 @@ fn pass_user_no_pass() { fn pass_user_wrong_pass() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres", &handle); + let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres", + TlsMode::None, + &handle); match l.run(done) { Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {} Err(e) => panic!("unexpected error {}", e), @@ -82,7 +94,7 @@ fn pass_user_wrong_pass() { #[test] fn batch_execute_ok() { let mut l = Core::new().unwrap(); - let done = Connection::connect("postgres://postgres@localhost", &l.handle()) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) .then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL);")); l.run(done).unwrap(); } @@ -90,7 +102,7 @@ fn batch_execute_ok() { #[test] fn batch_execute_err() { let mut l = Core::new().unwrap(); - let done = Connection::connect("postgres://postgres@localhost", &l.handle()) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) .then(|r| r.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL); \ INSERT INTO foo DEFAULT VALUES;")) .and_then(|c| c.batch_execute("SELECT * FROM bogo")) @@ -110,7 +122,7 @@ fn batch_execute_err() { #[test] fn prepare_execute() { let mut l = Core::new().unwrap(); - let done = Connection::connect("postgres://postgres@localhost", &l.handle()) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) .then(|c| { c.unwrap().prepare("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)") }) @@ -127,7 +139,7 @@ fn prepare_execute() { #[test] fn query() { let mut l = Core::new().unwrap(); - let done = Connection::connect("postgres://postgres@localhost", &l.handle()) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) .then(|c| { c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR); INSERT INTO foo (name) VALUES ('joe'), ('bob')") @@ -149,7 +161,7 @@ fn query() { #[test] fn transaction() { let mut l = Core::new().unwrap(); - let done = Connection::connect("postgres://postgres@localhost", &l.handle()) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) .then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);")) .then(|c| c.unwrap().transaction()) .then(|t| t.unwrap().batch_execute("INSERT INTO foo (name) VALUES ('joe');")) @@ -170,7 +182,7 @@ fn transaction() { fn unix_socket() { let mut l = Core::new().unwrap(); let handle = l.handle(); - let done = Connection::connect("postgres://postgres@localhost", &handle) + let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) .then(|c| c.unwrap().prepare("SHOW unix_socket_directories")) .and_then(|(s, c)| c.query(&s, &[]).collect()) .then(|r| { @@ -178,8 +190,28 @@ fn unix_socket() { let params = ConnectParams::builder() .user("postgres", None) .build_unix(r[0].get::(0)); - Connection::connect(params, &handle) + Connection::connect(params, TlsMode::None, &handle) }) .then(|c| c.unwrap().batch_execute("")); l.run(done).unwrap(); } + +#[cfg(feature = "with-openssl")] +#[test] +fn openssl_required() { + use openssl::ssl::{SslMethod, SslConnectorBuilder}; + use tls::openssl::OpenSsl; + + let mut builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap(); + builder.builder_mut().set_ca_file("../.travis/server.crt").unwrap(); + let negotiator = OpenSsl::from(builder.build()); + + let mut l = Core::new().unwrap(); + let done = Connection::connect("postgres://postgres@localhost", + TlsMode::Require(Box::new(negotiator)), + &l.handle()) + .then(|c| c.unwrap().prepare("SELECT 1")) + .and_then(|(s, c)| c.query(&s, &[]).collect()) + .map(|(r, _)| assert_eq!(r[0].get::(0), 1)); + l.run(done).unwrap(); +} diff --git a/postgres-tokio/src/tls/mod.rs b/postgres-tokio/src/tls/mod.rs new file mode 100644 index 00000000..f27051f3 --- /dev/null +++ b/postgres-tokio/src/tls/mod.rs @@ -0,0 +1,33 @@ +use futures::BoxFuture; +use std::error::Error; +use tokio_core::io::Io; + +pub use stream::Stream; + +#[cfg(feature = "with-openssl")] +pub mod openssl; + +pub trait TlsStream: Io + Send { + fn get_ref(&self) -> &Stream; + + fn get_mut(&mut self) -> &mut Stream; +} + +impl Io for Box {} + +impl TlsStream for Stream { + fn get_ref(&self) -> &Stream { + self + } + + fn get_mut(&mut self) -> &mut Stream { + self + } +} + +pub trait Handshake: 'static + Sync + Send { + fn handshake(&mut self, + host: &str, + stream: Stream) + -> BoxFuture, Box>; +} diff --git a/postgres-tokio/src/tls/openssl.rs b/postgres-tokio/src/tls/openssl.rs new file mode 100644 index 00000000..df5926e7 --- /dev/null +++ b/postgres-tokio/src/tls/openssl.rs @@ -0,0 +1,50 @@ +use futures::{Future, BoxFuture}; +use openssl::ssl::{SslMethod, SslConnector, SslConnectorBuilder}; +use openssl::error::ErrorStack; +use std::error::Error; +use tokio_openssl::{SslConnectorExt, SslStream}; + +use tls::{Stream, TlsStream, Handshake}; + +impl TlsStream for SslStream { + fn get_ref(&self) -> &Stream { + self.get_ref().get_ref() + } + + fn get_mut(&mut self) -> &mut Stream { + self.get_mut().get_mut() + } +} + +pub struct OpenSsl(SslConnector); + +impl OpenSsl { + pub fn new() -> Result { + let connector = try!(SslConnectorBuilder::new(SslMethod::tls())).build(); + Ok(OpenSsl(connector)) + } +} + +impl From for OpenSsl { + fn from(connector: SslConnector) -> OpenSsl { + OpenSsl(connector) + } +} + +impl Handshake for OpenSsl { + fn handshake(&mut self, + host: &str, + stream: Stream) + -> BoxFuture, Box> { + self.0.connect_async(host, stream) + .map(|s| { + let s: Box = Box::new(s); + s + }) + .map_err(|e| { + let e: Box = Box::new(e); + e + }) + .boxed() + } +} \ No newline at end of file