Return a custom TlsStream rather than a ChannelBinding up front

This commit is contained in:
Steven Fackler 2019-10-27 14:24:25 -07:00
parent 6c77baad1b
commit dc9d07e246
9 changed files with 220 additions and 53 deletions

View File

@ -92,5 +92,6 @@ fn make_map(codes: &LinkedHashMap<String, Vec<String>>, file: &mut BufWriter<Fil
#[rustfmt::skip]
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = \n{};\n",
builder.build()
).unwrap();
)
.unwrap();
}

View File

@ -7,7 +7,7 @@
//! use postgres_native_tls::MakeTlsConnector;
//! use std::fs;
//!
//! # fn main() -> Result<(), Box<std::error::Error>> {
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let cert = fs::read("database_cert.pem")?;
//! let cert = Certificate::from_pem(&cert)?;
//! let connector = TlsConnector::builder()
@ -30,7 +30,7 @@
//! use postgres_native_tls::MakeTlsConnector;
//! use std::fs;
//!
//! # fn main() -> Result<(), Box<std::error::Error>> {
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let cert = fs::read("database_cert.pem")?;
//! let cert = Certificate::from_pem(&cert)?;
//! let connector = TlsConnector::builder()
@ -48,13 +48,16 @@
#![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use futures::task::Context;
use futures::Poll;
use std::future::Future;
use std::io;
use std::pin::Pin;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use tokio_tls::TlsStream;
#[cfg(test)]
mod test;
@ -111,20 +114,88 @@ where
type Stream = TlsStream<S>;
type Error = native_tls::Error;
#[allow(clippy::type_complexity)]
type Future = Pin<
Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>,
>;
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
fn connect(self, stream: S) -> Self::Future {
let future = async move {
let stream = self.connector.connect(&self.domain, stream).await?;
// FIXME https://github.com/tokio-rs/tokio/issues/1383
let channel_binding = ChannelBinding::none();
Ok((stream, channel_binding))
Ok(TlsStream(stream))
};
Box::pin(future)
}
}
/// The stream returned by `TlsConnector`.
pub struct TlsStream<S>(tokio_tls::TlsStream<S>);
impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_read_buf(cx, buf)
}
}
impl<S> AsyncWrite for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_write_buf(cx, buf)
}
}
impl<S> tls::TlsStream for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
// FIXME https://github.com/tokio-rs/tokio/issues/1383
ChannelBinding::none()
}
}

View File

@ -6,7 +6,7 @@
//! use openssl::ssl::{SslConnector, SslMethod};
//! use postgres_openssl::MakeTlsConnector;
//!
//! # fn main() -> Result<(), Box<std::error::Error>> {
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
//! builder.set_ca_file("database_cert.pem")?;
//! let connector = MakeTlsConnector::new(builder.build());
@ -25,7 +25,7 @@
//! use openssl::ssl::{SslConnector, SslMethod};
//! use postgres_openssl::MakeTlsConnector;
//!
//! # fn main() -> Result<(), Box<std::error::Error>> {
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
//! builder.set_ca_file("database_cert.pem")?;
//! let connector = MakeTlsConnector::new(builder.build());
@ -42,6 +42,8 @@
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use futures::task::Context;
use futures::Poll;
#[cfg(feature = "runtime")]
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
@ -51,11 +53,13 @@ use openssl::ssl::SslConnector;
use openssl::ssl::{ConnectConfiguration, SslRef};
use std::fmt::Debug;
use std::future::Future;
use std::io;
use std::pin::Pin;
#[cfg(feature = "runtime")]
use std::sync::Arc;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut};
use tokio_openssl::{HandshakeError, SslStream};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
@ -99,7 +103,7 @@ impl<S> MakeTlsConnect<S> for MakeTlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
{
type Stream = SslStream<S>;
type Stream = TlsStream<S>;
type TlsConnect = TlsConnector;
type Error = ErrorStack;
@ -130,29 +134,96 @@ impl<S> TlsConnect<S> for TlsConnector
where
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
{
type Stream = SslStream<S>;
type Stream = TlsStream<S>;
type Error = HandshakeError<S>;
#[allow(clippy::type_complexity)]
type Future = Pin<
Box<dyn Future<Output = Result<(SslStream<S>, ChannelBinding), HandshakeError<S>>> + Send>,
>;
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, HandshakeError<S>>> + Send>>;
fn connect(self, stream: S) -> Self::Future {
let future = async move {
let stream = tokio_openssl::connect(self.ssl, &self.domain, stream).await?;
let channel_binding = match tls_server_end_point(stream.ssl()) {
Some(buf) => ChannelBinding::tls_server_end_point(buf),
None => ChannelBinding::none(),
};
Ok((stream, channel_binding))
Ok(TlsStream(stream))
};
Box::pin(future)
}
}
/// The stream returned by `TlsConnector`.
pub struct TlsStream<S>(SslStream<S>);
impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_read_buf(cx, buf)
}
}
impl<S> AsyncWrite for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_write_buf(cx, buf)
}
}
impl<S> tls::TlsStream for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
match tls_server_end_point(self.0.ssl()) {
Some(buf) => ChannelBinding::tls_server_end_point(buf),
None => ChannelBinding::none(),
}
}
}
fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
let cert = ssl.peer_certificate()?;
let algo_nid = cert.signature_algorithm().object().nid();

View File

@ -16,7 +16,7 @@ where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let (mut stream, _) = connect_tls::connect_tls(stream, mode, tls).await?;
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);

View File

