parent
7e7ae968c1
commit
23b0d6e6f3
@ -354,6 +354,8 @@ enum Kind {
|
||||
MissingHost,
|
||||
#[cfg(feature = "runtime")]
|
||||
InvalidPort,
|
||||
#[cfg(feature = "runtime")]
|
||||
InvalidPortCount,
|
||||
}
|
||||
|
||||
struct ErrorInner {
|
||||
@ -397,6 +399,8 @@ impl fmt::Display for Error {
|
||||
Kind::MissingHost => "host not provided",
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::InvalidPort => "invalid port",
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::InvalidPortCount => "wrong number of ports provided",
|
||||
};
|
||||
fmt.write_str(s)?;
|
||||
if let Some(ref cause) = self.0.cause {
|
||||
@ -514,4 +518,9 @@ impl Error {
|
||||
pub(crate) fn invalid_port(e: ParseIntError) -> Error {
|
||||
Error::new(Kind::InvalidPort, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) fn invalid_port_count() -> Error {
|
||||
Error::new(Kind::InvalidPortCount, None)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
use futures::{try_ready, Future, Poll};
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::collections::HashMap;
|
||||
use std::vec;
|
||||
|
||||
use crate::proto::{Client, ConnectOnceFuture, Connection};
|
||||
use crate::{Error, MakeTlsMode, Socket};
|
||||
@ -20,11 +21,16 @@ where
|
||||
future: T::Future,
|
||||
host: String,
|
||||
port: u16,
|
||||
addrs: vec::IntoIter<(String, u16)>,
|
||||
make_tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
#[state_machine_future(transitions(MakingTlsMode, Finished))]
|
||||
Connecting {
|
||||
future: ConnectOnceFuture<T::TlsMode>,
|
||||
addrs: vec::IntoIter<(String, u16)>,
|
||||
make_tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
@ -43,16 +49,42 @@ where
|
||||
Some(host) => host,
|
||||
None => return Err(Error::missing_host()),
|
||||
};
|
||||
let mut addrs = host
|
||||
.split(',')
|
||||
.map(|s| (s.to_string(), 0u16))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let port = match state.params.remove("port") {
|
||||
Some(port) => port.parse::<u16>().map_err(Error::invalid_port)?,
|
||||
None => 5432,
|
||||
};
|
||||
let port = state.params.remove("port").unwrap_or_else(String::new);
|
||||
let mut ports = port
|
||||
.split(',')
|
||||
.map(|s| {
|
||||
if s.is_empty() {
|
||||
Ok(5432)
|
||||
} else {
|
||||
s.parse::<u16>().map_err(Error::invalid_port)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if ports.len() == 1 {
|
||||
ports.resize(addrs.len(), ports[0]);
|
||||
}
|
||||
if addrs.len() != ports.len() {
|
||||
return Err(Error::invalid_port_count());
|
||||
}
|
||||
|
||||
for (addr, port) in addrs.iter_mut().zip(ports) {
|
||||
addr.1 = port;
|
||||
}
|
||||
|
||||
let mut addrs = addrs.into_iter();
|
||||
let (host, port) = addrs.next().expect("addrs cannot be empty");
|
||||
|
||||
transition!(MakingTlsMode {
|
||||
future: state.make_tls_mode.make_tls_mode(&host),
|
||||
host,
|
||||
port,
|
||||
addrs,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
@ -64,15 +96,36 @@ where
|
||||
let state = state.take();
|
||||
|
||||
transition!(Connecting {
|
||||
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params),
|
||||
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()),
|
||||
addrs: state.addrs,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_connecting<'a>(
|
||||
state: &'a mut RentToOwn<'a, Connecting<T>>,
|
||||
) -> Poll<AfterConnecting<T>, Error> {
|
||||
let r = try_ready!(state.future.poll());
|
||||
transition!(Finished(r))
|
||||
match state.future.poll() {
|
||||
Ok(Async::Ready(r)) => transition!(Finished(r)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => {
|
||||
let mut state = state.take();
|
||||
let (host, port) = match state.addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return Err(e),
|
||||
};
|
||||
|
||||
transition!(MakingTlsMode {
|
||||
future: state.make_tls_mode.make_tls_mode(&host),
|
||||
host,
|
||||
port,
|
||||
addrs: state.addrs,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,12 +6,9 @@ fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error =
|
||||
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // FIXME doesn't work with our docker-based tests :(
|
||||
fn unix_socket() {
|
||||
fn smoke_test(s: &str) {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let connect = connect("host=/var/run/postgresql port=5433 user=postgres");
|
||||
let connect = connect(s);
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
@ -20,15 +17,33 @@ fn unix_socket() {
|
||||
runtime.block_on(execute).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // FIXME doesn't work with our docker-based tests :(
|
||||
fn unix_socket() {
|
||||
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let connect = connect("host=localhost port=5433 user=postgres");
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
|
||||
let execute = client.batch_execute("SELECT 1");
|
||||
runtime.block_on(execute).unwrap();
|
||||
smoke_test("host=localhost port=5433 user=postgres")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_hosts_one_port() {
|
||||
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_hosts_multiple_ports() {
|
||||
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_port_count() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = connect("host=localhost port=5433,5433 user=postgres");
|
||||
runtime.block_on(f).err().unwrap();
|
||||
|
||||
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
|
||||
runtime.block_on(f).err().unwrap();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user