use futures::{try_ready, Future, Poll}; use postgres_protocol::message::frontend; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use tokio_io::io::{self, ReadExact, WriteAll}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::{ChannelBinding, Error, TlsMode}; #[derive(StateMachineFuture)] pub enum Tls where T: TlsMode, S: AsyncRead + AsyncWrite, { #[state_machine_future(start, transitions(SendingTls, ConnectingTls))] Start { stream: S, tls_mode: T }, #[state_machine_future(transitions(ReadingTls))] SendingTls { future: WriteAll>, tls_mode: T, }, #[state_machine_future(transitions(ConnectingTls))] ReadingTls { future: ReadExact, tls_mode: T, }, #[state_machine_future(transitions(Ready))] ConnectingTls { future: T::Future }, #[state_machine_future(ready)] Ready((T::Stream, ChannelBinding)), #[state_machine_future(error)] Failed(Error), } impl PollTls for Tls where T: TlsMode, S: AsyncRead + AsyncWrite, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let state = state.take(); if state.tls_mode.request_tls() { let mut buf = vec![]; frontend::ssl_request(&mut buf); transition!(SendingTls { future: io::write_all(state.stream, buf), tls_mode: state.tls_mode, }) } else { transition!(ConnectingTls { future: state.tls_mode.handle_tls(false, state.stream), }) } } fn poll_sending_tls<'a>( state: &'a mut RentToOwn<'a, SendingTls>, ) -> Poll, Error> { let (stream, _) = try_ready!(state.future.poll().map_err(Error::io)); let state = state.take(); transition!(ReadingTls { future: io::read_exact(stream, [0]), tls_mode: state.tls_mode, }) } fn poll_reading_tls<'a>( state: &'a mut RentToOwn<'a, ReadingTls>, ) -> Poll, Error> { let (stream, buf) = try_ready!(state.future.poll().map_err(Error::io)); let state = state.take(); let use_tls = buf[0] == b'S'; transition!(ConnectingTls { future: state.tls_mode.handle_tls(use_tls, stream) }) } fn poll_connecting_tls<'a>( state: &'a mut RentToOwn<'a, ConnectingTls>, ) -> Poll, Error> { let t = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into()))); transition!(Ready(t)) } } impl TlsFuture where T: TlsMode, S: AsyncRead + AsyncWrite, { pub fn new(stream: S, tls_mode: T) -> TlsFuture { Tls::start(stream, tls_mode) } }