@ -2,7 +2,7 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCod
use crate::config::{self, Config};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{ChannelBinding, TlsConnect};
use crate::tls::{TlsConnect, TlsStream};
use crate::{Client, Connection, Error};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
@ -86,7 +86,7 @@ where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let (stream, channel_binding) = connect_tls(stream, config.ssl_mode, tls).await?;
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
@ -94,7 +94,7 @@ where
};
startup(&mut stream, config).await?;
authenticate(&mut stream, channel_binding, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
let (sender, receiver) = mpsc::unbounded();
@ -132,14 +132,10 @@ where
.map_err(Error::io)
}
async fn authenticate<S, T>(
stream: &mut StartupStream<S, T>,
channel_binding: ChannelBinding,
config: &Config,
) -> Result<(), Error>
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => {
@ -172,7 +168,7 @@ where
authenticate_password(stream, output.as_bytes()).await?;
}
Some(Message::AuthenticationSasl(body)) => {
authenticate_sasl(stream, body, channel_binding, config).await?;
authenticate_sasl(stream, body, config).await?;
}
Some(Message::AuthenticationKerberosV5)
| Some(Message::AuthenticationScmCredential)
@ -225,12 +221,11 @@ where
async fn authenticate_sasl<S, T>(
stream: &mut StartupStream<S, T>,
body: AuthenticationSaslBody,
channel_binding: ChannelBinding,
config: &Config,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
let password = config
.password
@ -248,7 +243,10 @@ where
}
}
let channel_binding = channel_binding
let channel_binding = stream
.inner
.get_ref()
.channel_binding()
.tls_server_end_point
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
.map(sasl::ChannelBinding::tls_server_end_point);

View File

@ -1,7 +1,7 @@
use crate::config::SslMode;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::private::ForcePrivateApi;
use crate::tls::{ChannelBinding, TlsConnect};
use crate::tls::TlsConnect;
use crate::Error;
use bytes::BytesMut;
use postgres_protocol::message::frontend;
@ -11,15 +11,15 @@ pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: T,
) -> Result<(MaybeTlsStream<S, T::Stream>, ChannelBinding), Error>
) -> Result<MaybeTlsStream<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
match mode {
SslMode::Disable => return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())),
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none()))
return Ok(MaybeTlsStream::Raw(stream))
}
SslMode::Prefer | SslMode::Require => {}
SslMode::__NonExhaustive => unreachable!(),
@ -36,14 +36,14 @@ where
if SslMode::Require == mode {
return Err(Error::tls("server does not support TLS".into()));
} else {
return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none()));
return Ok(MaybeTlsStream::Raw(stream));
}
}
let (stream, channel_binding) = tls
let stream = tls
.connect(stream)
.await
.map_err(|e| Error::tls(e.into()))?;
Ok((MaybeTlsStream::Tls(stream), channel_binding))
Ok(MaybeTlsStream::Tls(stream))
}

View File

@ -1,3 +1,4 @@
use crate::tls::{ChannelBinding, TlsStream};
use bytes::{Buf, BufMut};
use std::io;
use std::pin::Pin;
@ -93,3 +94,16 @@ where
}
}
}
impl<S, T> TlsStream for MaybeTlsStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
match self {
MaybeTlsStream::Raw(_) => ChannelBinding::none(),
MaybeTlsStream::Tls(s) => s.channel_binding(),
}
}
}

View File

@ -38,7 +38,7 @@ impl ChannelBinding {
#[cfg(feature = "runtime")]
pub trait MakeTlsConnect<S> {
/// The stream type created by the `TlsConnect` implementation.
type Stream: AsyncRead + AsyncWrite + Unpin;
type Stream: TlsStream + Unpin;
/// The `TlsConnect` implementation created by this type.
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
/// The error type returned by the `TlsConnect` implementation.
@ -53,11 +53,11 @@ pub trait MakeTlsConnect<S> {
/// An asynchronous function wrapping a stream in a TLS session.
pub trait TlsConnect<S> {
/// The stream returned by the future.
type Stream: AsyncRead + AsyncWrite + Unpin;
type Stream: TlsStream + Unpin;
/// The error returned by the future.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// The future returned by the connector.
type Future: Future<Output = Result<(Self::Stream, ChannelBinding), Self::Error>>;
type Future: Future<Output = Result<Self::Stream, Self::Error>>;
/// Returns a future performing a TLS handshake over the stream.
fn connect(self, stream: S) -> Self::Future;
@ -68,6 +68,12 @@ pub trait TlsConnect<S> {
}
}
/// A TLS-wrapped connection to a PostgreSQL database.
pub trait TlsStream: AsyncRead + AsyncWrite {
/// Returns channel binding information for the session.
fn channel_binding(&self) -> ChannelBinding;
}
/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
///
/// This can be used when `sslmode` is `none` or `prefer`.
@ -103,7 +109,7 @@ impl<S> TlsConnect<S> for NoTls {
pub struct NoTlsFuture(());
impl Future for NoTlsFuture {
type Output = Result<(NoTlsStream, ChannelBinding), NoTlsError>;
type Output = Result<NoTlsStream, NoTlsError>;
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(Err(NoTlsError(())))
@ -139,6 +145,12 @@ impl AsyncWrite for NoTlsStream {
}
}
impl TlsStream for NoTlsStream {
fn channel_binding(&self) -> ChannelBinding {
match *self {}
}
}
/// The error returned by `NoTls`.
#[derive(Debug)]
pub struct NoTlsError(());

View File

@ -14,5 +14,5 @@ async fn test_uuid_params() {
(None, "NULL"),
],
)
.await
.await
}