//! TLS support for `tokio-postgres` and `postgres` via `openssl`. //! //! # Examples //! //! ```no_run //! use openssl::ssl::{SslConnector, SslMethod}; //! use postgres_openssl::MakeTlsConnector; //! //! # fn main() -> Result<(), Box> { //! let mut builder = SslConnector::builder(SslMethod::tls())?; //! builder.set_ca_file("database_cert.pem")?; //! let connector = MakeTlsConnector::new(builder.build()); //! //! let connect_future = tokio_postgres::connect( //! "host=localhost user=postgres sslmode=require", //! connector, //! ); //! //! // ... //! # Ok(()) //! # } //! ``` //! //! ```no_run //! use openssl::ssl::{SslConnector, SslMethod}; //! use postgres_openssl::MakeTlsConnector; //! //! # fn main() -> Result<(), Box> { //! let mut builder = SslConnector::builder(SslMethod::tls())?; //! builder.set_ca_file("database_cert.pem")?; //! let connector = MakeTlsConnector::new(builder.build()); //! //! let client = postgres::Client::connect( //! "host=localhost user=postgres sslmode=require", //! connector, //! )?; //! //! // ... //! # Ok(()) //! # } //! ``` #![warn(rust_2018_idioms, clippy::all, missing_docs)] #[cfg(feature = "runtime")] use openssl::error::ErrorStack; use openssl::hash::MessageDigest; use openssl::nid::Nid; #[cfg(feature = "runtime")] use openssl::ssl::SslConnector; use openssl::ssl::{self, ConnectConfiguration, SslRef}; use openssl::x509::X509VerifyResult; use std::error::Error; use std::fmt::{self, Debug}; use std::future::Future; use std::io; use std::pin::Pin; #[cfg(feature = "runtime")] use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_openssl::SslStream; use tokio_postgres::tls; #[cfg(feature = "runtime")] use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::{ChannelBinding, TlsConnect}; #[cfg(test)] mod test; /// A `MakeTlsConnect` implementation using the `openssl` crate. /// /// Requires the `runtime` Cargo feature (enabled by default). #[cfg(feature = "runtime")] #[derive(Clone)] pub struct MakeTlsConnector { connector: SslConnector, config: Arc Result<(), ErrorStack> + Sync + Send>, } #[cfg(feature = "runtime")] impl MakeTlsConnector { /// Creates a new connector. pub fn new(connector: SslConnector) -> MakeTlsConnector { MakeTlsConnector { connector, config: Arc::new(|_, _| Ok(())), } } /// Sets a callback used to apply per-connection configuration. /// /// The the callback is provided the domain name along with the `ConnectConfiguration`. pub fn set_callback(&mut self, f: F) where F: Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + 'static + Sync + Send, { self.config = Arc::new(f); } } #[cfg(feature = "runtime")] impl MakeTlsConnect for MakeTlsConnector where S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send, { type Stream = TlsStream; type TlsConnect = TlsConnector; type Error = ErrorStack; fn make_tls_connect(&mut self, domain: &str) -> Result { let mut ssl = self.connector.configure()?; (self.config)(&mut ssl, domain)?; Ok(TlsConnector::new(ssl, domain)) } } /// A `TlsConnect` implementation using the `openssl` crate. pub struct TlsConnector { ssl: ConnectConfiguration, domain: String, } impl TlsConnector { /// Creates a new connector configured to connect to the specified domain. pub fn new(ssl: ConnectConfiguration, domain: &str) -> TlsConnector { TlsConnector { ssl, domain: domain.to_string(), } } } impl TlsConnect for TlsConnector where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Stream = TlsStream; type Error = Box; #[allow(clippy::type_complexity)] type Future = Pin, Self::Error>> + Send>>; fn connect(self, stream: S) -> Self::Future { let future = async move { let ssl = self.ssl.into_ssl(&self.domain)?; let mut stream = SslStream::new(ssl, stream)?; match Pin::new(&mut stream).connect().await { Ok(()) => Ok(TlsStream(stream)), Err(error) => Err(Box::new(ConnectError { error, verify_result: stream.ssl().verify_result(), }) as _), } }; Box::pin(future) } } #[derive(Debug)] struct ConnectError { error: ssl::Error, verify_result: X509VerifyResult, } impl fmt::Display for ConnectError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt::Display::fmt(&self.error, fmt)?; if self.verify_result != X509VerifyResult::OK { fmt.write_str(": ")?; fmt::Display::fmt(&self.verify_result, fmt)?; } Ok(()) } } impl Error for ConnectError { fn source(&self) -> Option<&(dyn Error + 'static)> { Some(&self.error) } } /// The stream returned by `TlsConnector`. pub struct TlsStream(SslStream); impl AsyncRead for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } impl AsyncWrite for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.0).poll_write(cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.0).poll_shutdown(cx) } } impl tls::TlsStream for TlsStream where S: AsyncRead + AsyncWrite + Unpin, { fn channel_binding(&self) -> ChannelBinding { match tls_server_end_point(self.0.ssl()) { Some(buf) => ChannelBinding::tls_server_end_point(buf), None => ChannelBinding::none(), } } } fn tls_server_end_point(ssl: &SslRef) -> Option> { let cert = ssl.peer_certificate()?; let algo_nid = cert.signature_algorithm().object().nid(); let signature_algorithms = algo_nid.signature_algorithms()?; let md = match signature_algorithms.digest { Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(), nid => MessageDigest::from_nid(nid)?, }; cert.digest(md).ok().map(|b| b.to_vec()) }