tokio-postgres TLS setup

This commit is contained in:
Steven Fackler 2018-06-25 21:16:18 -07:00
parent 5fbe20fd25
commit 70758bcd93
5 changed files with 230 additions and 33 deletions

View File

@ -35,9 +35,11 @@ pub use postgres_shared::{CancelData, Notification};
use error::Error;
use params::ConnectParams;
use tls::TlsConnect;
use types::{FromSql, ToSql, Type};
mod proto;
pub mod tls;
static NEXT_STATEMENT_ID: AtomicUsize = AtomicUsize::new(0);
@ -55,8 +57,14 @@ fn disconnected() -> Error {
))
}
pub fn connect(params: ConnectParams) -> Handshake {
Handshake(proto::HandshakeFuture::new(params))
pub enum TlsMode {
None,
Prefer(Box<TlsConnect>),
Require(Box<TlsConnect>),
}
pub fn connect(params: ConnectParams, tls: TlsMode) -> Handshake {
Handshake(proto::HandshakeFuture::new(params, tls))
}
pub struct Client(proto::Client);

View File

@ -9,7 +9,7 @@ use tokio_codec::Framed;
use disconnected;
use error::{self, Error};
use proto::codec::PostgresCodec;
use proto::socket::Socket;
use tls::TlsStream;
use {bad_response, CancelData};
pub struct Request {
@ -25,7 +25,7 @@ enum State {
}
pub struct Connection {
stream: Framed<Socket, PostgresCodec>,
stream: Framed<Box<TlsStream>, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
@ -37,7 +37,7 @@ pub struct Connection {
impl Connection {
pub fn new(
stream: Framed<Socket, PostgresCodec>,
stream: Framed<Box<TlsStream>, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,

View File

@ -8,55 +8,84 @@ use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use state_machine_future::RentToOwn;
use std::collections::HashMap;
use std::error::Error as StdError;
use std::io;
use tokio_codec::Framed;
use tokio_io::io::{read_exact, write_all, ReadExact, WriteAll};
use error::{self, Error};
use params::{ConnectParams, User};
use params::{ConnectParams, Host, User};
use proto::client::Client;
use proto::codec::PostgresCodec;
use proto::connection::Connection;
use proto::socket::{ConnectFuture, Socket};
use {bad_response, disconnected, CancelData};
use tls::{self, TlsConnect, TlsStream};
use {bad_response, disconnected, CancelData, TlsMode};
#[derive(StateMachineFuture)]
pub enum Handshake {
#[state_machine_future(start, transitions(SendingStartup))]
#[state_machine_future(start, transitions(BuildingStartup, SendingSsl))]
Start {
future: ConnectFuture,
params: ConnectParams,
tls: TlsMode,
},
#[state_machine_future(transitions(ReadingSsl))]
SendingSsl {
future: WriteAll<Socket, Vec<u8>>,
params: ConnectParams,
connector: Box<TlsConnect>,
required: bool,
},
#[state_machine_future(transitions(ConnectingTls, BuildingStartup))]
ReadingSsl {
future: ReadExact<Socket, [u8; 1]>,
params: ConnectParams,
connector: Box<TlsConnect>,
required: bool,
},
#[state_machine_future(transitions(BuildingStartup))]
ConnectingTls {
future:
Box<Future<Item = Box<TlsStream>, Error = Box<StdError + Sync + Send>> + Sync + Send>,
params: ConnectParams,
},
#[state_machine_future(transitions(SendingStartup))]
BuildingStartup {
stream: Framed<Box<TlsStream>, PostgresCodec>,
params: ConnectParams,
},
#[state_machine_future(transitions(ReadingAuth))]
SendingStartup {
future: sink::Send<Framed<Socket, PostgresCodec>>,
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
user: User,
},
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
ReadingAuth {
stream: Framed<Socket, PostgresCodec>,
stream: Framed<Box<TlsStream>, PostgresCodec>,
user: User,
},
#[state_machine_future(transitions(ReadingAuthCompletion))]
SendingPassword {
future: sink::Send<Framed<Socket, PostgresCodec>>,
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
},
#[state_machine_future(transitions(ReadingSasl))]
SendingSasl {
future: sink::Send<Framed<Socket, PostgresCodec>>,
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
scram: ScramSha256,
},
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
ReadingSasl {
stream: Framed<Socket, PostgresCodec>,
stream: Framed<Box<TlsStream>, PostgresCodec>,
scram: ScramSha256,
},
#[state_machine_future(transitions(ReadingInfo))]
ReadingAuthCompletion {
stream: Framed<Socket, PostgresCodec>,
stream: Framed<Box<TlsStream>, PostgresCodec>,
},
#[state_machine_future(transitions(Finished))]
ReadingInfo {
stream: Framed<Socket, PostgresCodec>,
stream: Framed<Box<TlsStream>, PostgresCodec>,
cancel_data: Option<CancelData>,
parameters: HashMap<String, String>,
},
@ -71,6 +100,84 @@ impl PollHandshake for Handshake {
let stream = try_ready!(state.future.poll());
let state = state.take();
let (connector, required) = match state.tls {
TlsMode::None => {
transition!(BuildingStartup {
stream: Framed::new(Box::new(stream), PostgresCodec),
params: state.params,
});
}
TlsMode::Prefer(connector) => (connector, false),
TlsMode::Require(connector) => (connector, true),
};
let mut buf = vec![];
frontend::ssl_request(&mut buf);
transition!(SendingSsl {
future: write_all(stream, buf),
params: state.params,
connector,
required,
})
}
fn poll_sending_ssl<'a>(
state: &'a mut RentToOwn<'a, SendingSsl>,
) -> Poll<AfterSendingSsl, Error> {
let (stream, _) = try_ready!(state.future.poll());
let state = state.take();
transition!(ReadingSsl {
future: read_exact(stream, [0]),
params: state.params,
connector: state.connector,
required: state.required,
})
}
fn poll_reading_ssl<'a>(
state: &'a mut RentToOwn<'a, ReadingSsl>,
) -> Poll<AfterReadingSsl, Error> {
let (stream, buf) = try_ready!(state.future.poll());
let state = state.take();
match buf[0] {
b'S' => {
let future = match state.params.host() {
Host::Tcp(domain) => state.connector.connect(domain, tls::Socket(stream)),
Host::Unix(_) => {
return Err(error::tls("TLS over unix sockets not supported".into()))
}
};
transition!(ConnectingTls {
future,
params: state.params,
})
}
b'N' if !state.required => transition!(BuildingStartup {
stream: Framed::new(Box::new(stream), PostgresCodec),
params: state.params,
}),
b'N' => Err(error::tls("TLS was required but not supported".into())),
_ => Err(bad_response()),
}
}
fn poll_connecting_tls<'a>(
state: &'a mut RentToOwn<'a, ConnectingTls>,
) -> Poll<AfterConnectingTls, Error> {
let stream = try_ready!(state.future.poll().map_err(error::tls));
let state = state.take();
transition!(BuildingStartup {
stream: Framed::new(stream, PostgresCodec),
params: state.params,
})
}
fn poll_building_startup<'a>(
state: &'a mut RentToOwn<'a, BuildingStartup>,
) -> Poll<AfterBuildingStartup, Error> {
let state = state.take();
let user = match state.params.user() {
Some(user) => user.clone(),
None => {
@ -102,10 +209,8 @@ impl PollHandshake for Handshake {
)?;
}
let stream = Framed::new(stream, PostgresCodec);
transition!(SendingStartup {
future: stream.send(buf),
future: state.stream.send(buf),
user,
})
}
@ -298,8 +403,8 @@ impl PollHandshake for Handshake {
}
impl HandshakeFuture {
pub fn new(params: ConnectParams) -> HandshakeFuture {
Handshake::start(Socket::connect(&params), params)
pub fn new(params: ConnectParams, tls: TlsMode) -> HandshakeFuture {
Handshake::start(Socket::connect(&params), params, tls)
}
}

63
tokio-postgres/src/tls.rs Normal file
View File

@ -0,0 +1,63 @@
use bytes::{Buf, BufMut};
use futures::{Future, Poll};
use std::error::Error;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
use proto;
pub struct Socket(pub(crate) proto::Socket);
impl Read for Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl AsyncRead for Socket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
self.0.read_buf(buf)
}
}
impl Write for Socket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl AsyncWrite for Socket {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown()
}
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: Buf,
{
self.0.write_buf(buf)
}
}
pub trait TlsConnect {
fn connect(
&self,
domain: &str,
socket: Socket,
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send>;
}
pub trait TlsStream: 'static + Sync + Send + AsyncRead + AsyncWrite {}
impl TlsStream for proto::Socket {}

View File

@ -6,12 +6,13 @@ use tokio::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::error::SqlState;
use tokio_postgres::types::Type;
use tokio_postgres::TlsMode;
fn smoke_test(url: &str) {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(url.parse().unwrap());
let handshake = tokio_postgres::connect(url.parse().unwrap(), TlsMode::None);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
@ -34,7 +35,10 @@ fn plain_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect("postgres://pass_user@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://pass_user@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.as_connection().is_some() => {}
@ -47,8 +51,10 @@ fn plain_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake =
tokio_postgres::connect("postgres://pass_user:foo@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://pass_user:foo@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
@ -66,7 +72,10 @@ fn md5_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect("postgres://md5_user@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://md5_user@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.as_connection().is_some() => {}
@ -79,8 +88,10 @@ fn md5_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake =
tokio_postgres::connect("postgres://md5_user:foo@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://md5_user:foo@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
@ -98,8 +109,10 @@ fn scram_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake =
tokio_postgres::connect("postgres://scram_user@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://scram_user@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.as_connection().is_some() => {}
@ -112,8 +125,10 @@ fn scram_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake =
tokio_postgres::connect("postgres://scram_user:foo@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://scram_user:foo@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
@ -131,7 +146,10 @@ fn pipelined_prepare() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
@ -161,7 +179,10 @@ fn insert_select() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap());
let handshake = tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();