Add a convenience connect free function
This commit is contained in:
parent
af41875ea4
commit
6ae93a0634
@ -77,10 +77,7 @@ fn runtime() {
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let connector = MakeTlsConnector::new(builder.build());
|
||||
|
||||
let connect = "host=localhost port=5433 user=postgres"
|
||||
.parse::<tokio_postgres::Builder>()
|
||||
.unwrap()
|
||||
.connect(RequireTls(connector));
|
||||
let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", RequireTls(connector));
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
|
@ -128,7 +128,7 @@ impl Builder {
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
Connect(ConnectFuture::new(make_tls_mode, self.clone()))
|
||||
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,6 +36,14 @@ fn next_portal() -> String {
|
||||
format!("p{}", ID.fetch_add(1, Ordering::SeqCst))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn connect<T>(config: &str, tls_mode: T) -> Connect<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
Connect(proto::ConnectFuture::new(tls_mode, config.parse()))
|
||||
}
|
||||
|
||||
pub fn cancel_query<S, T>(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQuery<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
|
@ -10,7 +10,10 @@ where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(MakingTlsMode))]
|
||||
Start { make_tls_mode: T, config: Builder },
|
||||
Start {
|
||||
make_tls_mode: T,
|
||||
config: Result<Builder, Error>,
|
||||
},
|
||||
#[state_machine_future(transitions(Connecting))]
|
||||
MakingTlsMode {
|
||||
future: T::Future,
|
||||
@ -38,15 +41,17 @@ where
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
|
||||
if state.config.0.host.is_empty() {
|
||||
let config = state.config?;
|
||||
|
||||
if config.0.host.is_empty() {
|
||||
return Err(Error::missing_host());
|
||||
}
|
||||
|
||||
if state.config.0.port.len() > 1 && state.config.0.port.len() != state.config.0.host.len() {
|
||||
if config.0.port.len() > 1 && config.0.port.len() != config.0.host.len() {
|
||||
return Err(Error::invalid_port_count());
|
||||
}
|
||||
|
||||
let hostname = match &state.config.0.host[0] {
|
||||
let hostname = match &config.0.host[0] {
|
||||
Host::Tcp(host) => &**host,
|
||||
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
|
||||
#[cfg(unix)]
|
||||
@ -58,7 +63,7 @@ where
|
||||
future,
|
||||
idx: 0,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
config: state.config,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
@ -113,7 +118,7 @@ impl<T> ConnectFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
pub fn new(make_tls_mode: T, config: Builder) -> ConnectFuture<T> {
|
||||
pub fn new(make_tls_mode: T, config: Result<Builder, Error>) -> ConnectFuture<T> {
|
||||
Connect::start(make_tls_mode, config)
|
||||
}
|
||||
}
|
||||
|
@ -1,14 +1,10 @@
|
||||
use futures::Future;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_postgres::{Client, Connection, Error, NoTls, Socket};
|
||||
|
||||
fn connect(s: &str) -> impl Future<Item = (Client, Connection<Socket>), Error = Error> {
|
||||
s.parse::<tokio_postgres::Builder>().unwrap().connect(NoTls)
|
||||
}
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
fn smoke_test(s: &str) {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let connect = connect(s);
|
||||
let connect = tokio_postgres::connect(s, NoTls);
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
@ -41,9 +37,12 @@ fn multiple_hosts_multiple_ports() {
|
||||
#[test]
|
||||
fn wrong_port_count() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = connect("host=localhost port=5433,5433 user=postgres");
|
||||
let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
|
||||
let f = connect("host=localhost,localhost,localhost port=5433,5433 user=postgres");
|
||||
let f = tokio_postgres::connect(
|
||||
"host=localhost,localhost,localhost port=5433,5433 user=postgres",
|
||||
NoTls,
|
||||
);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user