From dc9d07e2460d43c089fa8914f80cfdf2cc38d73b Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 27 Oct 2019 14:24:25 -0700 Subject: [PATCH] Return a custom TlsStream rather than a ChannelBinding up front --- codegen/src/sqlstate.rs | 3 +- postgres-native-tls/src/lib.rs | 93 ++++++++++++++++--- postgres-openssl/src/lib.rs | 101 ++++++++++++++++++--- tokio-postgres/src/cancel_query_raw.rs | 2 +- tokio-postgres/src/connect_raw.rs | 24 +++-- tokio-postgres/src/connect_tls.rs | 14 +-- tokio-postgres/src/maybe_tls_stream.rs | 14 +++ tokio-postgres/src/tls.rs | 20 +++- tokio-postgres/tests/test/types/uuid_08.rs | 2 +- 9 files changed, 220 insertions(+), 53 deletions(-) diff --git a/codegen/src/sqlstate.rs b/codegen/src/sqlstate.rs index 79a69638..bb21be34 100644 --- a/codegen/src/sqlstate.rs +++ b/codegen/src/sqlstate.rs @@ -92,5 +92,6 @@ fn make_map(codes: &LinkedHashMap>, file: &mut BufWriter = \n{};\n", builder.build() - ).unwrap(); + ) + .unwrap(); } diff --git a/postgres-native-tls/src/lib.rs b/postgres-native-tls/src/lib.rs index f9e67fa0..add9ea8a 100644 --- a/postgres-native-tls/src/lib.rs +++ b/postgres-native-tls/src/lib.rs @@ -7,7 +7,7 @@ //! use postgres_native_tls::MakeTlsConnector; //! use std::fs; //! -//! # fn main() -> Result<(), Box> { +//! # fn main() -> Result<(), Box> { //! let cert = fs::read("database_cert.pem")?; //! let cert = Certificate::from_pem(&cert)?; //! let connector = TlsConnector::builder() @@ -30,7 +30,7 @@ //! use postgres_native_tls::MakeTlsConnector; //! use std::fs; //! -//! # fn main() -> Result<(), Box> { +//! # fn main() -> Result<(), Box> { //! let cert = fs::read("database_cert.pem")?; //! let cert = Certificate::from_pem(&cert)?; //! let connector = TlsConnector::builder() @@ -48,13 +48,16 @@ #![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use futures::task::Context; +use futures::Poll; use std::future::Future; +use std::io; use std::pin::Pin; -use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut}; +use tokio_postgres::tls; #[cfg(feature = "runtime")] use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::{ChannelBinding, TlsConnect}; -use tokio_tls::TlsStream; #[cfg(test)] mod test; @@ -111,20 +114,88 @@ where type Stream = TlsStream; type Error = native_tls::Error; #[allow(clippy::type_complexity)] - type Future = Pin< - Box, ChannelBinding), native_tls::Error>> + Send>, - >; + type Future = Pin, native_tls::Error>> + Send>>; fn connect(self, stream: S) -> Self::Future { let future = async move { let stream = self.connector.connect(&self.domain, stream).await?; - // FIXME https://github.com/tokio-rs/tokio/issues/1383 - let channel_binding = ChannelBinding::none(); - - Ok((stream, channel_binding)) + Ok(TlsStream(stream)) }; Box::pin(future) } } + +/// The stream returned by `TlsConnector`. +pub struct TlsStream(tokio_tls::TlsStream); + +impl AsyncRead for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + + fn poll_read_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + Self: Sized, + { + Pin::new(&mut self.0).poll_read_buf(cx, buf) + } +} + +impl AsyncWrite for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + fn poll_write_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + Self: Sized, + { + Pin::new(&mut self.0).poll_write_buf(cx, buf) + } +} + +impl tls::TlsStream for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + // FIXME https://github.com/tokio-rs/tokio/issues/1383 + ChannelBinding::none() + } +} diff --git a/postgres-openssl/src/lib.rs b/postgres-openssl/src/lib.rs index 2be536d2..a6d27d63 100644 --- a/postgres-openssl/src/lib.rs +++ b/postgres-openssl/src/lib.rs @@ -6,7 +6,7 @@ //! use openssl::ssl::{SslConnector, SslMethod}; //! use postgres_openssl::MakeTlsConnector; //! -//! # fn main() -> Result<(), Box> { +//! # fn main() -> Result<(), Box> { //! let mut builder = SslConnector::builder(SslMethod::tls())?; //! builder.set_ca_file("database_cert.pem")?; //! let connector = MakeTlsConnector::new(builder.build()); @@ -25,7 +25,7 @@ //! use openssl::ssl::{SslConnector, SslMethod}; //! use postgres_openssl::MakeTlsConnector; //! -//! # fn main() -> Result<(), Box> { +//! # fn main() -> Result<(), Box> { //! let mut builder = SslConnector::builder(SslMethod::tls())?; //! builder.set_ca_file("database_cert.pem")?; //! let connector = MakeTlsConnector::new(builder.build()); @@ -42,6 +42,8 @@ #![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")] #![warn(rust_2018_idioms, clippy::all, missing_docs)] +use futures::task::Context; +use futures::Poll; #[cfg(feature = "runtime")] use openssl::error::ErrorStack; use openssl::hash::MessageDigest; @@ -51,11 +53,13 @@ use openssl::ssl::SslConnector; use openssl::ssl::{ConnectConfiguration, SslRef}; use std::fmt::Debug; use std::future::Future; +use std::io; use std::pin::Pin; #[cfg(feature = "runtime")] use std::sync::Arc; -use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut}; use tokio_openssl::{HandshakeError, SslStream}; +use tokio_postgres::tls; #[cfg(feature = "runtime")] use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::{ChannelBinding, TlsConnect}; @@ -99,7 +103,7 @@ impl MakeTlsConnect for MakeTlsConnector where S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send, { - type Stream = SslStream; + type Stream = TlsStream; type TlsConnect = TlsConnector; type Error = ErrorStack; @@ -130,29 +134,96 @@ impl TlsConnect for TlsConnector where S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send, { - type Stream = SslStream; + type Stream = TlsStream; type Error = HandshakeError; #[allow(clippy::type_complexity)] - type Future = Pin< - Box, ChannelBinding), HandshakeError>> + Send>, - >; + type Future = Pin, HandshakeError>> + Send>>; fn connect(self, stream: S) -> Self::Future { let future = async move { let stream = tokio_openssl::connect(self.ssl, &self.domain, stream).await?; - - let channel_binding = match tls_server_end_point(stream.ssl()) { - Some(buf) => ChannelBinding::tls_server_end_point(buf), - None => ChannelBinding::none(), - }; - - Ok((stream, channel_binding)) + Ok(TlsStream(stream)) }; Box::pin(future) } } +/// The stream returned by `TlsConnector`. +pub struct TlsStream(SslStream); + +impl AsyncRead for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.0.prepare_uninitialized_buffer(buf) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + + fn poll_read_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + Self: Sized, + { + Pin::new(&mut self.0).poll_read_buf(cx, buf) + } +} + +impl AsyncWrite for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + + fn poll_write_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + Self: Sized, + { + Pin::new(&mut self.0).poll_write_buf(cx, buf) + } +} + +impl tls::TlsStream for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match tls_server_end_point(self.0.ssl()) { + Some(buf) => ChannelBinding::tls_server_end_point(buf), + None => ChannelBinding::none(), + } + } +} + fn tls_server_end_point(ssl: &SslRef) -> Option> { let cert = ssl.peer_certificate()?; let algo_nid = cert.signature_algorithm().object().nid(); diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs index 0dcdd8ba..c89dc581 100644 --- a/tokio-postgres/src/cancel_query_raw.rs +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -16,7 +16,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let (mut stream, _) = connect_tls::connect_tls(stream, mode, tls).await?; + let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; let mut buf = BytesMut::new(); frontend::cancel_request(process_id, secret_key, &mut buf); diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index b96ced03..90fb4165 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -2,7 +2,7 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCod use crate::config::{self, Config}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; -use crate::tls::{ChannelBinding, TlsConnect}; +use crate::tls::{TlsConnect, TlsStream}; use crate::{Client, Connection, Error}; use bytes::BytesMut; use fallible_iterator::FallibleIterator; @@ -86,7 +86,7 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - let (stream, channel_binding) = connect_tls(stream, config.ssl_mode, tls).await?; + let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), @@ -94,7 +94,7 @@ where }; startup(&mut stream, config).await?; - authenticate(&mut stream, channel_binding, config).await?; + authenticate(&mut stream, config).await?; let (process_id, secret_key, parameters) = read_info(&mut stream).await?; let (sender, receiver) = mpsc::unbounded(); @@ -132,14 +132,10 @@ where .map_err(Error::io) } -async fn authenticate( - stream: &mut StartupStream, - channel_binding: ChannelBinding, - config: &Config, -) -> Result<(), Error> +async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, { match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => { @@ -172,7 +168,7 @@ where authenticate_password(stream, output.as_bytes()).await?; } Some(Message::AuthenticationSasl(body)) => { - authenticate_sasl(stream, body, channel_binding, config).await?; + authenticate_sasl(stream, body, config).await?; } Some(Message::AuthenticationKerberosV5) | Some(Message::AuthenticationScmCredential) @@ -225,12 +221,11 @@ where async fn authenticate_sasl( stream: &mut StartupStream, body: AuthenticationSaslBody, - channel_binding: ChannelBinding, config: &Config, ) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, { let password = config .password @@ -248,7 +243,10 @@ where } } - let channel_binding = channel_binding + let channel_binding = stream + .inner + .get_ref() + .channel_binding() .tls_server_end_point .filter(|_| config.channel_binding != config::ChannelBinding::Disable) .map(sasl::ChannelBinding::tls_server_end_point); diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs index d03357b4..03aaa0bc 100644 --- a/tokio-postgres/src/connect_tls.rs +++ b/tokio-postgres/src/connect_tls.rs @@ -1,7 +1,7 @@ use crate::config::SslMode; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::private::ForcePrivateApi; -use crate::tls::{ChannelBinding, TlsConnect}; +use crate::tls::TlsConnect; use crate::Error; use bytes::BytesMut; use postgres_protocol::message::frontend; @@ -11,15 +11,15 @@ pub async fn connect_tls( mut stream: S, mode: SslMode, tls: T, -) -> Result<(MaybeTlsStream, ChannelBinding), Error> +) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { match mode { - SslMode::Disable => return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())), + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { - return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())) + return Ok(MaybeTlsStream::Raw(stream)) } SslMode::Prefer | SslMode::Require => {} SslMode::__NonExhaustive => unreachable!(), @@ -36,14 +36,14 @@ where if SslMode::Require == mode { return Err(Error::tls("server does not support TLS".into())); } else { - return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())); + return Ok(MaybeTlsStream::Raw(stream)); } } - let (stream, channel_binding) = tls + let stream = tls .connect(stream) .await .map_err(|e| Error::tls(e.into()))?; - Ok((MaybeTlsStream::Tls(stream), channel_binding)) + Ok(MaybeTlsStream::Tls(stream)) } diff --git a/tokio-postgres/src/maybe_tls_stream.rs b/tokio-postgres/src/maybe_tls_stream.rs index 9928cef4..a8f0d3a6 100644 --- a/tokio-postgres/src/maybe_tls_stream.rs +++ b/tokio-postgres/src/maybe_tls_stream.rs @@ -1,3 +1,4 @@ +use crate::tls::{ChannelBinding, TlsStream}; use bytes::{Buf, BufMut}; use std::io; use std::pin::Pin; @@ -93,3 +94,16 @@ where } } } + +impl TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match self { + MaybeTlsStream::Raw(_) => ChannelBinding::none(), + MaybeTlsStream::Tls(s) => s.channel_binding(), + } + } +} diff --git a/tokio-postgres/src/tls.rs b/tokio-postgres/src/tls.rs index 78940f33..4e852d3f 100644 --- a/tokio-postgres/src/tls.rs +++ b/tokio-postgres/src/tls.rs @@ -38,7 +38,7 @@ impl ChannelBinding { #[cfg(feature = "runtime")] pub trait MakeTlsConnect { /// The stream type created by the `TlsConnect` implementation. - type Stream: AsyncRead + AsyncWrite + Unpin; + type Stream: TlsStream + Unpin; /// The `TlsConnect` implementation created by this type. type TlsConnect: TlsConnect; /// The error type returned by the `TlsConnect` implementation. @@ -53,11 +53,11 @@ pub trait MakeTlsConnect { /// An asynchronous function wrapping a stream in a TLS session. pub trait TlsConnect { /// The stream returned by the future. - type Stream: AsyncRead + AsyncWrite + Unpin; + type Stream: TlsStream + Unpin; /// The error returned by the future. type Error: Into>; /// The future returned by the connector. - type Future: Future>; + type Future: Future>; /// Returns a future performing a TLS handshake over the stream. fn connect(self, stream: S) -> Self::Future; @@ -68,6 +68,12 @@ pub trait TlsConnect { } } +/// A TLS-wrapped connection to a PostgreSQL database. +pub trait TlsStream: AsyncRead + AsyncWrite { + /// Returns channel binding information for the session. + fn channel_binding(&self) -> ChannelBinding; +} + /// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error. /// /// This can be used when `sslmode` is `none` or `prefer`. @@ -103,7 +109,7 @@ impl TlsConnect for NoTls { pub struct NoTlsFuture(()); impl Future for NoTlsFuture { - type Output = Result<(NoTlsStream, ChannelBinding), NoTlsError>; + type Output = Result; fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { Poll::Ready(Err(NoTlsError(()))) @@ -139,6 +145,12 @@ impl AsyncWrite for NoTlsStream { } } +impl TlsStream for NoTlsStream { + fn channel_binding(&self) -> ChannelBinding { + match *self {} + } +} + /// The error returned by `NoTls`. #[derive(Debug)] pub struct NoTlsError(()); diff --git a/tokio-postgres/tests/test/types/uuid_08.rs b/tokio-postgres/tests/test/types/uuid_08.rs index 01b674b9..23764378 100644 --- a/tokio-postgres/tests/test/types/uuid_08.rs +++ b/tokio-postgres/tests/test/types/uuid_08.rs @@ -14,5 +14,5 @@ async fn test_uuid_params() { (None, "NULL"), ], ) - .await + .await }