parent
56088a9a46
commit
7e7ae968c1
@ -4,6 +4,10 @@ version = "0.1.0"
|
||||
authors = ["Steven Fackler <sfackler@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
default = ["runtime"]
|
||||
runtime = ["tokio-postgres/runtime"]
|
||||
|
||||
[dependencies]
|
||||
futures = "0.1"
|
||||
openssl = "0.10"
|
||||
|
@ -1,17 +1,72 @@
|
||||
#![warn(rust_2018_idioms, clippy::all)]
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
use futures::future::{self, FutureResult};
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
#[cfg(feature = "runtime")]
|
||||
use openssl::error::ErrorStack;
|
||||
use openssl::hash::MessageDigest;
|
||||
use openssl::nid::Nid;
|
||||
#[cfg(feature = "runtime")]
|
||||
use openssl::ssl::SslConnector;
|
||||
use openssl::ssl::{ConnectConfiguration, HandshakeError, SslRef};
|
||||
use std::fmt::Debug;
|
||||
#[cfg(feature = "runtime")]
|
||||
use std::sync::Arc;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tokio_openssl::{ConnectAsync, ConnectConfigurationExt, SslStream};
|
||||
#[cfg(feature = "runtime")]
|
||||
use tokio_postgres::MakeTlsConnect;
|
||||
use tokio_postgres::{ChannelBinding, TlsConnect};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[derive(Clone)]
|
||||
pub struct MakeTlsConnector {
|
||||
connector: SslConnector,
|
||||
config: Arc<dyn Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl MakeTlsConnector {
|
||||
pub fn new(connector: SslConnector) -> MakeTlsConnector {
|
||||
MakeTlsConnector {
|
||||
connector,
|
||||
config: Arc::new(|_| Ok(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_callback<F>(&mut self, f: F)
|
||||
where
|
||||
F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
|
||||
{
|
||||
self.config = Arc::new(f);
|
||||
}
|
||||
|
||||
fn make_tls_connect_inner(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
|
||||
let mut ssl = self.connector.configure()?;
|
||||
(self.config)(&mut ssl)?;
|
||||
Ok(TlsConnector::new(ssl, domain))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<S> MakeTlsConnect<S> for MakeTlsConnector
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Debug + 'static + Sync + Send,
|
||||
{
|
||||
type Stream = SslStream<S>;
|
||||
type TlsConnect = TlsConnector;
|
||||
type Error = ErrorStack;
|
||||
type Future = FutureResult<TlsConnector, ErrorStack>;
|
||||
|
||||
fn make_tls_connect(&mut self, domain: &str) -> FutureResult<TlsConnector, ErrorStack> {
|
||||
future::result(self.make_tls_connect_inner(domain))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TlsConnector {
|
||||
ssl: ConnectConfiguration,
|
||||
domain: String,
|
||||
|
@ -4,7 +4,7 @@ use tokio::net::TcpStream;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
|
||||
|
||||
use crate::TlsConnector;
|
||||
use super::*;
|
||||
|
||||
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
|
||||
where
|
||||
@ -72,3 +72,24 @@ fn scram_user() {
|
||||
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "runtime")]
|
||||
fn runtime() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let connector = MakeTlsConnector::new(builder.build());
|
||||
|
||||
let connect = "host=localhost port=5433 user=postgres"
|
||||
.parse::<tokio_postgres::Builder>()
|
||||
.unwrap()
|
||||
.connect(RequireTls(connector));
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
|
||||
let execute = client.batch_execute("SELECT 1");
|
||||
runtime.block_on(execute).unwrap();
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use crate::proto::ConnectFuture;
|
||||
use crate::proto::HandshakeFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::{Connect, Socket};
|
||||
use crate::{Connect, MakeTlsMode, Socket};
|
||||
use crate::{Error, Handshake, TlsMode};
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -61,11 +61,11 @@ impl Builder {
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn connect<T>(&self, tls_mode: T) -> Connect<T>
|
||||
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
Connect(ConnectFuture::new(tls_mode, self.params.clone()))
|
||||
Connect(ConnectFuture::new(make_tls_mode, self.params.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,12 +184,12 @@ where
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Connect<T>(proto::ConnectFuture<T>)
|
||||
where
|
||||
T: TlsMode<Socket>;
|
||||
T: MakeTlsMode<Socket>;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T> Future for Connect<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
type Item = (Client, Connection<T::Stream>);
|
||||
type Error = Error;
|
||||
|
@ -1,61 +1,31 @@
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use futures_cpupool::{CpuFuture, CpuPool};
|
||||
use lazy_static::lazy_static;
|
||||
use futures::{try_ready, Future, Poll};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
#[cfg(unix)]
|
||||
use std::path::Path;
|
||||
use std::vec;
|
||||
use tokio_tcp::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
use crate::proto::{Client, Connection, HandshakeFuture};
|
||||
use crate::{Error, Socket, TlsMode};
|
||||
|
||||
lazy_static! {
|
||||
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
|
||||
.name_prefix("postgres-dns-")
|
||||
.pool_size(2)
|
||||
.create();
|
||||
}
|
||||
use crate::proto::{Client, ConnectOnceFuture, Connection};
|
||||
use crate::{Error, MakeTlsMode, Socket};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Connect<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
#[state_machine_future(start)]
|
||||
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
|
||||
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
|
||||
#[state_machine_future(start, transitions(MakingTlsMode))]
|
||||
Start {
|
||||
tls_mode: T,
|
||||
make_tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[cfg(unix)]
|
||||
#[state_machine_future(transitions(Handshaking))]
|
||||
ConnectingUnix {
|
||||
future: tokio_uds::ConnectFuture,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(ConnectingTcp))]
|
||||
ResolvingDns {
|
||||
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(Handshaking))]
|
||||
ConnectingTcp {
|
||||
future: tokio_tcp::ConnectFuture,
|
||||
addrs: vec::IntoIter<SocketAddr>,
|
||||
tls_mode: T,
|
||||
#[state_machine_future(transitions(Connecting))]
|
||||
MakingTlsMode {
|
||||
future: T::Future,
|
||||
host: String,
|
||||
port: u16,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Handshaking { future: HandshakeFuture<Socket, T> },
|
||||
Connecting {
|
||||
future: ConnectOnceFuture<T::TlsMode>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
#[state_machine_future(error)]
|
||||
@ -64,7 +34,7 @@ where
|
||||
|
||||
impl<T> PollConnect<T> for Connect<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
@ -79,99 +49,38 @@ where
|
||||
None => 5432,
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if host.starts_with('/') {
|
||||
let path = Path::new(&host).join(format!(".s.PGSQL.{}", port));
|
||||
transition!(ConnectingUnix {
|
||||
future: UnixStream::connect(path),
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
transition!(ResolvingDns {
|
||||
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
|
||||
tls_mode: state.tls_mode,
|
||||
transition!(MakingTlsMode {
|
||||
future: state.make_tls_mode.make_tls_mode(&host),
|
||||
host,
|
||||
port,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn poll_connecting_unix<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
|
||||
) -> Poll<AfterConnectingUnix<T>, Error> {
|
||||
let stream = try_ready!(state.future.poll().map_err(Error::connect));
|
||||
let stream = Socket::new_unix(stream);
|
||||
fn poll_making_tls_mode<'a>(
|
||||
state: &'a mut RentToOwn<'a, MakingTlsMode<T>>,
|
||||
) -> Poll<AfterMakingTlsMode<T>, Error> {
|
||||
let tls_mode = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
|
||||
let state = state.take();
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.params)
|
||||
transition!(Connecting {
|
||||
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_resolving_dns<'a>(
|
||||
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
|
||||
) -> Poll<AfterResolvingDns<T>, Error> {
|
||||
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
|
||||
let state = state.take();
|
||||
|
||||
let addr = match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"resolved 0 addresses",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
transition!(ConnectingTcp {
|
||||
future: TcpStream::connect(&addr),
|
||||
addrs,
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_connecting_tcp<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
|
||||
) -> Poll<AfterConnectingTcp<T>, Error> {
|
||||
let stream = loop {
|
||||
match state.future.poll() {
|
||||
Ok(Async::Ready(stream)) => break Socket::new_tcp(stream),
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Err(e) => {
|
||||
let addr = match state.addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return Err(Error::connect(e)),
|
||||
};
|
||||
state.future = TcpStream::connect(&addr);
|
||||
}
|
||||
}
|
||||
};
|
||||
let state = state.take();
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.params),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_handshaking<'a>(
|
||||
state: &'a mut RentToOwn<'a, Handshaking<T>>,
|
||||
) -> Poll<AfterHandshaking<T>, Error> {
|
||||
fn poll_connecting<'a>(
|
||||
state: &'a mut RentToOwn<'a, Connecting<T>>,
|
||||
) -> Poll<AfterConnecting<T>, Error> {
|
||||
let r = try_ready!(state.future.poll());
|
||||
|
||||
transition!(Finished(r))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ConnectFuture<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
pub fn new(tls_mode: T, params: HashMap<String, String>) -> ConnectFuture<T> {
|
||||
Connect::start(tls_mode, params)
|
||||
pub fn new(make_tls_mode: T, params: HashMap<String, String>) -> ConnectFuture<T> {
|
||||
Connect::start(make_tls_mode, params)
|
||||
}
|
||||
}
|
||||
|
176
tokio-postgres/src/proto/connect_once.rs
Normal file
176
tokio-postgres/src/proto/connect_once.rs
Normal file
@ -0,0 +1,176 @@
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use futures_cpupool::{CpuFuture, CpuPool};
|
||||
use lazy_static::lazy_static;
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
#[cfg(unix)]
|
||||
use std::path::Path;
|
||||
use std::vec;
|
||||
use tokio_tcp::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
use crate::proto::{Client, Connection, HandshakeFuture};
|
||||
use crate::{Error, Socket, TlsMode};
|
||||
|
||||
lazy_static! {
|
||||
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
|
||||
.name_prefix("postgres-dns-")
|
||||
.pool_size(2)
|
||||
.create();
|
||||
}
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum ConnectOnce<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
{
|
||||
#[state_machine_future(start)]
|
||||
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
|
||||
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
|
||||
Start {
|
||||
host: String,
|
||||
port: u16,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[cfg(unix)]
|
||||
#[state_machine_future(transitions(Handshaking))]
|
||||
ConnectingUnix {
|
||||
future: tokio_uds::ConnectFuture,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(ConnectingTcp))]
|
||||
ResolvingDns {
|
||||
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(Handshaking))]
|
||||
ConnectingTcp {
|
||||
future: tokio_tcp::ConnectFuture,
|
||||
addrs: vec::IntoIter<SocketAddr>,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Handshaking { future: HandshakeFuture<Socket, T> },
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
|
||||
impl<T> PollConnectOnce<T> for ConnectOnce<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let state = state.take();
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if state.host.starts_with('/') {
|
||||
let path = Path::new(&state.host).join(format!(".s.PGSQL.{}", state.port));
|
||||
transition!(ConnectingUnix {
|
||||
future: UnixStream::connect(path),
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let host = state.host;
|
||||
let port = state.port;
|
||||
transition!(ResolvingDns {
|
||||
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn poll_connecting_unix<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingUnix<T>>,
|
||||
) -> Poll<AfterConnectingUnix<T>, Error> {
|
||||
let stream = try_ready!(state.future.poll().map_err(Error::connect));
|
||||
let stream = Socket::new_unix(stream);
|
||||
let state = state.take();
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.params)
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_resolving_dns<'a>(
|
||||
state: &'a mut RentToOwn<'a, ResolvingDns<T>>,
|
||||
) -> Poll<AfterResolvingDns<T>, Error> {
|
||||
let mut addrs = try_ready!(state.future.poll().map_err(Error::connect));
|
||||
let state = state.take();
|
||||
|
||||
let addr = match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"resolved 0 addresses",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
transition!(ConnectingTcp {
|
||||
future: TcpStream::connect(&addr),
|
||||
addrs,
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_connecting_tcp<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingTcp<T>>,
|
||||
) -> Poll<AfterConnectingTcp<T>, Error> {
|
||||
let stream = loop {
|
||||
match state.future.poll() {
|
||||
Ok(Async::Ready(stream)) => break Socket::new_tcp(stream),
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Err(e) => {
|
||||
let addr = match state.addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return Err(Error::connect(e)),
|
||||
};
|
||||
state.future = TcpStream::connect(&addr);
|
||||
}
|
||||
}
|
||||
};
|
||||
let state = state.take();
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.params),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_handshaking<'a>(
|
||||
state: &'a mut RentToOwn<'a, Handshaking<T>>,
|
||||
) -> Poll<AfterHandshaking<T>, Error> {
|
||||
let r = try_ready!(state.future.poll());
|
||||
|
||||
transition!(Finished(r))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ConnectOnceFuture<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
{
|
||||
pub fn new(
|
||||
host: String,
|
||||
port: u16,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
) -> ConnectOnceFuture<T> {
|
||||
ConnectOnce::start(host, port, tls_mode, params)
|
||||
}
|
||||
}
|
@ -24,6 +24,8 @@ mod client;
|
||||
mod codec;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod connect;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod connect_once;
|
||||
mod connection;
|
||||
mod copy_in;
|
||||
mod copy_out;
|
||||
@ -46,6 +48,8 @@ pub use crate::proto::client::Client;
|
||||
pub use crate::proto::codec::PostgresCodec;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::proto::connect::ConnectFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::proto::connect_once::ConnectOnceFuture;
|
||||
pub use crate::proto::connection::Connection;
|
||||
pub use crate::proto::copy_in::CopyInFuture;
|
||||
pub use crate::proto::copy_out::CopyOutStream;
|
||||
|
@ -6,12 +6,14 @@ use tokio_tcp::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Inner {
|
||||
Tcp(TcpStream),
|
||||
#[cfg(unix)]
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Socket(Inner);
|
||||
|
||||
impl Socket {
|
||||
|
@ -25,6 +25,16 @@ impl ChannelBinding {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub trait MakeTlsMode<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
type TlsMode: TlsMode<S, Stream = Self::Stream>;
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
type Future: Future<Item = Self::TlsMode, Error = Self::Error>;
|
||||
|
||||
fn make_tls_mode(&mut self, domain: &str) -> Self::Future;
|
||||
}
|
||||
|
||||
pub trait TlsMode<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
@ -35,6 +45,16 @@ pub trait TlsMode<S> {
|
||||
fn handle_tls(self, use_tls: bool, stream: S) -> Self::Future;
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub trait MakeTlsConnect<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
type Future: Future<Item = Self::TlsConnect, Error = Self::Error>;
|
||||
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Self::Future;
|
||||
}
|
||||
|
||||
pub trait TlsConnect<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
@ -46,6 +66,21 @@ pub trait TlsConnect<S> {
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct NoTls;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<S> MakeTlsMode<S> for NoTls
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Stream = S;
|
||||
type TlsMode = NoTls;
|
||||
type Error = Void;
|
||||
type Future = FutureResult<NoTls, Void>;
|
||||
|
||||
fn make_tls_mode(&mut self, _: &str) -> FutureResult<NoTls, Void> {
|
||||
future::ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> TlsMode<S> for NoTls
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
@ -68,6 +103,38 @@ where
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct PreferTls<T>(pub T);
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T, S> MakeTlsMode<S> for PreferTls<T>
|
||||
where
|
||||
T: MakeTlsConnect<S>,
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Stream = MaybeTlsStream<T::Stream, S>;
|
||||
type TlsMode = PreferTls<T::TlsConnect>;
|
||||
type Error = T::Error;
|
||||
type Future = MakePreferTlsFuture<T::Future>;
|
||||
|
||||
fn make_tls_mode(&mut self, domain: &str) -> MakePreferTlsFuture<T::Future> {
|
||||
MakePreferTlsFuture(self.0.make_tls_connect(domain))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub struct MakePreferTlsFuture<F>(F);
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<F> Future for MakePreferTlsFuture<F>
|
||||
where
|
||||
F: Future,
|
||||
{
|
||||
type Item = PreferTls<F::Item>;
|
||||
type Error = F::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<PreferTls<F::Item>, F::Error> {
|
||||
self.0.poll().map(|f| f.map(PreferTls))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TlsMode<S> for PreferTls<T>
|
||||
where
|
||||
T: TlsConnect<S>,
|
||||
@ -207,6 +274,37 @@ where
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct RequireTls<T>(pub T);
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T, S> MakeTlsMode<S> for RequireTls<T>
|
||||
where
|
||||
T: MakeTlsConnect<S>,
|
||||
{
|
||||
type Stream = T::Stream;
|
||||
type TlsMode = RequireTls<T::TlsConnect>;
|
||||
type Error = T::Error;
|
||||
type Future = MakeRequireTlsFuture<T::Future>;
|
||||
|
||||
fn make_tls_mode(&mut self, domain: &str) -> MakeRequireTlsFuture<T::Future> {
|
||||
MakeRequireTlsFuture(self.0.make_tls_connect(domain))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub struct MakeRequireTlsFuture<F>(F);
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<F> Future for MakeRequireTlsFuture<F>
|
||||
where
|
||||
F: Future,
|
||||
{
|
||||
type Item = RequireTls<F::Item>;
|
||||
type Error = F::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<RequireTls<F::Item>, F::Error> {
|
||||
self.0.poll().map(|f| f.map(RequireTls))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TlsMode<S> for RequireTls<T>
|
||||
where
|
||||
T: TlsConnect<S>,
|
||||
|
Loading…
Reference in New Issue
Block a user