Prep for multi-host support

cc #399
This commit is contained in:
Steven Fackler 2018-12-18 21:39:05 -08:00
parent 56088a9a46
commit 7e7ae968c1
10 changed files with 398 additions and 129 deletions

View File

@ -4,6 +4,10 @@ version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
edition = "2018"
[features]
default = ["runtime"]
runtime = ["tokio-postgres/runtime"]
[dependencies]
futures = "0.1"
openssl = "0.10"

View File

@ -1,17 +1,72 @@
#![warn(rust_2018_idioms, clippy::all)]
#[cfg(feature = "runtime")]
use futures::future::{self, FutureResult};
use futures::{try_ready, Async, Future, Poll};
#[cfg(feature = "runtime")]
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
#[cfg(feature = "runtime")]
use openssl::ssl::SslConnector;
use openssl::ssl::{ConnectConfiguration, HandshakeError, SslRef};
use std::fmt::Debug;
#[cfg(feature = "runtime")]
use std::sync::Arc;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_openssl::{ConnectAsync, ConnectConfigurationExt, SslStream};
#[cfg(feature = "runtime")]
use tokio_postgres::MakeTlsConnect;
use tokio_postgres::{ChannelBinding, TlsConnect};
#[cfg(test)]
mod test;
#[cfg(feature = "runtime")]
#[derive(Clone)]
pub struct MakeTlsConnector {
connector: SslConnector,
config: Arc<dyn Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
}
#[cfg(feature = "runtime")]
impl MakeTlsConnector {
pub fn new(connector: SslConnector) -> MakeTlsConnector {
MakeTlsConnector {
connector,
config: Arc::new(|_| Ok(())),
}
}
pub fn set_callback<F>(&mut self, f: F)
where
F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.config = Arc::new(f);
}
fn make_tls_connect_inner(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
let mut ssl = self.connector.configure()?;
(self.config)(&mut ssl)?;
Ok(TlsConnector::new(ssl, domain))
}
}
#[cfg(feature = "runtime")]
impl<S> MakeTlsConnect<S> for MakeTlsConnector
where
S: AsyncRead + AsyncWrite + Debug + 'static + Sync + Send,
{
type Stream = SslStream<S>;
type TlsConnect = TlsConnector;
type Error = ErrorStack;
type Future = FutureResult<TlsConnector, ErrorStack>;
fn make_tls_connect(&mut self, domain: &str) -> FutureResult<TlsConnector, ErrorStack> {
future::result(self.make_tls_connect_inner(domain))
}
}
pub struct TlsConnector {
ssl: ConnectConfiguration,
domain: String,

View File

@ -4,7 +4,7 @@ use tokio::net::TcpStream;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
use crate::TlsConnector;
use super::*;
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
where
@ -72,3 +72,24 @@ fn scram_user() {
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
);
}
#[test]
#[cfg(feature = "runtime")]
fn runtime() {
let mut runtime = Runtime::new().unwrap();
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let connector = MakeTlsConnector::new(builder.build());
let connect = "host=localhost port=5433 user=postgres"
.parse::<tokio_postgres::Builder>()
.unwrap()
.connect(RequireTls(connector));
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let execute = client.batch_execute("SELECT 1");
runtime.block_on(execute).unwrap();
}

View File

@ -7,7 +7,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::ConnectFuture;
use crate::proto::HandshakeFuture;
#[cfg(feature = "runtime")]
use crate::{Connect, Socket};
use crate::{Connect, MakeTlsMode, Socket};
use crate::{Error, Handshake, TlsMode};
#[derive(Clone)]
@ -61,11 +61,11 @@ impl Builder {
}
#[cfg(feature = "runtime")]
pub fn connect<T>(&self, tls_mode: T) -> Connect<T>
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
where
T: TlsMode<Socket>,
T: MakeTlsMode<Socket>,
{
Connect(ConnectFuture::new(tls_mode, self.params.clone()))
Connect(ConnectFuture::new(make_tls_mode, self.params.clone()))
}
}

View File

