TLS support
This commit is contained in:
parent
d8aed0931a
commit
de097259a1
@ -3,6 +3,9 @@ name = "postgres-tokio"
|
||||
version = "0.1.0"
|
||||
authors = ["Steven Fackler <sfackler@gmail.com>"]
|
||||
|
||||
[features]
|
||||
with-openssl = ["tokio-openssl", "openssl"]
|
||||
|
||||
[dependencies]
|
||||
fallible-iterator = "0.1.3"
|
||||
futures = "0.1.7"
|
||||
@ -12,3 +15,6 @@ postgres-protocol = "0.2"
|
||||
tokio-core = "0.1"
|
||||
tokio-dns-unofficial = "0.1"
|
||||
tokio-uds = "0.1"
|
||||
|
||||
tokio-openssl = { version = "0.1", optional = true }
|
||||
openssl = { version = "0.9", optional = true }
|
||||
|
@ -7,6 +7,11 @@ extern crate tokio_core;
|
||||
extern crate tokio_dns;
|
||||
extern crate tokio_uds;
|
||||
|
||||
#[cfg(feature = "tokio-openssl")]
|
||||
extern crate tokio_openssl;
|
||||
#[cfg(feature = "openssl")]
|
||||
extern crate openssl;
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::{Future, IntoFuture, BoxFuture, Stream, Sink, Poll, StartSend};
|
||||
use futures::future::Either;
|
||||
@ -31,13 +36,21 @@ use error::{ConnectError, Error, DbError};
|
||||
use params::{ConnectParams, IntoConnectParams};
|
||||
use stream::PostgresStream;
|
||||
use types::{Oid, Type, ToSql, SessionInfo, IsNull, FromSql, WrongType};
|
||||
use tls::Handshake;
|
||||
|
||||
pub mod error;
|
||||
mod stream;
|
||||
pub mod tls;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub enum TlsMode {
|
||||
Require(Box<Handshake>),
|
||||
Prefer(Box<Handshake>),
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct CancelData {
|
||||
pub process_id: i32,
|
||||
@ -119,7 +132,10 @@ impl fmt::Debug for Connection {
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn connect<T>(params: T, handle: &Handle) -> BoxFuture<Connection, ConnectError>
|
||||
pub fn connect<T>(params: T,
|
||||
tls_mode: TlsMode,
|
||||
handle: &Handle)
|
||||
-> BoxFuture<Connection, ConnectError>
|
||||
where T: IntoConnectParams
|
||||
{
|
||||
let params = match params.into_connect_params() {
|
||||
@ -127,8 +143,7 @@ impl Connection {
|
||||
Err(e) => return futures::failed(ConnectError::ConnectParams(e)).boxed(),
|
||||
};
|
||||
|
||||
stream::connect(params.host(), params.port(), handle)
|
||||
.map_err(ConnectError::Io)
|
||||
stream::connect(params.host().clone(), params.port(), tls_mode, handle)
|
||||
.map(|s| {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
Connection(InnerConnection {
|
||||
|
@ -1,6 +1,8 @@
|
||||
use futures::{BoxFuture, Future, IntoFuture, Async};
|
||||
use futures::{BoxFuture, Future, IntoFuture, Async, Sink, Stream as FuturesStream};
|
||||
use futures::future::Either;
|
||||
use postgres_shared::params::Host;
|
||||
use postgres_protocol::message::backend::{self, ParseResult};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::io::{self, Read, Write};
|
||||
use tokio_core::io::{Io, Codec, EasyBuf, Framed};
|
||||
use tokio_core::net::TcpStream;
|
||||
@ -8,68 +10,117 @@ use tokio_core::reactor::Handle;
|
||||
use tokio_dns;
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
pub type PostgresStream = Framed<InnerStream, PostgresCodec>;
|
||||
use TlsMode;
|
||||
use error::ConnectError;
|
||||
use tls::TlsStream;
|
||||
|
||||
pub fn connect(host: &Host,
|
||||
port: u16,
|
||||
handle: &Handle)
|
||||
-> BoxFuture<PostgresStream, io::Error> {
|
||||
match *host {
|
||||
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
|
||||
|
||||
pub fn connect(host: Host,
|
||||
port: u16,
|
||||
tls_mode: TlsMode,
|
||||
handle: &Handle)
|
||||
-> BoxFuture<PostgresStream, ConnectError> {
|
||||
let inner = match host {
|
||||
Host::Tcp(ref host) => {
|
||||
tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
|
||||
.map(|s| InnerStream::Tcp(s).framed(PostgresCodec))
|
||||
.boxed()
|
||||
Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
|
||||
.map(|s| Stream(InnerStream::Tcp(s))))
|
||||
}
|
||||
Host::Unix(ref host) => {
|
||||
let addr = host.join(format!(".s.PGSQL.{}", port));
|
||||
UnixStream::connect(addr, handle)
|
||||
.map(|s| InnerStream::Unix(s).framed(PostgresCodec))
|
||||
.into_future()
|
||||
.boxed()
|
||||
Either::B(UnixStream::connect(addr, handle)
|
||||
.map(|s| Stream(InnerStream::Unix(s)))
|
||||
.into_future())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (required, mut handshaker) = match tls_mode {
|
||||
TlsMode::Require(h) => (true, h),
|
||||
TlsMode::Prefer(h) => (false, h),
|
||||
TlsMode::None => {
|
||||
return inner.map(|s| {
|
||||
let s: Box<TlsStream> = Box::new(s);
|
||||
s.framed(PostgresCodec)
|
||||
})
|
||||
.map_err(ConnectError::Io)
|
||||
.boxed()
|
||||
},
|
||||
};
|
||||
|
||||
inner.map(|s| s.framed(SslCodec))
|
||||
.and_then(|s| {
|
||||
let mut buf = vec![];
|
||||
frontend::ssl_request(&mut buf);
|
||||
s.send(buf)
|
||||
})
|
||||
.and_then(|s| s.into_future().map_err(|e| e.0))
|
||||
.map_err(ConnectError::Io)
|
||||
.and_then(move |(m, s)| {
|
||||
let s = s.into_inner();
|
||||
match (m, required) {
|
||||
(Some(b'N'), true) => {
|
||||
Either::A(Err(ConnectError::Tls("the server does not support TLS".into())).into_future())
|
||||
}
|
||||
(Some(b'N'), false) => {
|
||||
let s: Box<TlsStream> = Box::new(s);
|
||||
Either::A(Ok(s).into_future())
|
||||
},
|
||||
(None, _) => Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))).into_future()),
|
||||
_ => {
|
||||
let host = match host {
|
||||
Host::Tcp(ref host) => host,
|
||||
Host::Unix(_) => unreachable!(),
|
||||
};
|
||||
Either::B(handshaker.handshake(host, s).map_err(ConnectError::Tls))
|
||||
}
|
||||
}
|
||||
})
|
||||
.map(|s| s.framed(PostgresCodec))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
pub enum InnerStream {
|
||||
pub struct Stream(InnerStream);
|
||||
|
||||
enum InnerStream {
|
||||
Tcp(TcpStream),
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl Read for InnerStream {
|
||||
impl Read for Stream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match *self {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.read(buf),
|
||||
InnerStream::Unix(ref mut s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for InnerStream {
|
||||
impl Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match *self {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.write(buf),
|
||||
InnerStream::Unix(ref mut s) => s.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match *self {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.flush(),
|
||||
InnerStream::Unix(ref mut s) => s.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Io for InnerStream {
|
||||
impl Io for Stream {
|
||||
fn poll_read(&mut self) -> Async<()> {
|
||||
match *self {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.poll_read(),
|
||||
InnerStream::Unix(ref mut s) => s.poll_read(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write(&mut self) -> Async<()> {
|
||||
match *self {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.poll_write(),
|
||||
InnerStream::Unix(ref mut s) => s.poll_write(),
|
||||
}
|
||||
@ -98,3 +149,25 @@ impl Codec for PostgresCodec {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct SslCodec;
|
||||
|
||||
impl Codec for SslCodec {
|
||||
type In = u8;
|
||||
type Out = Vec<u8>;
|
||||
|
||||
fn decode(&mut self, buf: &mut EasyBuf) -> io::Result<Option<u8>> {
|
||||
if buf.as_slice().is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
let byte = buf.as_slice()[0];
|
||||
buf.drain_to(1);
|
||||
Ok(Some(byte))
|
||||
}
|
||||
}
|
||||
|
||||
fn encode(&mut self, msg: Vec<u8>, buf: &mut Vec<u8>) -> io::Result<()> {
|
||||
buf.extend_from_slice(&msg);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ use params::ConnectParams;
|
||||
fn basic() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &handle)
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
|
||||
.then(|c| c.unwrap().close());
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
@ -19,7 +19,9 @@ fn basic() {
|
||||
fn md5_user() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://md5_user:password@localhost/postgres", &handle);
|
||||
let done = Connection::connect("postgres://md5_user:password@localhost/postgres",
|
||||
TlsMode::None,
|
||||
&handle);
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
@ -27,7 +29,9 @@ fn md5_user() {
|
||||
fn md5_user_no_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://md5_user@localhost/postgres", &handle);
|
||||
let done = Connection::connect("postgres://md5_user@localhost/postgres",
|
||||
TlsMode::None,
|
||||
&handle);
|
||||
match l.run(done) {
|
||||
Err(ConnectError::ConnectParams(_)) => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
@ -39,7 +43,9 @@ fn md5_user_no_pass() {
|
||||
fn md5_user_wrong_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres", &handle);
|
||||
let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres",
|
||||
TlsMode::None,
|
||||
&handle);
|
||||
match l.run(done) {
|
||||
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
@ -51,7 +57,9 @@ fn md5_user_wrong_pass() {
|
||||
fn pass_user() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://pass_user:password@localhost/postgres", &handle);
|
||||
let done = Connection::connect("postgres://pass_user:password@localhost/postgres",
|
||||
TlsMode::None,
|
||||
&handle);
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
@ -59,7 +67,9 @@ fn pass_user() {
|
||||
fn pass_user_no_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://pass_user@localhost/postgres", &handle);
|
||||
let done = Connection::connect("postgres://pass_user@localhost/postgres",
|
||||
TlsMode::None,
|
||||
&handle);
|
||||
match l.run(done) {
|
||||
Err(ConnectError::ConnectParams(_)) => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
@ -71,7 +81,9 @@ fn pass_user_no_pass() {
|
||||
fn pass_user_wrong_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres", &handle);
|
||||
let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres",
|
||||
TlsMode::None,
|
||||
&handle);
|
||||
match l.run(done) {
|
||||
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
@ -82,7 +94,7 @@ fn pass_user_wrong_pass() {
|
||||
#[test]
|
||||
fn batch_execute_ok() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
|
||||
.then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL);"));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
@ -90,7 +102,7 @@ fn batch_execute_ok() {
|
||||
#[test]
|
||||
fn batch_execute_err() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
|
||||
.then(|r| r.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL); \
|
||||
INSERT INTO foo DEFAULT VALUES;"))
|
||||
.and_then(|c| c.batch_execute("SELECT * FROM bogo"))
|
||||
@ -110,7 +122,7 @@ fn batch_execute_err() {
|
||||
#[test]
|
||||
fn prepare_execute() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
|
||||
.then(|c| {
|
||||
c.unwrap().prepare("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)")
|
||||
})
|
||||
@ -127,7 +139,7 @@ fn prepare_execute() {
|
||||
#[test]
|
||||
fn query() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
|
||||
.then(|c| {
|
||||
c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);
|
||||
INSERT INTO foo (name) VALUES ('joe'), ('bob')")
|
||||
@ -149,7 +161,7 @@ fn query() {
|
||||
#[test]
|
||||
fn transaction() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &l.handle())
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
|
||||
.then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);"))
|
||||
.then(|c| c.unwrap().transaction())
|
||||
.then(|t| t.unwrap().batch_execute("INSERT INTO foo (name) VALUES ('joe');"))
|
||||
@ -170,7 +182,7 @@ fn transaction() {
|
||||
fn unix_socket() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://postgres@localhost", &handle)
|
||||
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
|
||||
.then(|c| c.unwrap().prepare("SHOW unix_socket_directories"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.then(|r| {
|
||||
@ -178,8 +190,28 @@ fn unix_socket() {
|
||||
let params = ConnectParams::builder()
|
||||
.user("postgres", None)
|
||||
.build_unix(r[0].get::<String, _>(0));
|
||||
Connection::connect(params, &handle)
|
||||
Connection::connect(params, TlsMode::None, &handle)
|
||||
})
|
||||
.then(|c| c.unwrap().batch_execute(""));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(feature = "with-openssl")]
|
||||
#[test]
|
||||
fn openssl_required() {
|
||||
use openssl::ssl::{SslMethod, SslConnectorBuilder};
|
||||
use tls::openssl::OpenSsl;
|
||||
|
||||
let mut builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
|
||||
builder.builder_mut().set_ca_file("../.travis/server.crt").unwrap();
|
||||
let negotiator = OpenSsl::from(builder.build());
|
||||
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect("postgres://postgres@localhost",
|
||||
TlsMode::Require(Box::new(negotiator)),
|
||||
&l.handle())
|
||||
.then(|c| c.unwrap().prepare("SELECT 1"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.map(|(r, _)| assert_eq!(r[0].get::<i32, _>(0), 1));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
33
postgres-tokio/src/tls/mod.rs
Normal file
33
postgres-tokio/src/tls/mod.rs
Normal file
@ -0,0 +1,33 @@
|
||||
use futures::BoxFuture;
|
||||
use std::error::Error;
|
||||
use tokio_core::io::Io;
|
||||
|
||||
pub use stream::Stream;
|
||||
|
||||
#[cfg(feature = "with-openssl")]
|
||||
pub mod openssl;
|
||||
|
||||
pub trait TlsStream: Io + Send {
|
||||
fn get_ref(&self) -> &Stream;
|
||||
|
||||
fn get_mut(&mut self) -> &mut Stream;
|
||||
}
|
||||
|
||||
impl Io for Box<TlsStream> {}
|
||||
|
||||
impl TlsStream for Stream {
|
||||
fn get_ref(&self) -> &Stream {
|
||||
self
|
||||
}
|
||||
|
||||
fn get_mut(&mut self) -> &mut Stream {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Handshake: 'static + Sync + Send {
|
||||
fn handshake(&mut self,
|
||||
host: &str,
|
||||
stream: Stream)
|
||||
-> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>>;
|
||||
}
|
50
postgres-tokio/src/tls/openssl.rs
Normal file
50
postgres-tokio/src/tls/openssl.rs
Normal file
@ -0,0 +1,50 @@
|
||||
use futures::{Future, BoxFuture};
|
||||
use openssl::ssl::{SslMethod, SslConnector, SslConnectorBuilder};
|
||||
use openssl::error::ErrorStack;
|
||||
use std::error::Error;
|
||||
use tokio_openssl::{SslConnectorExt, SslStream};
|
||||
|
||||
use tls::{Stream, TlsStream, Handshake};
|
||||
|
||||
impl TlsStream for SslStream<Stream> {
|
||||
fn get_ref(&self) -> &Stream {
|
||||
self.get_ref().get_ref()
|
||||
}
|
||||
|
||||
fn get_mut(&mut self) -> &mut Stream {
|
||||
self.get_mut().get_mut()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct OpenSsl(SslConnector);
|
||||
|
||||
impl OpenSsl {
|
||||
pub fn new() -> Result<OpenSsl, ErrorStack> {
|
||||
let connector = try!(SslConnectorBuilder::new(SslMethod::tls())).build();
|
||||
Ok(OpenSsl(connector))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SslConnector> for OpenSsl {
|
||||
fn from(connector: SslConnector) -> OpenSsl {
|
||||
OpenSsl(connector)
|
||||
}
|
||||
}
|
||||
|
||||
impl Handshake for OpenSsl {
|
||||
fn handshake(&mut self,
|
||||
host: &str,
|
||||
stream: Stream)
|
||||
-> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>> {
|
||||
self.0.connect_async(host, stream)
|
||||
.map(|s| {
|
||||
let s: Box<TlsStream> = Box::new(s);
|
||||
s
|
||||
})
|
||||
.map_err(|e| {
|
||||
let e: Box<Error + Sync + Send> = Box::new(e);
|
||||
e
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user