Support multiple hosts when connecting

cc #399
This commit is contained in:
Steven Fackler 2018-12-19 20:18:48 -08:00
parent 7e7ae968c1
commit 23b0d6e6f3
3 changed files with 100 additions and 23 deletions

View File

@ -354,6 +354,8 @@ enum Kind {
MissingHost,
#[cfg(feature = "runtime")]
InvalidPort,
#[cfg(feature = "runtime")]
InvalidPortCount,
}
struct ErrorInner {
@ -397,6 +399,8 @@ impl fmt::Display for Error {
Kind::MissingHost => "host not provided",
#[cfg(feature = "runtime")]
Kind::InvalidPort => "invalid port",
#[cfg(feature = "runtime")]
Kind::InvalidPortCount => "wrong number of ports provided",
};
fmt.write_str(s)?;
if let Some(ref cause) = self.0.cause {
@ -514,4 +518,9 @@ impl Error {
pub(crate) fn invalid_port(e: ParseIntError) -> Error {
Error::new(Kind::InvalidPort, Some(Box::new(e)))
}
#[cfg(feature = "runtime")]
pub(crate) fn invalid_port_count() -> Error {
Error::new(Kind::InvalidPortCount, None)
}
}

View File

@ -1,6 +1,7 @@
use futures::{try_ready, Future, Poll};
use futures::{try_ready, Async, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::collections::HashMap;
use std::vec;
use crate::proto::{Client, ConnectOnceFuture, Connection};
use crate::{Error, MakeTlsMode, Socket};
@ -20,11 +21,16 @@ where
future: T::Future,
host: String,
port: u16,
addrs: vec::IntoIter<(String, u16)>,
make_tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(Finished))]
#[state_machine_future(transitions(MakingTlsMode, Finished))]
Connecting {
future: ConnectOnceFuture<T::TlsMode>,
addrs: vec::IntoIter<(String, u16)>,
make_tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
@ -43,16 +49,42 @@ where
Some(host) => host,
None => return Err(Error::missing_host()),
};
let mut addrs = host
.split(',')
.map(|s| (s.to_string(), 0u16))
.collect::<Vec<_>>();
let port = match state.params.remove("port") {
Some(port) => port.parse::<u16>().map_err(Error::invalid_port)?,
None => 5432,
};
let port = state.params.remove("port").unwrap_or_else(String::new);
let mut ports = port
.split(',')
.map(|s| {
if s.is_empty() {
Ok(5432)
} else {
s.parse::<u16>().map_err(Error::invalid_port)
}
})
.collect::<Result<Vec<_>, _>>()?;
if ports.len() == 1 {
ports.resize(addrs.len(), ports[0]);
}
if addrs.len() != ports.len() {
return Err(Error::invalid_port_count());
}
for (addr, port) in addrs.iter_mut().zip(ports) {
addr.1 = port;
}
let mut addrs = addrs.into_iter();
let (host, port) = addrs.next().expect("addrs cannot be empty");
transition!(MakingTlsMode {
future: state.make_tls_mode.make_tls_mode(&host),
host,
port,
addrs,
make_tls_mode: state.make_tls_mode,
params: state.params,
})
}
@ -64,15 +96,36 @@ where
let state = state.take();
transition!(Connecting {
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params),
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()),
addrs: state.addrs,
make_tls_mode: state.make_tls_mode,
params: state.params,
})
}
fn poll_connecting<'a>(
state: &'a mut RentToOwn<'a, Connecting<T>>,
) -> Poll<AfterConnecting<T>, Error> {
let r = try_ready!(state.future.poll());
transition!(Finished(r))
match state.future.poll() {
Ok(Async::Ready(r)) => transition!(Finished(r)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
let mut state = state.take();
let (host, port) = match state.addrs.next() {
Some(addr) => addr,
None => return Err(e),
};
transition!(MakingTlsMode {
future: state.make_tls_mode.make_tls_mode(&host),
host,
port,
addrs: state.addrs,
make_tls_mode: state.make_tls_mode,
params: state.params,
})
}
}
}
}

View File

@ -6,12 +6,9 @@ fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error =
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
}
#[test]
#[ignore] // FIXME doesn't work with our docker-based tests :(
fn unix_socket() {
fn smoke_test(s: &str) {
let mut runtime = Runtime::new().unwrap();
let connect = connect("host=/var/run/postgresql port=5433 user=postgres");
let connect = connect(s);
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
@ -20,15 +17,33 @@ fn unix_socket() {
runtime.block_on(execute).unwrap();
}
#[test]
#[ignore] // FIXME doesn't work with our docker-based tests :(
fn unix_socket() {
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
}
#[test]
fn tcp() {
let mut runtime = Runtime::new().unwrap();
let connect = connect("host=localhost port=5433 user=postgres");
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let execute = client.batch_execute("SELECT 1");
runtime.block_on(execute).unwrap();
smoke_test("host=localhost port=5433 user=postgres")
}
#[test]
fn multiple_hosts_one_port() {
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
}
#[test]
fn multiple_hosts_multiple_ports() {
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
}
#[test]
fn wrong_port_count() {
let mut runtime = Runtime::new().unwrap();
let f = connect("host=localhost port=5433,5433 user=postgres");
runtime.block_on(f).err().unwrap();
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
runtime.block_on(f).err().unwrap();
}