Add a convenience connect free function

This commit is contained in:
Steven Fackler 2018-12-29 13:28:38 -08:00
parent af41875ea4
commit 6ae93a0634
5 changed files with 28 additions and 19 deletions

View File

@ -77,10 +77,7 @@ fn runtime() {
builder.set_ca_file("../test/server.crt").unwrap();
let connector = MakeTlsConnector::new(builder.build());
let connect = "host=localhost port=5433 user=postgres"
.parse::<tokio_postgres::Builder>()
.unwrap()
.connect(RequireTls(connector));
let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", RequireTls(connector));
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);

View File

@ -128,7 +128,7 @@ impl Builder {
where
T: MakeTlsMode<Socket>,
{
Connect(ConnectFuture::new(make_tls_mode, self.clone()))
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
}
}

View File

@ -36,6 +36,14 @@ fn next_portal() -> String {
format!("p{}", ID.fetch_add(1, Ordering::SeqCst))
}
#[cfg(feature = "runtime")]
pub fn connect<T>(config: &str, tls_mode: T) -> Connect<T>
where
T: MakeTlsMode<Socket>,
{
Connect(proto::ConnectFuture::new(tls_mode, config.parse()))
}
pub fn cancel_query<S, T>(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQuery<S, T>
where
S: AsyncRead + AsyncWrite,

View File

@ -10,7 +10,10 @@ where
T: MakeTlsMode<Socket>,
{
#[state_machine_future(start, transitions(MakingTlsMode))]
Start { make_tls_mode: T, config: Builder },
Start {
make_tls_mode: T,
config: Result<Builder, Error>,
},
#[state_machine_future(transitions(Connecting))]
MakingTlsMode {
future: T::Future,
@ -38,15 +41,17 @@ where
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
if state.config.0.host.is_empty() {
let config = state.config?;
if config.0.host.is_empty() {
return Err(Error::missing_host());
}
if state.config.0.port.len() > 1 && state.config.0.port.len() != state.config.0.host.len() {
if config.0.port.len() > 1 && config.0.port.len() != config.0.host.len() {
return Err(Error::invalid_port_count());
}
let hostname = match &state.config.0.host[0] {
let hostname = match &config.0.host[0] {
Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)]
@ -58,7 +63,7 @@ where
future,
idx: 0,
make_tls_mode: state.make_tls_mode,
config: state.config,
config,
})
}
@ -113,7 +118,7 @@ impl<T> ConnectFuture<T>
where
T: MakeTlsMode<Socket>,
{
pub fn new(make_tls_mode: T, config: Builder) -> ConnectFuture<T> {
pub fn new(make_tls_mode: T, config: Result<Builder, Error>) -> ConnectFuture<T> {
Connect::start(make_tls_mode, config)
}
}

View File

@ -1,14 +1,10 @@
use futures::Future;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{Client, Connection, Error, NoTls, Socket};
fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error = Error> {
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
}
use tokio_postgres::NoTls;
fn smoke_test(s: &str) {
let mut runtime = Runtime::new().unwrap();
let connect = connect(s);
let connect = tokio_postgres::connect(s, NoTls);
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
@ -41,9 +37,12 @@ fn multiple_hosts_multiple_ports() {
#[test]
fn wrong_port_count() {
let mut runtime = Runtime::new().unwrap();
let f = connect("host=localhost port=5433,5433 user=postgres");
let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls);
runtime.block_on(f).err().unwrap();
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
let f = tokio_postgres::connect(
"host=localhost,localhost,localhost port=5433,5433 user=postgres",
NoTls,
);
runtime.block_on(f).err().unwrap();
}