@ -184,12 +184,12 @@ where
#[must_use = "futures do nothing unless polled"]
pub struct Connect<T>(proto::ConnectFuture<T>)
where
T: TlsMode<Socket>;
T: MakeTlsMode<Socket>;
#[cfg(feature = "runtime")]
impl<T> Future for Connect<T>
where
T: TlsMode<Socket>,
T: MakeTlsMode<Socket>,
{
type Item = (Client, Connection<T::Stream>);
type Error = Error;

View File

@ -1,61 +1,31 @@
use futures::{try_ready, Async, Future, Poll};
use futures_cpupool::{CpuFuture, CpuPool};
use lazy_static::lazy_static;
use futures::{try_ready, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::collections::HashMap;
use std::io;
use std::net::{SocketAddr, ToSocketAddrs};
#[cfg(unix)]
use std::path::Path;
use std::vec;
use tokio_tcp::TcpStream;
#[cfg(unix)]
use tokio_uds::UnixStream;
use crate::proto::{Client, Connection, HandshakeFuture};
use crate::{Error, Socket, TlsMode};
lazy_static! {
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
.name_prefix("postgres-dns-")
.pool_size(2)
.create();
}
use crate::proto::{Client, ConnectOnceFuture, Connection};
use crate::{Error, MakeTlsMode, Socket};
#[derive(StateMachineFuture)]
pub enum Connect<T>
where
T: TlsMode<Socket>,
T: MakeTlsMode<Socket>,
{
#[state_machine_future(start)]
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
#[state_machine_future(start, transitions(MakingTlsMode))]
Start {
tls_mode: T,
make_tls_mode: T,
params: HashMap<String, String>,
},
#[cfg(unix)]
#[state_machine_future(transitions(Handshaking))]
ConnectingUnix {
future: tokio_uds::ConnectFuture,
tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(ConnectingTcp))]
ResolvingDns {
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(Handshaking))]
ConnectingTcp {
future: tokio_tcp::ConnectFuture,
addrs: vec::IntoIter<SocketAddr>,
tls_mode: T,
#[state_machine_future(transitions(Connecting))]
MakingTlsMode {
future: T::Future,
host: String,
port: u16,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(Finished))]
Handshaking { future: HandshakeFuture<Socket, T> },
Connecting {
future: ConnectOnceFuture<T::TlsMode>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
#[state_machine_future(error)]
@ -64,7 +34,7 @@ where
impl<T> PollConnect<T> for Connect<T>
where
T: TlsMode<Socket>,
T: MakeTlsMode<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
@ -79,99 +49,38 @@ where
None => 5432,
};
#[cfg(unix)]
{
if host.starts_with('/') {
let path = Path::new(&host).join(format!(".s.PGSQL.{}", port));
transition!(ConnectingUnix {
future: UnixStream::connect(path),
tls_mode: state.tls_mode,
params: state.params,
})
}
}
transition!(ResolvingDns {
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
tls_mode: state.tls_mode,
transition!(MakingTlsMode {
future: state.make_tls_mode.make_tls_mode(&host),
host,
port,
params: state.params,
})
}
#[cfg(unix)]
fn poll_connecting_unix<'a>(
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
) -> Poll<AfterConnectingUnix<T>, Error> {
let stream = try_ready!(state.future.poll().map_err(Error::connect));
let stream = Socket::new_unix(stream);
fn poll_making_tls_mode<'a>(
state: &'a mut RentToOwn<'a, MakingTlsMode<T>>,
) -> Poll<AfterMakingTlsMode<T>, Error> {
let tls_mode = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
let state = state.take();
transition!(Handshaking {
future: HandshakeFuture::new(stream, state.tls_mode, state.params)
transition!(Connecting {
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params),
})
}
fn poll_resolving_dns<'a>(
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
) -> Poll<AfterResolvingDns<T>, Error> {
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
let state = state.take();
let addr = match addrs.next() {
Some(addr) => addr,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidData,
"resolved 0 addresses",
)))
}
};
transition!(ConnectingTcp {
future: TcpStream::connect(&addr),
addrs,
tls_mode: state.tls_mode,
params: state.params,
})
}
fn poll_connecting_tcp<'a>(
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
) -> Poll<AfterConnectingTcp<T>, Error> {
let stream = loop {
match state.future.poll() {
Ok(Async::Ready(stream)) => break Socket::new_tcp(stream),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => {
let addr = match state.addrs.next() {
Some(addr) => addr,
None => return Err(Error::connect(e)),
};
state.future = TcpStream::connect(&addr);
}
}
};
let state = state.take();
transition!(Handshaking {
future: HandshakeFuture::new(stream, state.tls_mode, state.params),
})
}
fn poll_handshaking<'a>(
state: &'a mut RentToOwn<'a, Handshaking<T>>,
) -> Poll<AfterHandshaking<T>, Error> {
fn poll_connecting<'a>(
state: &'a mut RentToOwn<'a, Connecting<T>>,
) -> Poll<AfterConnecting<T>, Error> {
let r = try_ready!(state.future.poll());
transition!(Finished(r))
}
}
impl<T> ConnectFuture<T>
where
T: TlsMode<Socket>,
T: MakeTlsMode<Socket>,
{
pub fn new(tls_mode: T, params: HashMap<String, String>) -> ConnectFuture<T> {
Connect::start(tls_mode, params)
pub fn new(make_tls_mode: T, params: HashMap<String, String>) -> ConnectFuture<T> {
Connect::start(make_tls_mode, params)
}
}

View File

@ -0,0 +1,176 @@
use futures::{try_ready, Async, Future, Poll};
use futures_cpupool::{CpuFuture, CpuPool};
use lazy_static::lazy_static;
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::collections::HashMap;
use std::io;
use std::net::{SocketAddr, ToSocketAddrs};
#[cfg(unix)]
use std::path::Path;
use std::vec;
use tokio_tcp::TcpStream;
#[cfg(unix)]
use tokio_uds::UnixStream;
use crate::proto::{Client, Connection, HandshakeFuture};
use crate::{Error, Socket, TlsMode};
lazy_static! {
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
.name_prefix("postgres-dns-")
.pool_size(2)
.create();
}
#[derive(StateMachineFuture)]
pub enum ConnectOnce<T>
where
T: TlsMode<Socket>,
{
#[state_machine_future(start)]
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
Start {
host: String,
port: u16,
tls_mode: T,
params: HashMap<String, String>,
},
#[cfg(unix)]
#[state_machine_future(transitions(Handshaking))]
ConnectingUnix {
future: tokio_uds::ConnectFuture,
tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(ConnectingTcp))]
ResolvingDns {
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(Handshaking))]
ConnectingTcp {
future: tokio_tcp::ConnectFuture,
addrs: vec::IntoIter<SocketAddr>,
tls_mode: T,
params: HashMap<String, String>,
},
#[state_machine_future(transitions(Finished))]
Handshaking { future: HandshakeFuture<Socket, T> },
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
#[state_machine_future(error)]
Failed(Error),
}
impl<T> PollConnectOnce<T> for ConnectOnce<T>
where
T: TlsMode<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let state = state.take();
#[cfg(unix)]
{
if state.host.starts_with('/') {
let path = Path::new(&state.host).join(format!(".s.PGSQL.{}", state.port));
transition!(ConnectingUnix {
future: UnixStream::connect(path),
tls_mode: state.tls_mode,
params: state.params,
})
}
}
let host = state.host;
let port = state.port;
transition!(ResolvingDns {
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
tls_mode: state.tls_mode,
params: state.params,
})
}
#[cfg(unix)]
fn poll_connecting_unix<'a>(
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
) -> Poll<AfterConnectingUnix<T>, Error> {
let stream = try_ready!(state.future.poll().map_err(Error::connect));
let stream = Socket::new_unix(stream);
let state = state.take();
transition!(Handshaking {
future: HandshakeFuture::new(stream, state.tls_mode, state.params)
})
}
fn poll_resolving_dns<'a>(
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
) -> Poll<AfterResolvingDns<T>, Error> {
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
let state = state.take();
let addr = match addrs.next() {
Some(addr) => addr,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidData,
"resolved 0 addresses",
)))
}
};
transition!(ConnectingTcp {
future: TcpStream::connect(&addr),
addrs,
tls_mode: state.tls_mode,
params: state.params,
})
}
fn poll_connecting_tcp<'a>(
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
) -> Poll<AfterConnectingTcp<T>, Error> {
let stream = loop {
match state.future.poll() {
Ok(Async::Ready(stream)) => break Socket::new_tcp(stream),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => {
let addr = match state.addrs.next() {
Some(addr) => addr,
None => return Err(Error::connect(e)),
};
state.future = TcpStream::connect(&addr);
}
}
};
let state = state.take();
transition!(Handshaking {
future: HandshakeFuture::new(stream, state.tls_mode, state.params),
})
}
fn poll_handshaking<'a>(
state: &'a mut RentToOwn<'a, Handshaking<T>>,
) -> Poll<AfterHandshaking<T>, Error> {
let r = try_ready!(state.future.poll());
transition!(Finished(r))
}
}
impl<T> ConnectOnceFuture<T>
where
T: TlsMode<Socket>,
{
pub fn new(
host: String,
port: u16,
tls_mode: T,
params: HashMap<String, String>,
) -> ConnectOnceFuture<T> {
ConnectOnce::start(host, port, tls_mode, params)
}
}

