diff --git a/Cargo.toml b/Cargo.toml index 254b755c..568b3135 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "postgres-openssl", "postgres-native-tls", "tokio-postgres", + "tokio-postgres-openssl", ] [patch.crates-io] diff --git a/tokio-postgres-openssl/Cargo.toml b/tokio-postgres-openssl/Cargo.toml new file mode 100644 index 00000000..95534920 --- /dev/null +++ b/tokio-postgres-openssl/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "tokio-postgres-openssl" +version = "0.1.0" +authors = ["Steven Fackler "] + +[dependencies] +bytes = "0.4" +futures = "0.1" +openssl = "0.10" +tokio-io = "0.1" +tokio-openssl = "0.2" +tokio-postgres = { version = "0.3", path = "../tokio-postgres" } + +[dev-dependencies] +tokio = "0.1.7" diff --git a/tokio-postgres-openssl/src/lib.rs b/tokio-postgres-openssl/src/lib.rs new file mode 100644 index 00000000..9382f977 --- /dev/null +++ b/tokio-postgres-openssl/src/lib.rs @@ -0,0 +1,127 @@ +extern crate bytes; +extern crate futures; +extern crate openssl; +extern crate tokio_io; +extern crate tokio_openssl; +extern crate tokio_postgres; + +#[cfg(test)] +extern crate tokio; + +use bytes::{Buf, BufMut}; +use futures::{Future, IntoFuture, Poll}; +use openssl::error::ErrorStack; +use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod}; +use std::error::Error; +use std::io::{self, Read, Write}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_openssl::ConnectConfigurationExt; +use tokio_postgres::tls::{Socket, TlsConnect, TlsStream}; + +#[cfg(test)] +mod test; + +pub struct TlsConnector { + connector: SslConnector, + callback: Box Result<(), ErrorStack> + Sync + Send>, +} + +impl TlsConnector { + pub fn new() -> Result { + let connector = SslConnector::builder(SslMethod::tls())?.build(); + Ok(TlsConnector::with_connector(connector)) + } + + pub fn with_connector(connector: SslConnector) -> TlsConnector { + TlsConnector { + connector, + callback: Box::new(|_| Ok(())), + } + } + + pub fn set_callback(&mut self, f: F) + where + F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send, + { + self.callback = Box::new(f); + } +} + +impl TlsConnect for TlsConnector { + fn connect( + &self, + domain: &str, + socket: Socket, + ) -> Box, Error = Box> + Sync + Send> { + let f = self + .connector + .configure() + .and_then(|mut ssl| (self.callback)(&mut ssl).map(|_| ssl)) + .map_err(|e| { + let e: Box = Box::new(e); + e + }) + .into_future() + .and_then({ + let domain = domain.to_string(); + move |ssl| { + ssl.connect_async(&domain, socket) + .map(|s| { + let s: Box = Box::new(SslStream(s)); + s + }) + .map_err(|e| { + let e: Box = Box::new(e); + e + }) + } + }); + Box::new(f) + } +} + +struct SslStream(tokio_openssl::SslStream); + +impl Read for SslStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl AsyncRead for SslStream { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn read_buf(&mut self, buf: &mut B) -> Poll + where + B: BufMut, + { + self.0.read_buf(buf) + } +} + +impl Write for SslStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +impl AsyncWrite for SslStream { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.0.shutdown() + } + + fn write_buf(&mut self, buf: &mut B) -> Poll + where + B: Buf, + { + self.0.write_buf(buf) + } +} + +impl TlsStream for SslStream {} diff --git a/tokio-postgres-openssl/src/test.rs b/tokio-postgres-openssl/src/test.rs new file mode 100644 index 00000000..f5999d14 --- /dev/null +++ b/tokio-postgres-openssl/src/test.rs @@ -0,0 +1,60 @@ +use futures::{Future, Stream}; +use openssl::ssl::{SslConnector, SslMethod}; +use tokio::runtime::current_thread::Runtime; +use tokio_postgres::{self, TlsMode}; + +use TlsConnector; + +fn smoke_test(url: &str, tls: TlsMode) { + let mut runtime = Runtime::new().unwrap(); + + let handshake = tokio_postgres::connect(url.parse().unwrap(), tls); + let (mut client, connection) = runtime.block_on(handshake).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + let prepare = client.prepare("SELECT 1::INT4"); + let statement = runtime.block_on(prepare).unwrap(); + let select = client.query(&statement, &[]).collect().map(|rows| { + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); + }); + runtime.block_on(select).unwrap(); + + drop(statement); + drop(client); + runtime.run().unwrap(); +} + +#[test] +fn require() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let connector = TlsConnector::with_connector(builder.build()); + smoke_test( + "postgres://ssl_user@localhost:5433/postgres", + TlsMode::Require(Box::new(connector)), + ); +} + +#[test] +fn prefer() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let connector = TlsConnector::with_connector(builder.build()); + smoke_test( + "postgres://ssl_user@localhost:5433/postgres", + TlsMode::Prefer(Box::new(connector)), + ); +} + +#[test] +fn scram_user() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let connector = TlsConnector::with_connector(builder.build()); + smoke_test( + "postgres://scram_user:password@localhost:5433/postgres", + TlsMode::Require(Box::new(connector)), + ); +}