TLS support

This commit is contained in:
Steven Fackler 2016-12-24 12:21:26 -05:00
parent d8aed0931a
commit de097259a1
6 changed files with 250 additions and 41 deletions

View File

@ -3,6 +3,9 @@ name = "postgres-tokio"
version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
[features]
with-openssl = ["tokio-openssl", "openssl"]
[dependencies]
fallible-iterator = "0.1.3"
futures = "0.1.7"
@ -12,3 +15,6 @@ postgres-protocol = "0.2"
tokio-core = "0.1"
tokio-dns-unofficial = "0.1"
tokio-uds = "0.1"
tokio-openssl = { version = "0.1", optional = true }
openssl = { version = "0.9", optional = true }

View File

@ -7,6 +7,11 @@ extern crate tokio_core;
extern crate tokio_dns;
extern crate tokio_uds;
#[cfg(feature = "tokio-openssl")]
extern crate tokio_openssl;
#[cfg(feature = "openssl")]
extern crate openssl;
use fallible_iterator::FallibleIterator;
use futures::{Future, IntoFuture, BoxFuture, Stream, Sink, Poll, StartSend};
use futures::future::Either;
@ -31,13 +36,21 @@ use error::{ConnectError, Error, DbError};
use params::{ConnectParams, IntoConnectParams};
use stream::PostgresStream;
use types::{Oid, Type, ToSql, SessionInfo, IsNull, FromSql, WrongType};
use tls::Handshake;
pub mod error;
mod stream;
pub mod tls;
#[cfg(test)]
mod test;
pub enum TlsMode {
Require(Box<Handshake>),
Prefer(Box<Handshake>),
None,
}
#[derive(Debug, Copy, Clone)]
pub struct CancelData {
pub process_id: i32,
@ -119,7 +132,10 @@ impl fmt::Debug for Connection {
}
impl Connection {
pub fn connect<T>(params: T, handle: &Handle) -> BoxFuture<Connection, ConnectError>
pub fn connect<T>(params: T,
tls_mode: TlsMode,
handle: &Handle)
-> BoxFuture<Connection, ConnectError>
where T: IntoConnectParams
{
let params = match params.into_connect_params() {
@ -127,8 +143,7 @@ impl Connection {
Err(e) => return futures::failed(ConnectError::ConnectParams(e)).boxed(),
};
stream::connect(params.host(), params.port(), handle)
.map_err(ConnectError::Io)
stream::connect(params.host().clone(), params.port(), tls_mode, handle)
.map(|s| {
let (sender, receiver) = mpsc::channel();
Connection(InnerConnection {

View File

@ -1,6 +1,8 @@
use futures::{BoxFuture, Future, IntoFuture, Async};
use futures::{BoxFuture, Future, IntoFuture, Async, Sink, Stream as FuturesStream};
use futures::future::Either;
use postgres_shared::params::Host;
use postgres_protocol::message::backend::{self, ParseResult};
use postgres_protocol::message::frontend;
use std::io::{self, Read, Write};
use tokio_core::io::{Io, Codec, EasyBuf, Framed};
use tokio_core::net::TcpStream;
@ -8,68 +10,117 @@ use tokio_core::reactor::Handle;
use tokio_dns;
use tokio_uds::UnixStream;
pub type PostgresStream = Framed<InnerStream, PostgresCodec>;
use TlsMode;
use error::ConnectError;
use tls::TlsStream;
pub fn connect(host: &Host,
port: u16,
handle: &Handle)
-> BoxFuture<PostgresStream, io::Error> {
match *host {
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
pub fn connect(host: Host,
port: u16,
tls_mode: TlsMode,
handle: &Handle)
-> BoxFuture<PostgresStream, ConnectError> {
let inner = match host {
Host::Tcp(ref host) => {
tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
.map(|s| InnerStream::Tcp(s).framed(PostgresCodec))
.boxed()
Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
.map(|s| Stream(InnerStream::Tcp(s))))
}
Host::Unix(ref host) => {
let addr = host.join(format!(".s.PGSQL.{}", port));
UnixStream::connect(addr, handle)
.map(|s| InnerStream::Unix(s).framed(PostgresCodec))
.into_future()
.boxed()
Either::B(UnixStream::connect(addr, handle)
.map(|s| Stream(InnerStream::Unix(s)))
.into_future())
}
}
};
let (required, mut handshaker) = match tls_mode {
TlsMode::Require(h) => (true, h),
TlsMode::Prefer(h) => (false, h),
TlsMode::None => {
return inner.map(|s| {
let s: Box<TlsStream> = Box::new(s);
s.framed(PostgresCodec)
})
.map_err(ConnectError::Io)
.boxed()
},
};
inner.map(|s| s.framed(SslCodec))
.and_then(|s| {
let mut buf = vec![];
frontend::ssl_request(&mut buf);
s.send(buf)
})
.and_then(|s| s.into_future().map_err(|e| e.0))
.map_err(ConnectError::Io)
.and_then(move |(m, s)| {
let s = s.into_inner();
match (m, required) {
(Some(b'N'), true) => {
Either::A(Err(ConnectError::Tls("the server does not support TLS".into())).into_future())
}
(Some(b'N'), false) => {
let s: Box<TlsStream> = Box::new(s);
Either::A(Ok(s).into_future())
},
(None, _) => Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))).into_future()),
_ => {
let host = match host {
Host::Tcp(ref host) => host,
Host::Unix(_) => unreachable!(),
};
Either::B(handshaker.handshake(host, s).map_err(ConnectError::Tls))
}
}
})
.map(|s| s.framed(PostgresCodec))
.boxed()
}
pub enum InnerStream {
pub struct Stream(InnerStream);
enum InnerStream {
Tcp(TcpStream),
Unix(UnixStream),
}
impl Read for InnerStream {
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
match self.0 {
InnerStream::Tcp(ref mut s) => s.read(buf),
InnerStream::Unix(ref mut s) => s.read(buf),
}
}
}
impl Write for InnerStream {
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match *self {
match self.0 {
InnerStream::Tcp(ref mut s) => s.write(buf),
InnerStream::Unix(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match *self {
match self.0 {
InnerStream::Tcp(ref mut s) => s.flush(),
InnerStream::Unix(ref mut s) => s.flush(),
}
}
}
impl Io for InnerStream {
impl Io for Stream {
fn poll_read(&mut self) -> Async<()> {
match *self {
match self.0 {
InnerStream::Tcp(ref mut s) => s.poll_read(),
InnerStream::Unix(ref mut s) => s.poll_read(),
}
}
fn poll_write(&mut self) -> Async<()> {
match *self {
match self.0 {
InnerStream::Tcp(ref mut s) => s.poll_write(),
InnerStream::Unix(ref mut s) => s.poll_write(),
}
@ -98,3 +149,25 @@ impl Codec for PostgresCodec {
Ok(())
}
}
struct SslCodec;
impl Codec for SslCodec {
type In = u8;
type Out = Vec<u8>;
fn decode(&mut self, buf: &mut EasyBuf) -> io::Result<Option<u8>> {
if buf.as_slice().is_empty() {
Ok(None)
} else {
let byte = buf.as_slice()[0];
buf.drain_to(1);
Ok(Some(byte))
}
}
fn encode(&mut self, msg: Vec<u8>, buf: &mut Vec<u8>) -> io::Result<()> {
buf.extend_from_slice(&msg);
Ok(())
}
}

View File

@ -10,7 +10,7 @@ use params::ConnectParams;
fn basic() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost", &handle)
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
.then(|c| c.unwrap().close());
l.run(done).unwrap();
}
@ -19,7 +19,9 @@ fn basic() {
fn md5_user() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://md5_user:password@localhost/postgres", &handle);
let done = Connection::connect("postgres://md5_user:password@localhost/postgres",
TlsMode::None,
&handle);
l.run(done).unwrap();
}
@ -27,7 +29,9 @@ fn md5_user() {
fn md5_user_no_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://md5_user@localhost/postgres", &handle);
let done = Connection::connect("postgres://md5_user@localhost/postgres",
TlsMode::None,
&handle);
match l.run(done) {
Err(ConnectError::ConnectParams(_)) => {}
Err(e) => panic!("unexpected error {}", e),
@ -39,7 +43,9 @@ fn md5_user_no_pass() {
fn md5_user_wrong_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres", &handle);
let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres",
TlsMode::None,
&handle);
match l.run(done) {
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {}
Err(e) => panic!("unexpected error {}", e),
@ -51,7 +57,9 @@ fn md5_user_wrong_pass() {
fn pass_user() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://pass_user:password@localhost/postgres", &handle);
let done = Connection::connect("postgres://pass_user:password@localhost/postgres",
TlsMode::None,
&handle);
l.run(done).unwrap();
}
@ -59,7 +67,9 @@ fn pass_user() {
fn pass_user_no_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://pass_user@localhost/postgres", &handle);
let done = Connection::connect("postgres://pass_user@localhost/postgres",
TlsMode::None,
&handle);
match l.run(done) {
Err(ConnectError::ConnectParams(_)) => {}
Err(e) => panic!("unexpected error {}", e),
@ -71,7 +81,9 @@ fn pass_user_no_pass() {
fn pass_user_wrong_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres", &handle);
let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres",
TlsMode::None,
&handle);
match l.run(done) {
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {}
Err(e) => panic!("unexpected error {}", e),
@ -82,7 +94,7 @@ fn pass_user_wrong_pass() {
#[test]
fn batch_execute_ok() {
let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL);"));
l.run(done).unwrap();
}
@ -90,7 +102,7 @@ fn batch_execute_ok() {
#[test]
fn batch_execute_err() {
let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|r| r.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL); \
INSERT INTO foo DEFAULT VALUES;"))
.and_then(|c| c.batch_execute("SELECT * FROM bogo"))
@ -110,7 +122,7 @@ fn batch_execute_err() {
#[test]
fn prepare_execute() {
let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| {
c.unwrap().prepare("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)")
})
@ -127,7 +139,7 @@ fn prepare_execute() {
#[test]
fn query() {
let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| {
c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);
INSERT INTO foo (name) VALUES ('joe'), ('bob')")
@ -149,7 +161,7 @@ fn query() {
#[test]
fn transaction() {
let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);"))
.then(|c| c.unwrap().transaction())
.then(|t| t.unwrap().batch_execute("INSERT INTO foo (name) VALUES ('joe');"))
@ -170,7 +182,7 @@ fn transaction() {
fn unix_socket() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost", &handle)
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
.then(|c| c.unwrap().prepare("SHOW unix_socket_directories"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.then(|r| {
@ -178,8 +190,28 @@ fn unix_socket() {
let params = ConnectParams::builder()
.user("postgres", None)
.build_unix(r[0].get::<String, _>(0));
Connection::connect(params, &handle)
Connection::connect(params, TlsMode::None, &handle)
})
.then(|c| c.unwrap().batch_execute(""));
l.run(done).unwrap();
}
#[cfg(feature = "with-openssl")]
#[test]
fn openssl_required() {
use openssl::ssl::{SslMethod, SslConnectorBuilder};
use tls::openssl::OpenSsl;
let mut builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
builder.builder_mut().set_ca_file("../.travis/server.crt").unwrap();
let negotiator = OpenSsl::from(builder.build());
let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost",
TlsMode::Require(Box::new(negotiator)),
&l.handle())
.then(|c| c.unwrap().prepare("SELECT 1"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| assert_eq!(r[0].get::<i32, _>(0), 1));
l.run(done).unwrap();
}

View File

@ -0,0 +1,33 @@
use futures::BoxFuture;
use std::error::Error;
use tokio_core::io::Io;
pub use stream::Stream;
#[cfg(feature = "with-openssl")]
pub mod openssl;
pub trait TlsStream: Io + Send {
fn get_ref(&self) -> &Stream;
fn get_mut(&mut self) -> &mut Stream;
}
impl Io for Box<TlsStream> {}
impl TlsStream for Stream {
fn get_ref(&self) -> &Stream {
self
}
fn get_mut(&mut self) -> &mut Stream {
self
}
}
pub trait Handshake: 'static + Sync + Send {
fn handshake(&mut self,
host: &str,
stream: Stream)
-> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>>;
}

View File

@ -0,0 +1,50 @@
use futures::{Future, BoxFuture};
use openssl::ssl::{SslMethod, SslConnector, SslConnectorBuilder};
use openssl::error::ErrorStack;
use std::error::Error;
use tokio_openssl::{SslConnectorExt, SslStream};
use tls::{Stream, TlsStream, Handshake};
impl TlsStream for SslStream<Stream> {
fn get_ref(&self) -> &Stream {
self.get_ref().get_ref()
}
fn get_mut(&mut self) -> &mut Stream {
self.get_mut().get_mut()
}
}
pub struct OpenSsl(SslConnector);
impl OpenSsl {
pub fn new() -> Result<OpenSsl, ErrorStack> {
let connector = try!(SslConnectorBuilder::new(SslMethod::tls())).build();
Ok(OpenSsl(connector))
}
}
impl From<SslConnector> for OpenSsl {
fn from(connector: SslConnector) -> OpenSsl {
OpenSsl(connector)
}
}
impl Handshake for OpenSsl {
fn handshake(&mut self,
host: &str,
stream: Stream)
-> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>> {
self.0.connect_async(host, stream)
.map(|s| {
let s: Box<TlsStream> = Box::new(s);
s
})
.map_err(|e| {
let e: Box<Error + Sync + Send> = Box::new(e);
e
})
.boxed()
}
}