diff --git a/tokio-postgres-openssl/src/test.rs b/tokio-postgres-openssl/src/test.rs index a85cc534..ee37202d 100644 --- a/tokio-postgres-openssl/src/test.rs +++ b/tokio-postgres-openssl/src/test.rs @@ -77,10 +77,7 @@ fn runtime() { builder.set_ca_file("../test/server.crt").unwrap(); let connector = MakeTlsConnector::new(builder.build()); - let connect = "host=localhost port=5433 user=postgres" - .parse::() - .unwrap() - .connect(RequireTls(connector)); + let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", RequireTls(connector)); let (mut client, connection) = runtime.block_on(connect).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.spawn(connection); diff --git a/tokio-postgres/src/builder.rs b/tokio-postgres/src/builder.rs index c64ad729..998e4843 100644 --- a/tokio-postgres/src/builder.rs +++ b/tokio-postgres/src/builder.rs @@ -128,7 +128,7 @@ impl Builder { where T: MakeTlsMode, { - Connect(ConnectFuture::new(make_tls_mode, self.clone())) + Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone()))) } } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 660dc4db..2b160a3e 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -36,6 +36,14 @@ fn next_portal() -> String { format!("p{}", ID.fetch_add(1, Ordering::SeqCst)) } +#[cfg(feature = "runtime")] +pub fn connect(config: &str, tls_mode: T) -> Connect +where + T: MakeTlsMode, +{ + Connect(proto::ConnectFuture::new(tls_mode, config.parse())) +} + pub fn cancel_query(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQuery where S: AsyncRead + AsyncWrite, diff --git a/tokio-postgres/src/proto/connect.rs b/tokio-postgres/src/proto/connect.rs index f99cd87c..d272db16 100644 --- a/tokio-postgres/src/proto/connect.rs +++ b/tokio-postgres/src/proto/connect.rs @@ -10,7 +10,10 @@ where T: MakeTlsMode, { #[state_machine_future(start, transitions(MakingTlsMode))] - Start { make_tls_mode: T, config: Builder }, + Start { + make_tls_mode: T, + config: Result, + }, #[state_machine_future(transitions(Connecting))] MakingTlsMode { future: T::Future, @@ -38,15 +41,17 @@ where fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let mut state = state.take(); - if state.config.0.host.is_empty() { + let config = state.config?; + + if config.0.host.is_empty() { return Err(Error::missing_host()); } - if state.config.0.port.len() > 1 && state.config.0.port.len() != state.config.0.host.len() { + if config.0.port.len() > 1 && config.0.port.len() != config.0.host.len() { return Err(Error::invalid_port_count()); } - let hostname = match &state.config.0.host[0] { + let hostname = match &config.0.host[0] { Host::Tcp(host) => &**host, // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] @@ -58,7 +63,7 @@ where future, idx: 0, make_tls_mode: state.make_tls_mode, - config: state.config, + config, }) } @@ -113,7 +118,7 @@ impl ConnectFuture where T: MakeTlsMode, { - pub fn new(make_tls_mode: T, config: Builder) -> ConnectFuture { + pub fn new(make_tls_mode: T, config: Result) -> ConnectFuture { Connect::start(make_tls_mode, config) } } diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 576ca02f..67246876 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -1,14 +1,10 @@ use futures::Future; use tokio::runtime::current_thread::Runtime; -use tokio_postgres::{Client, Connection, Error, NoTls, Socket}; - -fn connect(s: &str) -> impl Future), Error = Error> { - s.parse::().unwrap().connect(NoTls) -} +use tokio_postgres::NoTls; fn smoke_test(s: &str) { let mut runtime = Runtime::new().unwrap(); - let connect = connect(s); + let connect = tokio_postgres::connect(s, NoTls); let (mut client, connection) = runtime.block_on(connect).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.spawn(connection); @@ -41,9 +37,12 @@ fn multiple_hosts_multiple_ports() { #[test] fn wrong_port_count() { let mut runtime = Runtime::new().unwrap(); - let f = connect("host=localhost port=5433,5433 user=postgres"); + let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls); runtime.block_on(f).err().unwrap(); - let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres"); + let f = tokio_postgres::connect( + "host=localhost,localhost,localhost port=5433,5433 user=postgres", + NoTls, + ); runtime.block_on(f).err().unwrap(); }