From 23b0d6e6f30548c4a57cfedf56d179b0626011eb Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Wed, 19 Dec 2018 20:18:48 -0800 Subject: [PATCH] Support multiple hosts when connecting cc #399 --- tokio-postgres/src/error/mod.rs | 9 ++++ tokio-postgres/src/proto/connect.rs | 71 ++++++++++++++++++++++++---- tokio-postgres/tests/test/runtime.rs | 43 +++++++++++------ 3 files changed, 100 insertions(+), 23 deletions(-) diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index 3e5f992b..6a8d5d3d 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -354,6 +354,8 @@ enum Kind { MissingHost, #[cfg(feature = "runtime")] InvalidPort, + #[cfg(feature = "runtime")] + InvalidPortCount, } struct ErrorInner { @@ -397,6 +399,8 @@ impl fmt::Display for Error { Kind::MissingHost => "host not provided", #[cfg(feature = "runtime")] Kind::InvalidPort => "invalid port", + #[cfg(feature = "runtime")] + Kind::InvalidPortCount => "wrong number of ports provided", }; fmt.write_str(s)?; if let Some(ref cause) = self.0.cause { @@ -514,4 +518,9 @@ impl Error { pub(crate) fn invalid_port(e: ParseIntError) -> Error { Error::new(Kind::InvalidPort, Some(Box::new(e))) } + + #[cfg(feature = "runtime")] + pub(crate) fn invalid_port_count() -> Error { + Error::new(Kind::InvalidPortCount, None) + } } diff --git a/tokio-postgres/src/proto/connect.rs b/tokio-postgres/src/proto/connect.rs index 5fdd37c3..167a6a42 100644 --- a/tokio-postgres/src/proto/connect.rs +++ b/tokio-postgres/src/proto/connect.rs @@ -1,6 +1,7 @@ -use futures::{try_ready, Future, Poll}; +use futures::{try_ready, Async, Future, Poll}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::collections::HashMap; +use std::vec; use crate::proto::{Client, ConnectOnceFuture, Connection}; use crate::{Error, MakeTlsMode, Socket}; @@ -20,11 +21,16 @@ where future: T::Future, host: String, port: u16, + addrs: vec::IntoIter<(String, u16)>, + make_tls_mode: T, params: HashMap, }, - #[state_machine_future(transitions(Finished))] + #[state_machine_future(transitions(MakingTlsMode, Finished))] Connecting { future: ConnectOnceFuture, + addrs: vec::IntoIter<(String, u16)>, + make_tls_mode: T, + params: HashMap, }, #[state_machine_future(ready)] Finished((Client, Connection)), @@ -43,16 +49,42 @@ where Some(host) => host, None => return Err(Error::missing_host()), }; + let mut addrs = host + .split(',') + .map(|s| (s.to_string(), 0u16)) + .collect::>(); - let port = match state.params.remove("port") { - Some(port) => port.parse::().map_err(Error::invalid_port)?, - None => 5432, - }; + let port = state.params.remove("port").unwrap_or_else(String::new); + let mut ports = port + .split(',') + .map(|s| { + if s.is_empty() { + Ok(5432) + } else { + s.parse::().map_err(Error::invalid_port) + } + }) + .collect::, _>>()?; + if ports.len() == 1 { + ports.resize(addrs.len(), ports[0]); + } + if addrs.len() != ports.len() { + return Err(Error::invalid_port_count()); + } + + for (addr, port) in addrs.iter_mut().zip(ports) { + addr.1 = port; + } + + let mut addrs = addrs.into_iter(); + let (host, port) = addrs.next().expect("addrs cannot be empty"); transition!(MakingTlsMode { future: state.make_tls_mode.make_tls_mode(&host), host, port, + addrs, + make_tls_mode: state.make_tls_mode, params: state.params, }) } @@ -64,15 +96,36 @@ where let state = state.take(); transition!(Connecting { - future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params), + future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()), + addrs: state.addrs, + make_tls_mode: state.make_tls_mode, + params: state.params, }) } fn poll_connecting<'a>( state: &'a mut RentToOwn<'a, Connecting>, ) -> Poll, Error> { - let r = try_ready!(state.future.poll()); - transition!(Finished(r)) + match state.future.poll() { + Ok(Async::Ready(r)) => transition!(Finished(r)), + Ok(Async::NotReady) => Ok(Async::NotReady), + Err(e) => { + let mut state = state.take(); + let (host, port) = match state.addrs.next() { + Some(addr) => addr, + None => return Err(e), + }; + + transition!(MakingTlsMode { + future: state.make_tls_mode.make_tls_mode(&host), + host, + port, + addrs: state.addrs, + make_tls_mode: state.make_tls_mode, + params: state.params, + }) + } + } } } diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index f723be7b..576ca02f 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -6,12 +6,9 @@ fn connect(s: &str) -> impl Future), Error = s.parse::().unwrap().connect(NoTls) } -#[test] -#[ignore] // FIXME doesn't work with our docker-based tests :( -fn unix_socket() { +fn smoke_test(s: &str) { let mut runtime = Runtime::new().unwrap(); - - let connect = connect("host=/var/run/postgresql port=5433 user=postgres"); + let connect = connect(s); let (mut client, connection) = runtime.block_on(connect).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.spawn(connection); @@ -20,15 +17,33 @@ fn unix_socket() { runtime.block_on(execute).unwrap(); } +#[test] +#[ignore] // FIXME doesn't work with our docker-based tests :( +fn unix_socket() { + smoke_test("host=/var/run/postgresql port=5433 user=postgres"); +} + #[test] fn tcp() { - let mut runtime = Runtime::new().unwrap(); - - let connect = connect("host=localhost port=5433 user=postgres"); - let (mut client, connection) = runtime.block_on(connect).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.spawn(connection); - - let execute = client.batch_execute("SELECT 1"); - runtime.block_on(execute).unwrap(); + smoke_test("host=localhost port=5433 user=postgres") +} + +#[test] +fn multiple_hosts_one_port() { + smoke_test("host=foobar.invalid,localhost port=5433 user=postgres"); +} + +#[test] +fn multiple_hosts_multiple_ports() { + smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres"); +} + +#[test] +fn wrong_port_count() { + let mut runtime = Runtime::new().unwrap(); + let f = connect("host=localhost port=5433,5433 user=postgres"); + runtime.block_on(f).err().unwrap(); + + let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres"); + runtime.block_on(f).err().unwrap(); }