View File

@ -24,6 +24,8 @@ mod client;
mod codec;
#[cfg(feature = "runtime")]
mod connect;
#[cfg(feature = "runtime")]
mod connect_once;
mod connection;
mod copy_in;
mod copy_out;
@ -46,6 +48,8 @@ pub use crate::proto::client::Client;
pub use crate::proto::codec::PostgresCodec;
#[cfg(feature = "runtime")]
pub use crate::proto::connect::ConnectFuture;
#[cfg(feature = "runtime")]
pub use crate::proto::connect_once::ConnectOnceFuture;
pub use crate::proto::connection::Connection;
pub use crate::proto::copy_in::CopyInFuture;
pub use crate::proto::copy_out::CopyOutStream;

View File

@ -6,12 +6,14 @@ use tokio_tcp::TcpStream;
#[cfg(unix)]
use tokio_uds::UnixStream;
#[derive(Debug)]
enum Inner {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
#[derive(Debug)]
pub struct Socket(Inner);
impl Socket {

View File

@ -25,6 +25,16 @@ impl ChannelBinding {
}
}
#[cfg(feature = "runtime")]
pub trait MakeTlsMode<S> {
type Stream: AsyncRead + AsyncWrite;
type TlsMode: TlsMode<S, Stream = Self::Stream>;
type Error: Into<Box<dyn Error + Sync + Send>>;
type Future: Future<Item = Self::TlsMode, Error = Self::Error>;
fn make_tls_mode(&mut self, domain: &str) -> Self::Future;
}
pub trait TlsMode<S> {
type Stream: AsyncRead + AsyncWrite;
type Error: Into<Box<dyn Error + Sync + Send>>;
@ -35,6 +45,16 @@ pub trait TlsMode<S> {
fn handle_tls(self, use_tls: bool, stream: S) -> Self::Future;
}
#[cfg(feature = "runtime")]
pub trait MakeTlsConnect<S> {
type Stream: AsyncRead + AsyncWrite;
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
type Error: Into<Box<dyn Error + Sync + Send>>;
type Future: Future<Item = Self::TlsConnect, Error = Self::Error>;
fn make_tls_connect(&mut self, domain: &str) -> Self::Future;
}
pub trait TlsConnect<S> {
type Stream: AsyncRead + AsyncWrite;
type Error: Into<Box<dyn Error + Sync + Send>>;
@ -46,6 +66,21 @@ pub trait TlsConnect<S> {
#[derive(Debug, Copy, Clone)]
pub struct NoTls;
#[cfg(feature = "runtime")]
impl<S> MakeTlsMode<S> for NoTls
where
S: AsyncRead + AsyncWrite,
{
type Stream = S;
type TlsMode = NoTls;
type Error = Void;
type Future = FutureResult<NoTls, Void>;
fn make_tls_mode(&mut self, _: &str) -> FutureResult<NoTls, Void> {
future::ok(NoTls)
}
}
impl<S> TlsMode<S> for NoTls
where
S: AsyncRead + AsyncWrite,
@ -68,6 +103,38 @@ where
#[derive(Debug, Copy, Clone)]
pub struct PreferTls<T>(pub T);
#[cfg(feature = "runtime")]
impl<T, S> MakeTlsMode<S> for PreferTls<T>
where
T: MakeTlsConnect<S>,
S: AsyncRead + AsyncWrite,
{
type Stream = MaybeTlsStream<T::Stream, S>;
type TlsMode = PreferTls<T::TlsConnect>;
type Error = T::Error;
type Future = MakePreferTlsFuture<T::Future>;
fn make_tls_mode(&mut self, domain: &str) -> MakePreferTlsFuture<T::Future> {
MakePreferTlsFuture(self.0.make_tls_connect(domain))
}
}
#[cfg(feature = "runtime")]
pub struct MakePreferTlsFuture<F>(F);
#[cfg(feature = "runtime")]
impl<F> Future for MakePreferTlsFuture<F>
where
F: Future,
{
type Item = PreferTls<F::Item>;
type Error = F::Error;
fn poll(&mut self) -> Poll<PreferTls<F::Item>, F::Error> {
self.0.poll().map(|f| f.map(PreferTls))
}
}
impl<T, S> TlsMode<S> for PreferTls<T>
where
T: TlsConnect<S>,
@ -207,6 +274,37 @@ where
#[derive(Debug, Copy, Clone)]
pub struct RequireTls<T>(pub T);
#[cfg(feature = "runtime")]
impl<T, S> MakeTlsMode<S> for RequireTls<T>
where
T: MakeTlsConnect<S>,
{
type Stream = T::Stream;
type TlsMode = RequireTls<T::TlsConnect>;
type Error = T::Error;
type Future = MakeRequireTlsFuture<T::Future>;
fn make_tls_mode(&mut self, domain: &str) -> MakeRequireTlsFuture<T::Future> {
MakeRequireTlsFuture(self.0.make_tls_connect(domain))
}
}
#[cfg(feature = "runtime")]
pub struct MakeRequireTlsFuture<F>(F);
#[cfg(feature = "runtime")]
impl<F> Future for MakeRequireTlsFuture<F>
where
F: Future,
{
type Item = RequireTls<F::Item>;
type Error = F::Error;
fn poll(&mut self) -> Poll<RequireTls<F::Item>, F::Error> {
self.0.poll().map(|f| f.map(RequireTls))
}
}
impl<T, S> TlsMode<S> for RequireTls<T>
where
T: TlsConnect<S>,