rust-postgres/postgres-openssl/src/lib.rs

210 lines
5.9 KiB
Rust
Raw Normal View History

2019-04-02 01:51:17 +00:00
//! TLS support for `tokio-postgres` and `postgres` via `openssl`.
2019-03-06 06:01:18 +00:00
//!
2019-04-02 01:51:17 +00:00
//! # Examples
2019-03-06 06:01:18 +00:00
//!
//! ```no_run
//! use openssl::ssl::{SslConnector, SslMethod};
//! use postgres_openssl::MakeTlsConnector;
2019-03-06 06:01:18 +00:00
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
2019-04-02 01:51:17 +00:00
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
//! builder.set_ca_file("database_cert.pem")?;
2019-03-06 06:01:18 +00:00
//! let connector = MakeTlsConnector::new(builder.build());
//!
//! let connect_future = tokio_postgres::connect(
//! "host=localhost user=postgres sslmode=require",
//! connector,
//! );
//!
//! // ...
2019-04-02 01:51:17 +00:00
//! # Ok(())
//! # }
//! ```
//!
//! ```no_run
//! use openssl::ssl::{SslConnector, SslMethod};
//! use postgres_openssl::MakeTlsConnector;
2019-04-02 01:51:17 +00:00
//!
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
2019-04-02 01:51:17 +00:00
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
//! builder.set_ca_file("database_cert.pem")?;
//! let connector = MakeTlsConnector::new(builder.build());
//!
//! let client = postgres::Client::connect(
2019-04-02 01:51:17 +00:00
//! "host=localhost user=postgres sslmode=require",
//! connector,
//! )?;
//!
//! // ...
//! # Ok(())
//! # }
2019-03-06 06:01:18 +00:00
//! ```
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
2019-03-06 06:01:18 +00:00
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
2018-06-27 04:00:26 +00:00
2018-12-19 05:39:05 +00:00
#[cfg(feature = "runtime")]
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
use openssl::nid::Nid;
2018-12-19 05:39:05 +00:00
#[cfg(feature = "runtime")]
use openssl::ssl::SslConnector;
2019-08-03 03:49:22 +00:00
use openssl::ssl::{ConnectConfiguration, SslRef};
use std::fmt::Debug;
2019-08-03 03:49:22 +00:00
use std::future::Future;
use std::io;
2019-08-03 03:49:22 +00:00
use std::pin::Pin;
2018-12-19 05:39:05 +00:00
#[cfg(feature = "runtime")]
use std::sync::Arc;
2019-11-30 23:18:50 +00:00
use std::task::{Context, Poll};
2020-10-17 13:49:45 +00:00
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2019-08-03 03:49:22 +00:00
use tokio_openssl::{HandshakeError, SslStream};
use tokio_postgres::tls;
2018-12-19 05:39:05 +00:00
#[cfg(feature = "runtime")]
2019-03-05 05:26:10 +00:00
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
2018-06-27 04:00:26 +00:00
#[cfg(test)]
mod test;
2019-03-06 06:01:18 +00:00
/// A `MakeTlsConnect` implementation using the `openssl` crate.
///
/// Requires the `runtime` Cargo feature (enabled by default).
2018-12-19 05:39:05 +00:00
#[cfg(feature = "runtime")]
#[derive(Clone)]
pub struct MakeTlsConnector {
connector: SslConnector,
2019-03-06 06:01:18 +00:00
config: Arc<dyn Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + Sync + Send>,
2018-12-19 05:39:05 +00:00
}
#[cfg(feature = "runtime")]
impl MakeTlsConnector {
2019-03-06 06:01:18 +00:00
/// Creates a new connector.
2018-12-19 05:39:05 +00:00
pub fn new(connector: SslConnector) -> MakeTlsConnector {
MakeTlsConnector {
connector,
2019-03-06 06:01:18 +00:00
config: Arc::new(|_, _| Ok(())),
2018-12-19 05:39:05 +00:00
}
}
2019-03-06 06:01:18 +00:00
/// Sets a callback used to apply per-connection configuration.
///
/// The the callback is provided the domain name along with the `ConnectConfiguration`.
2018-12-19 05:39:05 +00:00
pub fn set_callback<F>(&mut self, f: F)
where
2019-03-06 06:01:18 +00:00
F: Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + 'static + Sync + Send,
2018-12-19 05:39:05 +00:00
{
self.config = Arc::new(f);
}
}
#[cfg(feature = "runtime")]
impl<S> MakeTlsConnect<S> for MakeTlsConnector
where
2019-08-03 03:49:22 +00:00
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
2018-12-19 05:39:05 +00:00
{
type Stream = TlsStream<S>;
2018-12-19 05:39:05 +00:00
type TlsConnect = TlsConnector;
type Error = ErrorStack;
fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
let mut ssl = self.connector.configure()?;
2019-03-06 06:01:18 +00:00
(self.config)(&mut ssl, domain)?;
Ok(TlsConnector::new(ssl, domain))
2018-12-19 05:39:05 +00:00
}
}
2019-03-06 06:01:18 +00:00
/// A `TlsConnect` implementation using the `openssl` crate.
2018-06-27 04:00:26 +00:00
pub struct TlsConnector {
ssl: ConnectConfiguration,
domain: String,
2018-06-27 04:00:26 +00:00
}
impl TlsConnector {
2019-03-06 06:01:18 +00:00
/// Creates a new connector configured to connect to the specified domain.
pub fn new(ssl: ConnectConfiguration, domain: &str) -> TlsConnector {
2018-06-27 04:00:26 +00:00
TlsConnector {
ssl,
domain: domain.to_string(),
2018-06-27 04:00:26 +00:00
}
}
}
impl<S> TlsConnect<S> for TlsConnector
where
2019-08-03 03:49:22 +00:00
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
{
type Stream = TlsStream<S>;
type Error = HandshakeError<S>;
2019-10-09 22:20:23 +00:00
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, HandshakeError<S>>> + Send>>;
2018-06-27 04:00:26 +00:00
2019-08-03 03:49:22 +00:00
fn connect(self, stream: S) -> Self::Future {
let future = async move {
let stream = tokio_openssl::connect(self.ssl, &self.domain, stream).await?;
Ok(TlsStream(stream))
2018-11-30 05:30:02 +00:00
};
2019-08-03 03:49:22 +00:00
Box::pin(future)
}
}
/// The stream returned by `TlsConnector`.
pub struct TlsStream<S>(SslStream<S>);
impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
2020-10-17 13:49:45 +00:00
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
impl<S> AsyncWrite for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl<S> tls::TlsStream for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
match tls_server_end_point(self.0.ssl()) {
Some(buf) => ChannelBinding::tls_server_end_point(buf),
None => ChannelBinding::none(),
}
}
}
fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
let cert = ssl.peer_certificate()?;
let algo_nid = cert.signature_algorithm().object().nid();
let signature_algorithms = algo_nid.signature_algorithms()?;
let md = match signature_algorithms.digest {
Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
nid => MessageDigest::from_nid(nid)?,
};
cert.digest(md).ok().map(|b| b.to_vec())
}