Return a custom TlsStream rather than a ChannelBinding up front
This commit is contained in:
parent
6c77baad1b
commit
dc9d07e246
@ -92,5 +92,6 @@ fn make_map(codes: &LinkedHashMap<String, Vec<String>>, file: &mut BufWriter<Fil
|
|||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = \n{};\n",
|
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = \n{};\n",
|
||||||
builder.build()
|
builder.build()
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
//! use postgres_native_tls::MakeTlsConnector;
|
//! use postgres_native_tls::MakeTlsConnector;
|
||||||
//! use std::fs;
|
//! use std::fs;
|
||||||
//!
|
//!
|
||||||
//! # fn main() -> Result<(), Box<std::error::Error>> {
|
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
//! let cert = fs::read("database_cert.pem")?;
|
//! let cert = fs::read("database_cert.pem")?;
|
||||||
//! let cert = Certificate::from_pem(&cert)?;
|
//! let cert = Certificate::from_pem(&cert)?;
|
||||||
//! let connector = TlsConnector::builder()
|
//! let connector = TlsConnector::builder()
|
||||||
@ -30,7 +30,7 @@
|
|||||||
//! use postgres_native_tls::MakeTlsConnector;
|
//! use postgres_native_tls::MakeTlsConnector;
|
||||||
//! use std::fs;
|
//! use std::fs;
|
||||||
//!
|
//!
|
||||||
//! # fn main() -> Result<(), Box<std::error::Error>> {
|
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
//! let cert = fs::read("database_cert.pem")?;
|
//! let cert = fs::read("database_cert.pem")?;
|
||||||
//! let cert = Certificate::from_pem(&cert)?;
|
//! let cert = Certificate::from_pem(&cert)?;
|
||||||
//! let connector = TlsConnector::builder()
|
//! let connector = TlsConnector::builder()
|
||||||
@ -48,13 +48,16 @@
|
|||||||
#![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")]
|
#![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")]
|
||||||
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
||||||
|
|
||||||
|
use futures::task::Context;
|
||||||
|
use futures::Poll;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut};
|
||||||
|
use tokio_postgres::tls;
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
use tokio_postgres::tls::MakeTlsConnect;
|
use tokio_postgres::tls::MakeTlsConnect;
|
||||||
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
||||||
use tokio_tls::TlsStream;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
@ -111,20 +114,88 @@ where
|
|||||||
type Stream = TlsStream<S>;
|
type Stream = TlsStream<S>;
|
||||||
type Error = native_tls::Error;
|
type Error = native_tls::Error;
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
type Future = Pin<
|
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, native_tls::Error>> + Send>>;
|
||||||
Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
fn connect(self, stream: S) -> Self::Future {
|
fn connect(self, stream: S) -> Self::Future {
|
||||||
let future = async move {
|
let future = async move {
|
||||||
let stream = self.connector.connect(&self.domain, stream).await?;
|
let stream = self.connector.connect(&self.domain, stream).await?;
|
||||||
|
|
||||||
// FIXME https://github.com/tokio-rs/tokio/issues/1383
|
Ok(TlsStream(stream))
|
||||||
let channel_binding = ChannelBinding::none();
|
|
||||||
|
|
||||||
Ok((stream, channel_binding))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Box::pin(future)
|
Box::pin(future)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The stream returned by `TlsConnector`.
|
||||||
|
pub struct TlsStream<S>(tokio_tls::TlsStream<S>);
|
||||||
|
|
||||||
|
impl<S> AsyncRead for TlsStream<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||||
|
self.0.prepare_uninitialized_buffer(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read_buf<B: BufMut>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
Pin::new(&mut self.0).poll_read_buf(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> AsyncWrite for TlsStream<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.0).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.0).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_buf<B: Buf>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
Pin::new(&mut self.0).poll_write_buf(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> tls::TlsStream for TlsStream<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
fn channel_binding(&self) -> ChannelBinding {
|
||||||
|
// FIXME https://github.com/tokio-rs/tokio/issues/1383
|
||||||
|
ChannelBinding::none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
//! use openssl::ssl::{SslConnector, SslMethod};
|
//! use openssl::ssl::{SslConnector, SslMethod};
|
||||||
//! use postgres_openssl::MakeTlsConnector;
|
//! use postgres_openssl::MakeTlsConnector;
|
||||||
//!
|
//!
|
||||||
//! # fn main() -> Result<(), Box<std::error::Error>> {
|
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
|
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
|
||||||
//! builder.set_ca_file("database_cert.pem")?;
|
//! builder.set_ca_file("database_cert.pem")?;
|
||||||
//! let connector = MakeTlsConnector::new(builder.build());
|
//! let connector = MakeTlsConnector::new(builder.build());
|
||||||
@ -25,7 +25,7 @@
|
|||||||
//! use openssl::ssl::{SslConnector, SslMethod};
|
//! use openssl::ssl::{SslConnector, SslMethod};
|
||||||
//! use postgres_openssl::MakeTlsConnector;
|
//! use postgres_openssl::MakeTlsConnector;
|
||||||
//!
|
//!
|
||||||
//! # fn main() -> Result<(), Box<std::error::Error>> {
|
//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
|
//! let mut builder = SslConnector::builder(SslMethod::tls())?;
|
||||||
//! builder.set_ca_file("database_cert.pem")?;
|
//! builder.set_ca_file("database_cert.pem")?;
|
||||||
//! let connector = MakeTlsConnector::new(builder.build());
|
//! let connector = MakeTlsConnector::new(builder.build());
|
||||||
@ -42,6 +42,8 @@
|
|||||||
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
|
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
|
||||||
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
||||||
|
|
||||||
|
use futures::task::Context;
|
||||||
|
use futures::Poll;
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
use openssl::error::ErrorStack;
|
use openssl::error::ErrorStack;
|
||||||
use openssl::hash::MessageDigest;
|
use openssl::hash::MessageDigest;
|
||||||
@ -51,11 +53,13 @@ use openssl::ssl::SslConnector;
|
|||||||
use openssl::ssl::{ConnectConfiguration, SslRef};
|
use openssl::ssl::{ConnectConfiguration, SslRef};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite, Buf, BufMut};
|
||||||
use tokio_openssl::{HandshakeError, SslStream};
|
use tokio_openssl::{HandshakeError, SslStream};
|
||||||
|
use tokio_postgres::tls;
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
use tokio_postgres::tls::MakeTlsConnect;
|
use tokio_postgres::tls::MakeTlsConnect;
|
||||||
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
||||||
@ -99,7 +103,7 @@ impl<S> MakeTlsConnect<S> for MakeTlsConnector
|
|||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
|
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
|
||||||
{
|
{
|
||||||
type Stream = SslStream<S>;
|
type Stream = TlsStream<S>;
|
||||||
type TlsConnect = TlsConnector;
|
type TlsConnect = TlsConnector;
|
||||||
type Error = ErrorStack;
|
type Error = ErrorStack;
|
||||||
|
|
||||||
@ -130,29 +134,96 @@ impl<S> TlsConnect<S> for TlsConnector
|
|||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
|
S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
|
||||||
{
|
{
|
||||||
type Stream = SslStream<S>;
|
type Stream = TlsStream<S>;
|
||||||
type Error = HandshakeError<S>;
|
type Error = HandshakeError<S>;
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
type Future = Pin<
|
type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, HandshakeError<S>>> + Send>>;
|
||||||
Box<dyn Future<Output = Result<(SslStream<S>, ChannelBinding), HandshakeError<S>>> + Send>,
|
|
||||||
>;
|
|
||||||
|
|
||||||
fn connect(self, stream: S) -> Self::Future {
|
fn connect(self, stream: S) -> Self::Future {
|
||||||
let future = async move {
|
let future = async move {
|
||||||
let stream = tokio_openssl::connect(self.ssl, &self.domain, stream).await?;
|
let stream = tokio_openssl::connect(self.ssl, &self.domain, stream).await?;
|
||||||
|
Ok(TlsStream(stream))
|
||||||
let channel_binding = match tls_server_end_point(stream.ssl()) {
|
|
||||||
Some(buf) => ChannelBinding::tls_server_end_point(buf),
|
|
||||||
None => ChannelBinding::none(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok((stream, channel_binding))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Box::pin(future)
|
Box::pin(future)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The stream returned by `TlsConnector`.
|
||||||
|
pub struct TlsStream<S>(SslStream<S>);
|
||||||
|
|
||||||
|
impl<S> AsyncRead for TlsStream<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||||
|
self.0.prepare_uninitialized_buffer(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read_buf<B: BufMut>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
Pin::new(&mut self.0).poll_read_buf(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> AsyncWrite for TlsStream<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
fn poll_write(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Pin::new(&mut self.0).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.0).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Pin::new(&mut self.0).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_buf<B: Buf>(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut B,
|
||||||
|
) -> Poll<io::Result<usize>>
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
Pin::new(&mut self.0).poll_write_buf(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> tls::TlsStream for TlsStream<S>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
fn channel_binding(&self) -> ChannelBinding {
|
||||||
|
match tls_server_end_point(self.0.ssl()) {
|
||||||
|
Some(buf) => ChannelBinding::tls_server_end_point(buf),
|
||||||
|
None => ChannelBinding::none(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
|
fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
|
||||||
let cert = ssl.peer_certificate()?;
|
let cert = ssl.peer_certificate()?;
|
||||||
let algo_nid = cert.signature_algorithm().object().nid();
|
let algo_nid = cert.signature_algorithm().object().nid();
|
||||||
|
@ -16,7 +16,7 @@ where
|
|||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
T: TlsConnect<S>,
|
T: TlsConnect<S>,
|
||||||
{
|
{
|
||||||
let (mut stream, _) = connect_tls::connect_tls(stream, mode, tls).await?;
|
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
|
||||||
|
|
||||||
let mut buf = BytesMut::new();
|
let mut buf = BytesMut::new();
|
||||||
frontend::cancel_request(process_id, secret_key, &mut buf);
|
frontend::cancel_request(process_id, secret_key, &mut buf);
|
||||||
|
@ -2,7 +2,7 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCod
|
|||||||
use crate::config::{self, Config};
|
use crate::config::{self, Config};
|
||||||
use crate::connect_tls::connect_tls;
|
use crate::connect_tls::connect_tls;
|
||||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||||
use crate::tls::{ChannelBinding, TlsConnect};
|
use crate::tls::{TlsConnect, TlsStream};
|
||||||
use crate::{Client, Connection, Error};
|
use crate::{Client, Connection, Error};
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use fallible_iterator::FallibleIterator;
|
use fallible_iterator::FallibleIterator;
|
||||||
@ -86,7 +86,7 @@ where
|
|||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
T: TlsConnect<S>,
|
T: TlsConnect<S>,
|
||||||
{
|
{
|
||||||
let (stream, channel_binding) = connect_tls(stream, config.ssl_mode, tls).await?;
|
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
|
||||||
|
|
||||||
let mut stream = StartupStream {
|
let mut stream = StartupStream {
|
||||||
inner: Framed::new(stream, PostgresCodec),
|
inner: Framed::new(stream, PostgresCodec),
|
||||||
@ -94,7 +94,7 @@ where
|
|||||||
};
|
};
|
||||||
|
|
||||||
startup(&mut stream, config).await?;
|
startup(&mut stream, config).await?;
|
||||||
authenticate(&mut stream, channel_binding, config).await?;
|
authenticate(&mut stream, config).await?;
|
||||||
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
||||||
|
|
||||||
let (sender, receiver) = mpsc::unbounded();
|
let (sender, receiver) = mpsc::unbounded();
|
||||||
@ -132,14 +132,10 @@ where
|
|||||||
.map_err(Error::io)
|
.map_err(Error::io)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authenticate<S, T>(
|
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||||
stream: &mut StartupStream<S, T>,
|
|
||||||
channel_binding: ChannelBinding,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<(), Error>
|
|
||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
T: TlsStream + Unpin,
|
||||||
{
|
{
|
||||||
match stream.try_next().await.map_err(Error::io)? {
|
match stream.try_next().await.map_err(Error::io)? {
|
||||||
Some(Message::AuthenticationOk) => {
|
Some(Message::AuthenticationOk) => {
|
||||||
@ -172,7 +168,7 @@ where
|
|||||||
authenticate_password(stream, output.as_bytes()).await?;
|
authenticate_password(stream, output.as_bytes()).await?;
|
||||||
}
|
}
|
||||||
Some(Message::AuthenticationSasl(body)) => {
|
Some(Message::AuthenticationSasl(body)) => {
|
||||||
authenticate_sasl(stream, body, channel_binding, config).await?;
|
authenticate_sasl(stream, body, config).await?;
|
||||||
}
|
}
|
||||||
Some(Message::AuthenticationKerberosV5)
|
Some(Message::AuthenticationKerberosV5)
|
||||||
| Some(Message::AuthenticationScmCredential)
|
| Some(Message::AuthenticationScmCredential)
|
||||||
@ -225,12 +221,11 @@ where
|
|||||||
async fn authenticate_sasl<S, T>(
|
async fn authenticate_sasl<S, T>(
|
||||||
stream: &mut StartupStream<S, T>,
|
stream: &mut StartupStream<S, T>,
|
||||||
body: AuthenticationSaslBody,
|
body: AuthenticationSaslBody,
|
||||||
channel_binding: ChannelBinding,
|
|
||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<(), Error>
|
) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
T: TlsStream + Unpin,
|
||||||
{
|
{
|
||||||
let password = config
|
let password = config
|
||||||
.password
|
.password
|
||||||
@ -248,7 +243,10 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let channel_binding = channel_binding
|
let channel_binding = stream
|
||||||
|
.inner
|
||||||
|
.get_ref()
|
||||||
|
.channel_binding()
|
||||||
.tls_server_end_point
|
.tls_server_end_point
|
||||||
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
|
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
|
||||||
.map(sasl::ChannelBinding::tls_server_end_point);
|
.map(sasl::ChannelBinding::tls_server_end_point);
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::config::SslMode;
|
use crate::config::SslMode;
|
||||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||||
use crate::tls::private::ForcePrivateApi;
|
use crate::tls::private::ForcePrivateApi;
|
||||||
use crate::tls::{ChannelBinding, TlsConnect};
|
use crate::tls::TlsConnect;
|
||||||
use crate::Error;
|
use crate::Error;
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use postgres_protocol::message::frontend;
|
use postgres_protocol::message::frontend;
|
||||||
@ -11,15 +11,15 @@ pub async fn connect_tls<S, T>(
|
|||||||
mut stream: S,
|
mut stream: S,
|
||||||
mode: SslMode,
|
mode: SslMode,
|
||||||
tls: T,
|
tls: T,
|
||||||
) -> Result<(MaybeTlsStream<S, T::Stream>, ChannelBinding), Error>
|
) -> Result<MaybeTlsStream<S, T::Stream>, Error>
|
||||||
where
|
where
|
||||||
S: AsyncRead + AsyncWrite + Unpin,
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
T: TlsConnect<S>,
|
T: TlsConnect<S>,
|
||||||
{
|
{
|
||||||
match mode {
|
match mode {
|
||||||
SslMode::Disable => return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none())),
|
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
|
||||||
SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
|
SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
|
||||||
return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none()))
|
return Ok(MaybeTlsStream::Raw(stream))
|
||||||
}
|
}
|
||||||
SslMode::Prefer | SslMode::Require => {}
|
SslMode::Prefer | SslMode::Require => {}
|
||||||
SslMode::__NonExhaustive => unreachable!(),
|
SslMode::__NonExhaustive => unreachable!(),
|
||||||
@ -36,14 +36,14 @@ where
|
|||||||
if SslMode::Require == mode {
|
if SslMode::Require == mode {
|
||||||
return Err(Error::tls("server does not support TLS".into()));
|
return Err(Error::tls("server does not support TLS".into()));
|
||||||
} else {
|
} else {
|
||||||
return Ok((MaybeTlsStream::Raw(stream), ChannelBinding::none()));
|
return Ok(MaybeTlsStream::Raw(stream));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (stream, channel_binding) = tls
|
let stream = tls
|
||||||
.connect(stream)
|
.connect(stream)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::tls(e.into()))?;
|
.map_err(|e| Error::tls(e.into()))?;
|
||||||
|
|
||||||
Ok((MaybeTlsStream::Tls(stream), channel_binding))
|
Ok(MaybeTlsStream::Tls(stream))
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use crate::tls::{ChannelBinding, TlsStream};
|
||||||
use bytes::{Buf, BufMut};
|
use bytes::{Buf, BufMut};
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
@ -93,3 +94,16 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<S, T> TlsStream for MaybeTlsStream<S, T>
|
||||||
|
where
|
||||||
|
S: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
T: TlsStream + Unpin,
|
||||||
|
{
|
||||||
|
fn channel_binding(&self) -> ChannelBinding {
|
||||||
|
match self {
|
||||||
|
MaybeTlsStream::Raw(_) => ChannelBinding::none(),
|
||||||
|
MaybeTlsStream::Tls(s) => s.channel_binding(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -38,7 +38,7 @@ impl ChannelBinding {
|
|||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
pub trait MakeTlsConnect<S> {
|
pub trait MakeTlsConnect<S> {
|
||||||
/// The stream type created by the `TlsConnect` implementation.
|
/// The stream type created by the `TlsConnect` implementation.
|
||||||
type Stream: AsyncRead + AsyncWrite + Unpin;
|
type Stream: TlsStream + Unpin;
|
||||||
/// The `TlsConnect` implementation created by this type.
|
/// The `TlsConnect` implementation created by this type.
|
||||||
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
|
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
|
||||||
/// The error type returned by the `TlsConnect` implementation.
|
/// The error type returned by the `TlsConnect` implementation.
|
||||||
@ -53,11 +53,11 @@ pub trait MakeTlsConnect<S> {
|
|||||||
/// An asynchronous function wrapping a stream in a TLS session.
|
/// An asynchronous function wrapping a stream in a TLS session.
|
||||||
pub trait TlsConnect<S> {
|
pub trait TlsConnect<S> {
|
||||||
/// The stream returned by the future.
|
/// The stream returned by the future.
|
||||||
type Stream: AsyncRead + AsyncWrite + Unpin;
|
type Stream: TlsStream + Unpin;
|
||||||
/// The error returned by the future.
|
/// The error returned by the future.
|
||||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||||
/// The future returned by the connector.
|
/// The future returned by the connector.
|
||||||
type Future: Future<Output = Result<(Self::Stream, ChannelBinding), Self::Error>>;
|
type Future: Future<Output = Result<Self::Stream, Self::Error>>;
|
||||||
|
|
||||||
/// Returns a future performing a TLS handshake over the stream.
|
/// Returns a future performing a TLS handshake over the stream.
|
||||||
fn connect(self, stream: S) -> Self::Future;
|
fn connect(self, stream: S) -> Self::Future;
|
||||||
@ -68,6 +68,12 @@ pub trait TlsConnect<S> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A TLS-wrapped connection to a PostgreSQL database.
|
||||||
|
pub trait TlsStream: AsyncRead + AsyncWrite {
|
||||||
|
/// Returns channel binding information for the session.
|
||||||
|
fn channel_binding(&self) -> ChannelBinding;
|
||||||
|
}
|
||||||
|
|
||||||
/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
|
/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
|
||||||
///
|
///
|
||||||
/// This can be used when `sslmode` is `none` or `prefer`.
|
/// This can be used when `sslmode` is `none` or `prefer`.
|
||||||
@ -103,7 +109,7 @@ impl<S> TlsConnect<S> for NoTls {
|
|||||||
pub struct NoTlsFuture(());
|
pub struct NoTlsFuture(());
|
||||||
|
|
||||||
impl Future for NoTlsFuture {
|
impl Future for NoTlsFuture {
|
||||||
type Output = Result<(NoTlsStream, ChannelBinding), NoTlsError>;
|
type Output = Result<NoTlsStream, NoTlsError>;
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
Poll::Ready(Err(NoTlsError(())))
|
Poll::Ready(Err(NoTlsError(())))
|
||||||
@ -139,6 +145,12 @@ impl AsyncWrite for NoTlsStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TlsStream for NoTlsStream {
|
||||||
|
fn channel_binding(&self) -> ChannelBinding {
|
||||||
|
match *self {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// The error returned by `NoTls`.
|
/// The error returned by `NoTls`.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct NoTlsError(());
|
pub struct NoTlsError(());
|
||||||
|
@ -14,5 +14,5 @@ async fn test_uuid_params() {
|
|||||||
(None, "NULL"),
|
(None, "NULL"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user