parent
7e7ae968c1
commit
23b0d6e6f3
@ -354,6 +354,8 @@ enum Kind {
|
|||||||
MissingHost,
|
MissingHost,
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
InvalidPort,
|
InvalidPort,
|
||||||
|
#[cfg(feature = "runtime")]
|
||||||
|
InvalidPortCount,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ErrorInner {
|
struct ErrorInner {
|
||||||
@ -397,6 +399,8 @@ impl fmt::Display for Error {
|
|||||||
Kind::MissingHost => "host not provided",
|
Kind::MissingHost => "host not provided",
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
Kind::InvalidPort => "invalid port",
|
Kind::InvalidPort => "invalid port",
|
||||||
|
#[cfg(feature = "runtime")]
|
||||||
|
Kind::InvalidPortCount => "wrong number of ports provided",
|
||||||
};
|
};
|
||||||
fmt.write_str(s)?;
|
fmt.write_str(s)?;
|
||||||
if let Some(ref cause) = self.0.cause {
|
if let Some(ref cause) = self.0.cause {
|
||||||
@ -514,4 +518,9 @@ impl Error {
|
|||||||
pub(crate) fn invalid_port(e: ParseIntError) -> Error {
|
pub(crate) fn invalid_port(e: ParseIntError) -> Error {
|
||||||
Error::new(Kind::InvalidPort, Some(Box::new(e)))
|
Error::new(Kind::InvalidPort, Some(Box::new(e)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "runtime")]
|
||||||
|
pub(crate) fn invalid_port_count() -> Error {
|
||||||
|
Error::new(Kind::InvalidPortCount, None)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use futures::{try_ready, Future, Poll};
|
use futures::{try_ready, Async, Future, Poll};
|
||||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::vec;
|
||||||
|
|
||||||
use crate::proto::{Client, ConnectOnceFuture, Connection};
|
use crate::proto::{Client, ConnectOnceFuture, Connection};
|
||||||
use crate::{Error, MakeTlsMode, Socket};
|
use crate::{Error, MakeTlsMode, Socket};
|
||||||
@ -20,11 +21,16 @@ where
|
|||||||
future: T::Future,
|
future: T::Future,
|
||||||
host: String,
|
host: String,
|
||||||
port: u16,
|
port: u16,
|
||||||
|
addrs: vec::IntoIter<(String, u16)>,
|
||||||
|
make_tls_mode: T,
|
||||||
params: HashMap<String, String>,
|
params: HashMap<String, String>,
|
||||||
},
|
},
|
||||||
#[state_machine_future(transitions(Finished))]
|
#[state_machine_future(transitions(MakingTlsMode, Finished))]
|
||||||
Connecting {
|
Connecting {
|
||||||
future: ConnectOnceFuture<T::TlsMode>,
|
future: ConnectOnceFuture<T::TlsMode>,
|
||||||
|
addrs: vec::IntoIter<(String, u16)>,
|
||||||
|
make_tls_mode: T,
|
||||||
|
params: HashMap<String, String>,
|
||||||
},
|
},
|
||||||
#[state_machine_future(ready)]
|
#[state_machine_future(ready)]
|
||||||
Finished((Client, Connection<T::Stream>)),
|
Finished((Client, Connection<T::Stream>)),
|
||||||
@ -43,16 +49,42 @@ where
|
|||||||
Some(host) => host,
|
Some(host) => host,
|
||||||
None => return Err(Error::missing_host()),
|
None => return Err(Error::missing_host()),
|
||||||
};
|
};
|
||||||
|
let mut addrs = host
|
||||||
|
.split(',')
|
||||||
|
.map(|s| (s.to_string(), 0u16))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let port = match state.params.remove("port") {
|
let port = state.params.remove("port").unwrap_or_else(String::new);
|
||||||
Some(port) => port.parse::<u16>().map_err(Error::invalid_port)?,
|
let mut ports = port
|
||||||
None => 5432,
|
.split(',')
|
||||||
};
|
.map(|s| {
|
||||||
|
if s.is_empty() {
|
||||||
|
Ok(5432)
|
||||||
|
} else {
|
||||||
|
s.parse::<u16>().map_err(Error::invalid_port)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
if ports.len() == 1 {
|
||||||
|
ports.resize(addrs.len(), ports[0]);
|
||||||
|
}
|
||||||
|
if addrs.len() != ports.len() {
|
||||||
|
return Err(Error::invalid_port_count());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (addr, port) in addrs.iter_mut().zip(ports) {
|
||||||
|
addr.1 = port;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut addrs = addrs.into_iter();
|
||||||
|
let (host, port) = addrs.next().expect("addrs cannot be empty");
|
||||||
|
|
||||||
transition!(MakingTlsMode {
|
transition!(MakingTlsMode {
|
||||||
future: state.make_tls_mode.make_tls_mode(&host),
|
future: state.make_tls_mode.make_tls_mode(&host),
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
|
addrs,
|
||||||
|
make_tls_mode: state.make_tls_mode,
|
||||||
params: state.params,
|
params: state.params,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -64,15 +96,36 @@ where
|
|||||||
let state = state.take();
|
let state = state.take();
|
||||||
|
|
||||||
transition!(Connecting {
|
transition!(Connecting {
|
||||||
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params),
|
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()),
|
||||||
|
addrs: state.addrs,
|
||||||
|
make_tls_mode: state.make_tls_mode,
|
||||||
|
params: state.params,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_connecting<'a>(
|
fn poll_connecting<'a>(
|
||||||
state: &'a mut RentToOwn<'a, Connecting<T>>,
|
state: &'a mut RentToOwn<'a, Connecting<T>>,
|
||||||
) -> Poll<AfterConnecting<T>, Error> {
|
) -> Poll<AfterConnecting<T>, Error> {
|
||||||
let r = try_ready!(state.future.poll());
|
match state.future.poll() {
|
||||||
transition!(Finished(r))
|
Ok(Async::Ready(r)) => transition!(Finished(r)),
|
||||||
|
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||||
|
Err(e) => {
|
||||||
|
let mut state = state.take();
|
||||||
|
let (host, port) = match state.addrs.next() {
|
||||||
|
Some(addr) => addr,
|
||||||
|
None => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
transition!(MakingTlsMode {
|
||||||
|
future: state.make_tls_mode.make_tls_mode(&host),
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
addrs: state.addrs,
|
||||||
|
make_tls_mode: state.make_tls_mode,
|
||||||
|
params: state.params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,12 +6,9 @@ fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error =
|
|||||||
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
|
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn smoke_test(s: &str) {
|
||||||
#[ignore] // FIXME doesn't work with our docker-based tests :(
|
|
||||||
fn unix_socket() {
|
|
||||||
let mut runtime = Runtime::new().unwrap();
|
let mut runtime = Runtime::new().unwrap();
|
||||||
|
let connect = connect(s);
|
||||||
let connect = connect("host=/var/run/postgresql port=5433 user=postgres");
|
|
||||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||||
let connection = connection.map_err(|e| panic!("{}", e));
|
let connection = connection.map_err(|e| panic!("{}", e));
|
||||||
runtime.spawn(connection);
|
runtime.spawn(connection);
|
||||||
@ -20,15 +17,33 @@ fn unix_socket() {
|
|||||||
runtime.block_on(execute).unwrap();
|
runtime.block_on(execute).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[ignore] // FIXME doesn't work with our docker-based tests :(
|
||||||
|
fn unix_socket() {
|
||||||
|
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tcp() {
|
fn tcp() {
|
||||||
let mut runtime = Runtime::new().unwrap();
|
smoke_test("host=localhost port=5433 user=postgres")
|
||||||
|
}
|
||||||
let connect = connect("host=localhost port=5433 user=postgres");
|
|
||||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
#[test]
|
||||||
let connection = connection.map_err(|e| panic!("{}", e));
|
fn multiple_hosts_one_port() {
|
||||||
runtime.spawn(connection);
|
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
|
||||||
|
}
|
||||||
let execute = client.batch_execute("SELECT 1");
|
|
||||||
runtime.block_on(execute).unwrap();
|
#[test]
|
||||||
|
fn multiple_hosts_multiple_ports() {
|
||||||
|
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wrong_port_count() {
|
||||||
|
let mut runtime = Runtime::new().unwrap();
|
||||||
|
let f = connect("host=localhost port=5433,5433 user=postgres");
|
||||||
|
runtime.block_on(f).err().unwrap();
|
||||||
|
|
||||||
|
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
|
||||||
|
runtime.block_on(f).err().unwrap();
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user