Support multiple hosts when connecting

cc #399
This commit is contained in:
Steven Fackler 2018-12-19 20:18:48 -08:00
parent 7e7ae968c1
commit 23b0d6e6f3
3 changed files with 100 additions and 23 deletions

View File

@ -354,6 +354,8 @@ enum Kind {
MissingHost, MissingHost,
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
InvalidPort, InvalidPort,
#[cfg(feature = "runtime")]
InvalidPortCount,
} }
struct ErrorInner { struct ErrorInner {
@ -397,6 +399,8 @@ impl fmt::Display for Error {
Kind::MissingHost => "host not provided", Kind::MissingHost => "host not provided",
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
Kind::InvalidPort => "invalid port", Kind::InvalidPort => "invalid port",
#[cfg(feature = "runtime")]
Kind::InvalidPortCount => "wrong number of ports provided",
}; };
fmt.write_str(s)?; fmt.write_str(s)?;
if let Some(ref cause) = self.0.cause { if let Some(ref cause) = self.0.cause {
@ -514,4 +518,9 @@ impl Error {
pub(crate) fn invalid_port(e: ParseIntError) -> Error { pub(crate) fn invalid_port(e: ParseIntError) -> Error {
Error::new(Kind::InvalidPort, Some(Box::new(e))) Error::new(Kind::InvalidPort, Some(Box::new(e)))
} }
#[cfg(feature = "runtime")]
pub(crate) fn invalid_port_count() -> Error {
Error::new(Kind::InvalidPortCount, None)
}
} }

View File

@ -1,6 +1,7 @@
use futures::{try_ready, Future, Poll}; use futures::{try_ready, Async, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::collections::HashMap; use std::collections::HashMap;
use std::vec;
use crate::proto::{Client, ConnectOnceFuture, Connection}; use crate::proto::{Client, ConnectOnceFuture, Connection};
use crate::{Error, MakeTlsMode, Socket}; use crate::{Error, MakeTlsMode, Socket};
@ -20,11 +21,16 @@ where
future: T::Future, future: T::Future,
host: String, host: String,
port: u16, port: u16,
addrs: vec::IntoIter<(String, u16)>,
make_tls_mode: T,
params: HashMap<String, String>, params: HashMap<String, String>,
}, },
#[state_machine_future(transitions(Finished))] #[state_machine_future(transitions(MakingTlsMode, Finished))]
Connecting { Connecting {
future: ConnectOnceFuture<T::TlsMode>, future: ConnectOnceFuture<T::TlsMode>,
addrs: vec::IntoIter<(String, u16)>,
make_tls_mode: T,
params: HashMap<String, String>,
}, },
#[state_machine_future(ready)] #[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)), Finished((Client, Connection<T::Stream>)),
@ -43,16 +49,42 @@ where
Some(host) => host, Some(host) => host,
None => return Err(Error::missing_host()), None => return Err(Error::missing_host()),
}; };
let mut addrs = host
.split(',')
.map(|s| (s.to_string(), 0u16))
.collect::<Vec<_>>();
let port = match state.params.remove("port") { let port = state.params.remove("port").unwrap_or_else(String::new);
Some(port) => port.parse::<u16>().map_err(Error::invalid_port)?, let mut ports = port
None => 5432, .split(',')
}; .map(|s| {
if s.is_empty() {
Ok(5432)
} else {
s.parse::<u16>().map_err(Error::invalid_port)
}
})
.collect::<Result<Vec<_>, _>>()?;
if ports.len() == 1 {
ports.resize(addrs.len(), ports[0]);
}
if addrs.len() != ports.len() {
return Err(Error::invalid_port_count());
}
for (addr, port) in addrs.iter_mut().zip(ports) {
addr.1 = port;
}
let mut addrs = addrs.into_iter();
let (host, port) = addrs.next().expect("addrs cannot be empty");
transition!(MakingTlsMode { transition!(MakingTlsMode {
future: state.make_tls_mode.make_tls_mode(&host), future: state.make_tls_mode.make_tls_mode(&host),
host, host,
port, port,
addrs,
make_tls_mode: state.make_tls_mode,
params: state.params, params: state.params,
}) })
} }
@ -64,15 +96,36 @@ where
let state = state.take(); let state = state.take();
transition!(Connecting { transition!(Connecting {
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params), future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()),
addrs: state.addrs,
make_tls_mode: state.make_tls_mode,
params: state.params,
}) })
} }
fn poll_connecting<'a>( fn poll_connecting<'a>(
state: &'a mut RentToOwn<'a, Connecting<T>>, state: &'a mut RentToOwn<'a, Connecting<T>>,
) -> Poll<AfterConnecting<T>, Error> { ) -> Poll<AfterConnecting<T>, Error> {
let r = try_ready!(state.future.poll()); match state.future.poll() {
transition!(Finished(r)) Ok(Async::Ready(r)) => transition!(Finished(r)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
let mut state = state.take();
let (host, port) = match state.addrs.next() {
Some(addr) => addr,
None => return Err(e),
};
transition!(MakingTlsMode {
future: state.make_tls_mode.make_tls_mode(&host),
host,
port,
addrs: state.addrs,
make_tls_mode: state.make_tls_mode,
params: state.params,
})
}
}
} }
} }

View File

@ -6,12 +6,9 @@ fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error =
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls) s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
} }
#[test] fn smoke_test(s: &str) {
#[ignore] // FIXME doesn't work with our docker-based tests :(
fn unix_socket() {
let mut runtime = Runtime::new().unwrap(); let mut runtime = Runtime::new().unwrap();
let connect = connect(s);
let connect = connect("host=/var/run/postgresql port=5433 user=postgres");
let (mut client, connection) = runtime.block_on(connect).unwrap(); let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e)); let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection); runtime.spawn(connection);
@ -20,15 +17,33 @@ fn unix_socket() {
runtime.block_on(execute).unwrap(); runtime.block_on(execute).unwrap();
} }
#[test]
#[ignore] // FIXME doesn't work with our docker-based tests :(
fn unix_socket() {
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
}
#[test] #[test]
fn tcp() { fn tcp() {
let mut runtime = Runtime::new().unwrap(); smoke_test("host=localhost port=5433 user=postgres")
}
let connect = connect("host=localhost port=5433 user=postgres");
let (mut client, connection) = runtime.block_on(connect).unwrap(); #[test]
let connection = connection.map_err(|e| panic!("{}", e)); fn multiple_hosts_one_port() {
runtime.spawn(connection); smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
}
let execute = client.batch_execute("SELECT 1");
runtime.block_on(execute).unwrap(); #[test]
fn multiple_hosts_multiple_ports() {
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
}
#[test]
fn wrong_port_count() {
let mut runtime = Runtime::new().unwrap();
let f = connect("host=localhost port=5433,5433 user=postgres");
runtime.block_on(f).err().unwrap();
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
runtime.block_on(f).err().unwrap();
} }