142 lines
3.6 KiB
Rust
142 lines
3.6 KiB
Rust
extern crate bytes;
|
|
extern crate futures;
|
|
extern crate openssl;
|
|
extern crate tokio_io;
|
|
extern crate tokio_openssl;
|
|
extern crate tokio_postgres;
|
|
|
|
#[cfg(test)]
|
|
extern crate tokio;
|
|
|
|
use bytes::{Buf, BufMut};
|
|
use futures::{Future, IntoFuture, Poll};
|
|
use openssl::error::ErrorStack;
|
|
use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslRef};
|
|
use std::error::Error;
|
|
use std::io::{self, Read, Write};
|
|
use tokio_io::{AsyncRead, AsyncWrite};
|
|
use tokio_openssl::ConnectConfigurationExt;
|
|
use tokio_postgres::tls::{Socket, TlsConnect, TlsStream};
|
|
|
|
#[cfg(test)]
|
|
mod test;
|
|
|
|
pub struct TlsConnector {
|
|
connector: SslConnector,
|
|
callback: Box<Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
|
|
}
|
|
|
|
impl TlsConnector {
|
|
pub fn new() -> Result<TlsConnector, ErrorStack> {
|
|
let connector = SslConnector::builder(SslMethod::tls())?.build();
|
|
Ok(TlsConnector::with_connector(connector))
|
|
}
|
|
|
|
pub fn with_connector(connector: SslConnector) -> TlsConnector {
|
|
TlsConnector {
|
|
connector,
|
|
callback: Box::new(|_| Ok(())),
|
|
}
|
|
}
|
|
|
|
pub fn set_callback<F>(&mut self, f: F)
|
|
where
|
|
F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
|
|
{
|
|
self.callback = Box::new(f);
|
|
}
|
|
}
|
|
|
|
impl TlsConnect for TlsConnector {
|
|
fn connect(
|
|
&self,
|
|
domain: &str,
|
|
socket: Socket,
|
|
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send> {
|
|
let f = self
|
|
.connector
|
|
.configure()
|
|
.and_then(|mut ssl| (self.callback)(&mut ssl).map(|_| ssl))
|
|
.map_err(|e| {
|
|
let e: Box<Error + Sync + Send> = Box::new(e);
|
|
e
|
|
})
|
|
.into_future()
|
|
.and_then({
|
|
let domain = domain.to_string();
|
|
move |ssl| {
|
|
ssl.connect_async(&domain, socket)
|
|
.map(|s| {
|
|
let s: Box<TlsStream> = Box::new(SslStream(s));
|
|
s
|
|
})
|
|
.map_err(|e| {
|
|
let e: Box<Error + Sync + Send> = Box::new(e);
|
|
e
|
|
})
|
|
}
|
|
});
|
|
Box::new(f)
|
|
}
|
|
}
|
|
|
|
struct SslStream(tokio_openssl::SslStream<Socket>);
|
|
|
|
impl Read for SslStream {
|
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
|
self.0.read(buf)
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for SslStream {
|
|
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
|
self.0.prepare_uninitialized_buffer(buf)
|
|
}
|
|
|
|
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
|
where
|
|
B: BufMut,
|
|
{
|
|
self.0.read_buf(buf)
|
|
}
|
|
}
|
|
|
|
impl Write for SslStream {
|
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
self.0.write(buf)
|
|
}
|
|
|
|
fn flush(&mut self) -> io::Result<()> {
|
|
self.0.flush()
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for SslStream {
|
|
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
|
self.0.shutdown()
|
|
}
|
|
|
|
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
|
where
|
|
B: Buf,
|
|
{
|
|
self.0.write_buf(buf)
|
|
}
|
|
}
|
|
|
|
impl TlsStream for SslStream {
|
|
fn tls_unique(&self) -> Option<Vec<u8>> {
|
|
let f = if self.0.get_ref().ssl().session_reused() {
|
|
SslRef::peer_finished
|
|
} else {
|
|
SslRef::finished
|
|
};
|
|
|
|
let len = f(self.0.get_ref().ssl(), &mut []);
|
|
let mut buf = vec![0; len];
|
|
f(self.0.get_ref().ssl(), &mut buf);
|
|
|
|
Some(buf)
|
|
}
|
|
}
|