use fallible_iterator::FallibleIterator; use futures::sink; use futures::sync::mpsc; use futures::{Future, Poll, Sink, Stream}; use postgres_protocol::authentication; use postgres_protocol::authentication::sasl::{self, ChannelBinding, ScramSha256}; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use state_machine_future::RentToOwn; use std::collections::HashMap; use std::io; use tokio_codec::Framed; use error::{self, Error}; use params::{ConnectParams, User}; use proto::codec::PostgresCodec; use proto::connection::{Connection, Request}; use proto::socket::{ConnectFuture, Socket}; use {bad_response, disconnected, CancelData}; #[derive(StateMachineFuture)] pub enum Handshake { #[state_machine_future(start, transitions(SendingStartup))] Start { future: ConnectFuture, params: ConnectParams, }, #[state_machine_future(transitions(ReadingAuth))] SendingStartup { future: sink::Send>, user: User, }, #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] ReadingAuth { stream: Framed, user: User, }, #[state_machine_future(transitions(ReadingAuthCompletion))] SendingPassword { future: sink::Send>, }, #[state_machine_future(transitions(ReadingSasl))] SendingSasl { future: sink::Send>, scram: ScramSha256, }, #[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))] ReadingSasl { stream: Framed, scram: ScramSha256, }, #[state_machine_future(transitions(ReadingInfo))] ReadingAuthCompletion { stream: Framed, }, #[state_machine_future(transitions(Finished))] ReadingInfo { stream: Framed, cancel_data: Option, parameters: HashMap, }, #[state_machine_future(ready)] Finished((mpsc::Sender, Connection)), #[state_machine_future(error)] Failed(Error), } impl PollHandshake for Handshake { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { let stream = try_ready!(state.future.poll()); let state = state.take(); let user = match state.params.user() { Some(user) => user.clone(), None => { return Err(error::connect( "user missing from connection parameters".into(), )) } }; let mut buf = vec![]; { let options = state .params .options() .iter() .map(|&(ref key, ref value)| (&**key, &**value)); let client_encoding = Some(("client_encoding", "UTF8")); let timezone = Some(("timezone", "GMT")); let user = Some(("user", user.name())); let database = state.params.database().map(|s| ("database", s)); frontend::startup_message( options .chain(client_encoding) .chain(timezone) .chain(user) .chain(database), &mut buf, )?; } let stream = Framed::new(stream, PostgresCodec); transition!(SendingStartup { future: stream.send(buf), user, }) } fn poll_sending_startup<'a>( state: &'a mut RentToOwn<'a, SendingStartup>, ) -> Poll { let stream = try_ready!(state.future.poll()); let state = state.take(); transition!(ReadingAuth { stream, user: state.user, }) } fn poll_reading_auth<'a>( state: &'a mut RentToOwn<'a, ReadingAuth>, ) -> Poll { let message = try_ready!(state.stream.poll()); let state = state.take(); match message { Some(Message::AuthenticationOk) => transition!(ReadingInfo { stream: state.stream, cancel_data: None, parameters: HashMap::new(), }), Some(Message::AuthenticationCleartextPassword) => { let pass = state.user.password().ok_or_else(missing_password)?; let mut buf = vec![]; frontend::password_message(pass, &mut buf)?; transition!(SendingPassword { future: state.stream.send(buf) }) } Some(Message::AuthenticationMd5Password(body)) => { let pass = state.user.password().ok_or_else(missing_password)?; let output = authentication::md5_hash( state.user.name().as_bytes(), pass.as_bytes(), body.salt(), ); let mut buf = vec![]; frontend::password_message(&output, &mut buf)?; transition!(SendingPassword { future: state.stream.send(buf) }) } Some(Message::AuthenticationSasl(body)) => { let pass = state.user.password().ok_or_else(missing_password)?; let mut has_scram = false; let mut mechanisms = body.mechanisms(); while let Some(mechanism) = mechanisms.next()? { match mechanism { sasl::SCRAM_SHA_256 => has_scram = true, _ => {} } } if !has_scram { return Err(io::Error::new( io::ErrorKind::Other, "unsupported SASL authentication", ).into()); } let mut scram = ScramSha256::new(pass.as_bytes(), ChannelBinding::unsupported())?; let mut buf = vec![]; frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?; transition!(SendingSasl { future: state.stream.send(buf), scram, }) } Some(Message::AuthenticationKerberosV5) | Some(Message::AuthenticationScmCredential) | Some(Message::AuthenticationGss) | Some(Message::AuthenticationSspi) => Err(io::Error::new( io::ErrorKind::Other, "unsupported authentication method", ).into()), Some(Message::ErrorResponse(body)) => Err(error::__db(body)), Some(_) => Err(bad_response()), None => Err(disconnected()), } } fn poll_sending_password<'a>( state: &'a mut RentToOwn<'a, SendingPassword>, ) -> Poll { let stream = try_ready!(state.future.poll()); transition!(ReadingAuthCompletion { stream }) } fn poll_sending_sasl<'a>( state: &'a mut RentToOwn<'a, SendingSasl>, ) -> Poll { let stream = try_ready!(state.future.poll()); let state = state.take(); transition!(ReadingSasl { stream, scram: state.scram }) } fn poll_reading_sasl<'a>( state: &'a mut RentToOwn<'a, ReadingSasl>, ) -> Poll { let message = try_ready!(state.stream.poll()); let mut state = state.take(); match message { Some(Message::AuthenticationSaslContinue(body)) => { state.scram.update(body.data())?; let mut buf = vec![]; frontend::sasl_response(state.scram.message(), &mut buf)?; transition!(SendingSasl { future: state.stream.send(buf), scram: state.scram, }) } Some(Message::AuthenticationSaslFinal(body)) => { state.scram.finish(body.data())?; transition!(ReadingAuthCompletion { stream: state.stream, }) } Some(Message::ErrorResponse(body)) => Err(error::__db(body)), Some(_) => Err(bad_response()), None => Err(disconnected()), } } fn poll_reading_auth_completion<'a>( state: &'a mut RentToOwn<'a, ReadingAuthCompletion>, ) -> Poll { let message = try_ready!(state.stream.poll()); let state = state.take(); match message { Some(Message::AuthenticationOk) => transition!(ReadingInfo { stream: state.stream, cancel_data: None, parameters: HashMap::new(), }), Some(Message::ErrorResponse(body)) => Err(error::__db(body)), Some(_) => Err(bad_response()), None => Err(disconnected()), } } fn poll_reading_info<'a>( state: &'a mut RentToOwn<'a, ReadingInfo>, ) -> Poll { loop { let message = try_ready!(state.stream.poll()); match message { Some(Message::BackendKeyData(body)) => { state.cancel_data = Some(CancelData { process_id: body.process_id(), secret_key: body.secret_key(), }); } Some(Message::ParameterStatus(body)) => { state .parameters .insert(body.name()?.to_string(), body.value()?.to_string()); } Some(Message::ReadyForQuery(_)) => { let state = state.take(); let cancel_data = state.cancel_data.ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidData, "BackendKeyData message missing") })?; let (sender, receiver) = mpsc::channel(0); let connection = Connection::new(state.stream, cancel_data, state.parameters, receiver); transition!(Finished((sender, connection))) } Some(Message::ErrorResponse(body)) => return Err(error::__db(body)), Some(Message::NoticeResponse(_)) => {} Some(_) => return Err(bad_response()), None => return Err(disconnected()), } } } } impl HandshakeFuture { pub fn new(params: ConnectParams) -> HandshakeFuture { Handshake::start(Socket::connect(¶ms), params) } } fn missing_password() -> Error { error::connect("a password was requested but not provided".into()) }