Remove future from MakeTlsMode

It's unlikely to be useful in practice, and just introduces more
complexity.
This commit is contained in:
Steven Fackler 2019-01-05 22:03:13 -08:00
parent 0ae7670e05
commit 940cbb8d4b
5 changed files with 44 additions and 107 deletions

View File

@ -18,7 +18,6 @@ impl Client {
T: MakeTlsMode<Socket> + 'static + Send,
T::TlsMode: Send,
T::Stream: Send,
T::Future: Send,
<T::TlsMode as TlsMode<Socket>>::Future: Send,
{
params.parse::<Config>()?.connect(tls_mode)

View File

@ -97,7 +97,6 @@ impl Config {
T: MakeTlsMode<Socket> + 'static + Send,
T::TlsMode: Send,
T::Stream: Send,
T::Future: Send,
<T::TlsMode as TlsMode<Socket>>::Future: Send,
{
let connect = self.0.connect(tls_mode);

View File

@ -1,7 +1,5 @@
#![warn(rust_2018_idioms, clippy::all)]
#[cfg(feature = "runtime")]
use futures::future::{self, FutureResult};
use futures::{try_ready, Async, Future, Poll};
#[cfg(feature = "runtime")]
use openssl::error::ErrorStack;
@ -44,12 +42,6 @@ impl MakeTlsConnector {
{
self.config = Arc::new(f);
}
fn make_tls_connect_inner(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
let mut ssl = self.connector.configure()?;
(self.config)(&mut ssl)?;
Ok(TlsConnector::new(ssl, domain))
}
}
#[cfg(feature = "runtime")]
@ -60,10 +52,11 @@ where
type Stream = SslStream<S>;
type TlsConnect = TlsConnector;
type Error = ErrorStack;
type Future = FutureResult<TlsConnector, ErrorStack>;
fn make_tls_connect(&mut self, domain: &str) -> FutureResult<TlsConnector, ErrorStack> {
future::result(self.make_tls_connect_inner(domain))
fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
let mut ssl = self.connector.configure()?;
(self.config)(&mut ssl)?;
Ok(TlsConnector::new(ssl, domain))
}
}

View File

@ -1,4 +1,4 @@
use futures::{try_ready, Async, Future, Poll};
use futures::{Async, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use crate::proto::{Client, ConnectOnceFuture, Connection};
@ -9,19 +9,12 @@ pub enum Connect<T>
where
T: MakeTlsMode<Socket>,
{
#[state_machine_future(start, transitions(MakingTlsMode))]
#[state_machine_future(start, transitions(Connecting))]
Start {
make_tls_mode: T,
config: Result<Config, Error>,
},
#[state_machine_future(transitions(Connecting))]
MakingTlsMode {
future: T::Future,
idx: usize,
make_tls_mode: T,
config: Config,
},
#[state_machine_future(transitions(MakingTlsMode, Finished))]
#[state_machine_future(transitions(Finished))]
Connecting {
future: ConnectOnceFuture<T::TlsMode>,
idx: usize,
@ -57,58 +50,48 @@ where
#[cfg(unix)]
Host::Unix(_) => "",
};
let future = state.make_tls_mode.make_tls_mode(hostname);
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
.map_err(|e| Error::tls(e.into()))?;
transition!(MakingTlsMode {
future,
transition!(Connecting {
future: ConnectOnceFuture::new(0, tls_mode, config.clone()),
idx: 0,
make_tls_mode: state.make_tls_mode,
config,
})
}
fn poll_making_tls_mode<'a>(
state: &'a mut RentToOwn<'a, MakingTlsMode<T>>,
) -> Poll<AfterMakingTlsMode<T>, Error> {
let tls_mode = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
let state = state.take();
transition!(Connecting {
future: ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone()),
idx: state.idx,
make_tls_mode: state.make_tls_mode,
config: state.config,
})
}
fn poll_connecting<'a>(
state: &'a mut RentToOwn<'a, Connecting<T>>,
) -> Poll<AfterConnecting<T>, Error> {
match state.future.poll() {
Ok(Async::Ready(r)) => transition!(Finished(r)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => {
let mut state = state.take();
let idx = state.idx + 1;
loop {
match state.future.poll() {
Ok(Async::Ready(r)) => transition!(Finished(r)),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => {
let state = &mut **state;
state.idx += 1;
let host = match state.config.0.host.get(idx) {
Some(host) => host,
None => return Err(e),
};
let host = match state.config.0.host.get(state.idx) {
Some(host) => host,
None => return Err(e),
};
let hostname = match host {
Host::Tcp(host) => &**host,
#[cfg(unix)]
Host::Unix(_) => "",
};
let future = state.make_tls_mode.make_tls_mode(hostname);
let hostname = match host {
Host::Tcp(host) => &**host,
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
.map_err(|e| Error::tls(e.into()))?;
transition!(MakingTlsMode {
future,
idx,
make_tls_mode: state.make_tls_mode,
config: state.config,
})
state.future =
ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone());
}
}
}
}

View File

@ -30,9 +30,8 @@ pub trait MakeTlsMode<S> {
type Stream: AsyncRead + AsyncWrite;
type TlsMode: TlsMode<S, Stream = Self::Stream>;
type Error: Into<Box<dyn Error + Sync + Send>>;
type Future: Future<Item = Self::TlsMode, Error = Self::Error>;
fn make_tls_mode(&mut self, domain: &str) -> Self::Future;
fn make_tls_mode(&mut self, domain: &str) -> Result<Self::TlsMode, Self::Error>;
}
pub trait TlsMode<S> {
@ -50,9 +49,8 @@ pub trait MakeTlsConnect<S> {
type Stream: AsyncRead + AsyncWrite;
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
type Error: Into<Box<dyn Error + Sync + Send>>;
type Future: Future<Item = Self::TlsConnect, Error = Self::Error>;
fn make_tls_connect(&mut self, domain: &str) -> Self::Future;
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
pub trait TlsConnect<S> {
@ -74,10 +72,9 @@ where
type Stream = S;
type TlsMode = NoTls;
type Error = Void;
type Future = FutureResult<NoTls, Void>;
fn make_tls_mode(&mut self, _: &str) -> FutureResult<NoTls, Void> {
future::ok(NoTls)
fn make_tls_mode(&mut self, _: &str) -> Result<NoTls, Void> {
Ok(NoTls)
}
}
@ -112,26 +109,9 @@ where
type Stream = MaybeTlsStream<T::Stream, S>;
type TlsMode = PreferTls<T::TlsConnect>;
type Error = T::Error;
type Future = MakePreferTlsFuture<T::Future>;
fn make_tls_mode(&mut self, domain: &str) -> MakePreferTlsFuture<T::Future> {
MakePreferTlsFuture(self.0.make_tls_connect(domain))
}
}
#[cfg(feature = "runtime")]
pub struct MakePreferTlsFuture<F>(F);
#[cfg(feature = "runtime")]
impl<F> Future for MakePreferTlsFuture<F>
where
F: Future,
{
type Item = PreferTls<F::Item>;
type Error = F::Error;
fn poll(&mut self) -> Poll<PreferTls<F::Item>, F::Error> {
self.0.poll().map(|f| f.map(PreferTls))
fn make_tls_mode(&mut self, domain: &str) -> Result<PreferTls<T::TlsConnect>, T::Error> {
self.0.make_tls_connect(domain).map(PreferTls)
}
}
@ -282,26 +262,9 @@ where
type Stream = T::Stream;
type TlsMode = RequireTls<T::TlsConnect>;
type Error = T::Error;
type Future = MakeRequireTlsFuture<T::Future>;
fn make_tls_mode(&mut self, domain: &str) -> MakeRequireTlsFuture<T::Future> {
MakeRequireTlsFuture(self.0.make_tls_connect(domain))
}
}
#[cfg(feature = "runtime")]
pub struct MakeRequireTlsFuture<F>(F);
#[cfg(feature = "runtime")]
impl<F> Future for MakeRequireTlsFuture<F>
where
F: Future,
{
type Item = RequireTls<F::Item>;
type Error = F::Error;
fn poll(&mut self) -> Poll<RequireTls<F::Item>, F::Error> {
self.0.poll().map(|f| f.map(RequireTls))
fn make_tls_mode(&mut self, domain: &str) -> Result<RequireTls<T::TlsConnect>, T::Error> {
self.0.make_tls_connect(domain).map(RequireTls)
}
}