tokio-postgres TLS setup
This commit is contained in:
parent
5fbe20fd25
commit
70758bcd93
@ -35,9 +35,11 @@ pub use postgres_shared::{CancelData, Notification};
|
||||
|
||||
use error::Error;
|
||||
use params::ConnectParams;
|
||||
use tls::TlsConnect;
|
||||
use types::{FromSql, ToSql, Type};
|
||||
|
||||
mod proto;
|
||||
pub mod tls;
|
||||
|
||||
static NEXT_STATEMENT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
@ -55,8 +57,14 @@ fn disconnected() -> Error {
|
||||
))
|
||||
}
|
||||
|
||||
pub fn connect(params: ConnectParams) -> Handshake {
|
||||
Handshake(proto::HandshakeFuture::new(params))
|
||||
pub enum TlsMode {
|
||||
None,
|
||||
Prefer(Box<TlsConnect>),
|
||||
Require(Box<TlsConnect>),
|
||||
}
|
||||
|
||||
pub fn connect(params: ConnectParams, tls: TlsMode) -> Handshake {
|
||||
Handshake(proto::HandshakeFuture::new(params, tls))
|
||||
}
|
||||
|
||||
pub struct Client(proto::Client);
|
||||
|
@ -9,7 +9,7 @@ use tokio_codec::Framed;
|
||||
use disconnected;
|
||||
use error::{self, Error};
|
||||
use proto::codec::PostgresCodec;
|
||||
use proto::socket::Socket;
|
||||
use tls::TlsStream;
|
||||
use {bad_response, CancelData};
|
||||
|
||||
pub struct Request {
|
||||
@ -25,7 +25,7 @@ enum State {
|
||||
}
|
||||
|
||||
pub struct Connection {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
cancel_data: CancelData,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
@ -37,7 +37,7 @@ pub struct Connection {
|
||||
|
||||
impl Connection {
|
||||
pub fn new(
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
cancel_data: CancelData,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
|
@ -8,55 +8,84 @@ use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use state_machine_future::RentToOwn;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error as StdError;
|
||||
use std::io;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::io::{read_exact, write_all, ReadExact, WriteAll};
|
||||
|
||||
use error::{self, Error};
|
||||
use params::{ConnectParams, User};
|
||||
use params::{ConnectParams, Host, User};
|
||||
use proto::client::Client;
|
||||
use proto::codec::PostgresCodec;
|
||||
use proto::connection::Connection;
|
||||
use proto::socket::{ConnectFuture, Socket};
|
||||
use {bad_response, disconnected, CancelData};
|
||||
use tls::{self, TlsConnect, TlsStream};
|
||||
use {bad_response, disconnected, CancelData, TlsMode};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Handshake {
|
||||
#[state_machine_future(start, transitions(SendingStartup))]
|
||||
#[state_machine_future(start, transitions(BuildingStartup, SendingSsl))]
|
||||
Start {
|
||||
future: ConnectFuture,
|
||||
params: ConnectParams,
|
||||
tls: TlsMode,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingSsl))]
|
||||
SendingSsl {
|
||||
future: WriteAll<Socket, Vec<u8>>,
|
||||
params: ConnectParams,
|
||||
connector: Box<TlsConnect>,
|
||||
required: bool,
|
||||
},
|
||||
#[state_machine_future(transitions(ConnectingTls, BuildingStartup))]
|
||||
ReadingSsl {
|
||||
future: ReadExact<Socket, [u8; 1]>,
|
||||
params: ConnectParams,
|
||||
connector: Box<TlsConnect>,
|
||||
required: bool,
|
||||
},
|
||||
#[state_machine_future(transitions(BuildingStartup))]
|
||||
ConnectingTls {
|
||||
future:
|
||||
Box<Future<Item = Box<TlsStream>, Error = Box<StdError + Sync + Send>> + Sync + Send>,
|
||||
params: ConnectParams,
|
||||
},
|
||||
#[state_machine_future(transitions(SendingStartup))]
|
||||
BuildingStartup {
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
params: ConnectParams,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuth))]
|
||||
SendingStartup {
|
||||
future: sink::Send<Framed<Socket, PostgresCodec>>,
|
||||
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
|
||||
user: User,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
|
||||
ReadingAuth {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
user: User,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuthCompletion))]
|
||||
SendingPassword {
|
||||
future: sink::Send<Framed<Socket, PostgresCodec>>,
|
||||
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingSasl))]
|
||||
SendingSasl {
|
||||
future: sink::Send<Framed<Socket, PostgresCodec>>,
|
||||
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
|
||||
scram: ScramSha256,
|
||||
},
|
||||
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
|
||||
ReadingSasl {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
scram: ScramSha256,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo))]
|
||||
ReadingAuthCompletion {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
ReadingInfo {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
stream: Framed<Box<TlsStream>, PostgresCodec>,
|
||||
cancel_data: Option<CancelData>,
|
||||
parameters: HashMap<String, String>,
|
||||
},
|
||||
@ -71,6 +100,84 @@ impl PollHandshake for Handshake {
|
||||
let stream = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
|
||||
let (connector, required) = match state.tls {
|
||||
TlsMode::None => {
|
||||
transition!(BuildingStartup {
|
||||
stream: Framed::new(Box::new(stream), PostgresCodec),
|
||||
params: state.params,
|
||||
});
|
||||
}
|
||||
TlsMode::Prefer(connector) => (connector, false),
|
||||
TlsMode::Require(connector) => (connector, true),
|
||||
};
|
||||
|
||||
let mut buf = vec![];
|
||||
frontend::ssl_request(&mut buf);
|
||||
transition!(SendingSsl {
|
||||
future: write_all(stream, buf),
|
||||
params: state.params,
|
||||
connector,
|
||||
required,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_sending_ssl<'a>(
|
||||
state: &'a mut RentToOwn<'a, SendingSsl>,
|
||||
) -> Poll<AfterSendingSsl, Error> {
|
||||
let (stream, _) = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
transition!(ReadingSsl {
|
||||
future: read_exact(stream, [0]),
|
||||
params: state.params,
|
||||
connector: state.connector,
|
||||
required: state.required,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_reading_ssl<'a>(
|
||||
state: &'a mut RentToOwn<'a, ReadingSsl>,
|
||||
) -> Poll<AfterReadingSsl, Error> {
|
||||
let (stream, buf) = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
|
||||
match buf[0] {
|
||||
b'S' => {
|
||||
let future = match state.params.host() {
|
||||
Host::Tcp(domain) => state.connector.connect(domain, tls::Socket(stream)),
|
||||
Host::Unix(_) => {
|
||||
return Err(error::tls("TLS over unix sockets not supported".into()))
|
||||
}
|
||||
};
|
||||
transition!(ConnectingTls {
|
||||
future,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
b'N' if !state.required => transition!(BuildingStartup {
|
||||
stream: Framed::new(Box::new(stream), PostgresCodec),
|
||||
params: state.params,
|
||||
}),
|
||||
b'N' => Err(error::tls("TLS was required but not supported".into())),
|
||||
_ => Err(bad_response()),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_connecting_tls<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingTls>,
|
||||
) -> Poll<AfterConnectingTls, Error> {
|
||||
let stream = try_ready!(state.future.poll().map_err(error::tls));
|
||||
let state = state.take();
|
||||
transition!(BuildingStartup {
|
||||
stream: Framed::new(stream, PostgresCodec),
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_building_startup<'a>(
|
||||
state: &'a mut RentToOwn<'a, BuildingStartup>,
|
||||
) -> Poll<AfterBuildingStartup, Error> {
|
||||
let state = state.take();
|
||||
|
||||
let user = match state.params.user() {
|
||||
Some(user) => user.clone(),
|
||||
None => {
|
||||
@ -102,10 +209,8 @@ impl PollHandshake for Handshake {
|
||||
)?;
|
||||
}
|
||||
|
||||
let stream = Framed::new(stream, PostgresCodec);
|
||||
|
||||
transition!(SendingStartup {
|
||||
future: stream.send(buf),
|
||||
future: state.stream.send(buf),
|
||||
user,
|
||||
})
|
||||
}
|
||||
@ -298,8 +403,8 @@ impl PollHandshake for Handshake {
|
||||
}
|
||||
|
||||
impl HandshakeFuture {
|
||||
pub fn new(params: ConnectParams) -> HandshakeFuture {
|
||||
Handshake::start(Socket::connect(¶ms), params)
|
||||
pub fn new(params: ConnectParams, tls: TlsMode) -> HandshakeFuture {
|
||||
Handshake::start(Socket::connect(¶ms), params, tls)
|
||||
}
|
||||
}
|
||||
|
||||
|
63
tokio-postgres/src/tls.rs
Normal file
63
tokio-postgres/src/tls.rs
Normal file
@ -0,0 +1,63 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use futures::{Future, Poll};
|
||||
use std::error::Error;
|
||||
use std::io::{self, Read, Write};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use proto;
|
||||
|
||||
pub struct Socket(pub(crate) proto::Socket);
|
||||
|
||||
impl Read for Socket {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Socket {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
self.0.prepare_uninitialized_buffer(buf)
|
||||
}
|
||||
|
||||
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: BufMut,
|
||||
{
|
||||
self.0.read_buf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Socket {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.0.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.0.flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socket {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
self.0.shutdown()
|
||||
}
|
||||
|
||||
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: Buf,
|
||||
{
|
||||
self.0.write_buf(buf)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait TlsConnect {
|
||||
fn connect(
|
||||
&self,
|
||||
domain: &str,
|
||||
socket: Socket,
|
||||
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send>;
|
||||
}
|
||||
|
||||
pub trait TlsStream: 'static + Sync + Send + AsyncRead + AsyncWrite {}
|
||||
|
||||
impl TlsStream for proto::Socket {}
|
@ -6,12 +6,13 @@ use tokio::prelude::*;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::TlsMode;
|
||||
|
||||
fn smoke_test(url: &str) {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake = tokio_postgres::connect(url.parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(url.parse().unwrap(), TlsMode::None);
|
||||
let (mut client, connection) = runtime.block_on(handshake).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
@ -34,7 +35,10 @@ fn plain_password_missing() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake = tokio_postgres::connect("postgres://pass_user@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://pass_user@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
match runtime.block_on(handshake) {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.as_connection().is_some() => {}
|
||||
@ -47,8 +51,10 @@ fn plain_password_wrong() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake =
|
||||
tokio_postgres::connect("postgres://pass_user:foo@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://pass_user:foo@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
match runtime.block_on(handshake) {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
@ -66,7 +72,10 @@ fn md5_password_missing() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake = tokio_postgres::connect("postgres://md5_user@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://md5_user@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
match runtime.block_on(handshake) {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.as_connection().is_some() => {}
|
||||
@ -79,8 +88,10 @@ fn md5_password_wrong() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake =
|
||||
tokio_postgres::connect("postgres://md5_user:foo@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://md5_user:foo@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
match runtime.block_on(handshake) {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
@ -98,8 +109,10 @@ fn scram_password_missing() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake =
|
||||
tokio_postgres::connect("postgres://scram_user@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://scram_user@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
match runtime.block_on(handshake) {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.as_connection().is_some() => {}
|
||||
@ -112,8 +125,10 @@ fn scram_password_wrong() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake =
|
||||
tokio_postgres::connect("postgres://scram_user:foo@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://scram_user:foo@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
match runtime.block_on(handshake) {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
@ -131,7 +146,10 @@ fn pipelined_prepare() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://postgres@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
let (mut client, connection) = runtime.block_on(handshake).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
@ -161,7 +179,10 @@ fn insert_select() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap());
|
||||
let handshake = tokio_postgres::connect(
|
||||
"postgres://postgres@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
);
|
||||
let (mut client, connection) = runtime.block_on(handshake).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user