diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index 7d891acd..a204d54c 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -16,12 +16,12 @@ default = ["runtime"] runtime = ["tokio-postgres/runtime"] [dependencies] -futures = "0.1" +futures-preview = "0.3.0-alpha.17" native-tls = "0.2" -tokio-io = "0.1" -tokio-tls = "0.2.1" +tokio-io = { git = "https://github.com/tokio-rs/tokio" } +tokio-tls = { git = "https://github.com/tokio-rs/tokio" } tokio-postgres = { version = "0.4.0-rc.1", path = "../tokio-postgres", default-features = false } [dev-dependencies] -tokio = "0.1.7" -postgres = { version = "0.16.0-rc.1", path = "../postgres" } +tokio = { git = "https://github.com/tokio-rs/tokio" } +#postgres = { version = "0.16.0-rc.1", path = "../postgres" } diff --git a/postgres-native-tls/src/lib.rs b/postgres-native-tls/src/lib.rs index a6813038..32809f59 100644 --- a/postgres-native-tls/src/lib.rs +++ b/postgres-native-tls/src/lib.rs @@ -47,13 +47,15 @@ //! ``` #![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.2.0-rc.1")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] +#![feature(async_await)] -use futures::{try_ready, Async, Future, Poll}; use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::{ChannelBinding, TlsConnect}; -use tokio_tls::{Connect, TlsStream}; +use tokio_tls::TlsStream; +use std::pin::Pin; +use std::future::Future; #[cfg(test)] mod test; @@ -76,7 +78,7 @@ impl MakeTlsConnector { #[cfg(feature = "runtime")] impl MakeTlsConnect for MakeTlsConnector where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, { type Stream = TlsStream; type TlsConnect = TlsConnector; @@ -105,35 +107,22 @@ impl TlsConnector { impl TlsConnect for TlsConnector where - S: AsyncRead + AsyncWrite, + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, { type Stream = TlsStream; type Error = native_tls::Error; - type Future = TlsConnectFuture; + type Future = Pin, ChannelBinding), native_tls::Error>> + Send>>; - fn connect(self, stream: S) -> TlsConnectFuture { - TlsConnectFuture(self.connector.connect(&self.domain, stream)) - } -} + fn connect(self, stream: S) -> Self::Future { + let future = async move { + let stream = self.connector.connect(&self.domain, stream).await?; -/// The future returned by `TlsConnector`. -pub struct TlsConnectFuture(Connect); + // FIXME https://github.com/tokio-rs/tokio/issues/1383 + let channel_binding = ChannelBinding::none(); -impl Future for TlsConnectFuture -where - S: AsyncRead + AsyncWrite, -{ - type Item = (TlsStream, ChannelBinding); - type Error = native_tls::Error; - - fn poll(&mut self) -> Poll<(TlsStream, ChannelBinding), native_tls::Error> { - let stream = try_ready!(self.0.poll()); - - let channel_binding = match stream.get_ref().tls_server_end_point().unwrap_or(None) { - Some(buf) => ChannelBinding::tls_server_end_point(buf), - None => ChannelBinding::none(), + Ok((stream, channel_binding)) }; - Ok(Async::Ready((stream, channel_binding))) + Box::pin(future) } } diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index 43aab31f..81f93398 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -1,44 +1,40 @@ -use futures::{Future, Stream}; use native_tls::{self, Certificate}; use tokio::net::TcpStream; -use tokio::runtime::current_thread::Runtime; use tokio_postgres::tls::TlsConnect; +use futures::{FutureExt, TryStreamExt}; #[cfg(feature = "runtime")] use crate::MakeTlsConnector; use crate::TlsConnector; -fn smoke_test(s: &str, tls: T) -where - T: TlsConnect, - T::Stream: 'static, +async fn smoke_test(s: &str, tls: T) + where + T: TlsConnect, + T::Stream: 'static + Send, { - let mut runtime = Runtime::new().unwrap(); + let stream = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) + .await + .unwrap(); let builder = s.parse::().unwrap(); + let (mut client, connection) = builder.connect_raw(stream, tls).await.unwrap(); - let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) - .map_err(|e| panic!("{}", e)) - .and_then(|s| builder.connect_raw(s, tls)); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.spawn(connection); + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); - let prepare = client.prepare("SELECT 1::INT4"); - let statement = runtime.block_on(prepare).unwrap(); - let select = client.query(&statement, &[]).collect().map(|rows| { - assert_eq!(rows.len(), 1); - assert_eq!(rows[0].get::<_, i32>(0), 1); - }); - runtime.block_on(select).unwrap(); + let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); + let rows = client + .query(&stmt, &[&1i32]) + .try_collect::>() + .await + .unwrap(); - drop(statement); - drop(client); - runtime.run().unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); } -#[test] -fn require() { +#[tokio::test] +async fn require() { let connector = native_tls::TlsConnector::builder() .add_root_certificate( Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), @@ -48,11 +44,11 @@ fn require() { smoke_test( "user=ssl_user dbname=postgres sslmode=require", TlsConnector::new(connector, "localhost"), - ); + ).await; } -#[test] -fn prefer() { +#[tokio::test] +async fn prefer() { let connector = native_tls::TlsConnector::builder() .add_root_certificate( Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), @@ -62,11 +58,11 @@ fn prefer() { smoke_test( "user=ssl_user dbname=postgres", TlsConnector::new(connector, "localhost"), - ); + ).await; } -#[test] -fn scram_user() { +#[tokio::test] +async fn scram_user() { let connector = native_tls::TlsConnector::builder() .add_root_certificate( Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), @@ -76,14 +72,12 @@ fn scram_user() { smoke_test( "user=scram_user password=password dbname=postgres sslmode=require", TlsConnector::new(connector, "localhost"), - ); + ).await; } -#[test] +#[tokio::test] #[cfg(feature = "runtime")] -fn runtime() { - let mut runtime = Runtime::new().unwrap(); - +async fn runtime() { let connector = native_tls::TlsConnector::builder() .add_root_certificate( Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), @@ -92,14 +86,22 @@ fn runtime() { .unwrap(); let connector = MakeTlsConnector::new(connector); - let connect = tokio_postgres::connect( + let (mut client, connection) = tokio_postgres::connect( "host=localhost port=5433 user=postgres sslmode=require", connector, - ); - let (mut client, connection) = runtime.block_on(connect).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.spawn(connection); + ) + .await + .unwrap(); + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); - let execute = client.simple_query("SELECT 1").for_each(|_| Ok(())); - runtime.block_on(execute).unwrap(); + let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); + let rows = client + .query(&stmt, &[&1i32]) + .try_collect::>() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); }