Move the TLS mode into config
This commit is contained in:
parent
dfc614bed1
commit
2d3b9bb1c6
@ -3,7 +3,7 @@ use std::io::{self, Read};
|
||||
use tokio_postgres::types::{ToSql, Type};
|
||||
use tokio_postgres::Error;
|
||||
#[cfg(feature = "runtime")]
|
||||
use tokio_postgres::{MakeTlsMode, Socket, TlsMode};
|
||||
use tokio_postgres::{MakeTlsConnect, Socket, TlsConnect};
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Config;
|
||||
@ -15,10 +15,10 @@ impl Client {
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn connect<T>(params: &str, tls_mode: T) -> Result<Client, Error>
|
||||
where
|
||||
T: MakeTlsMode<Socket> + 'static + Send,
|
||||
T::TlsMode: Send,
|
||||
T: MakeTlsConnect<Socket> + 'static + Send,
|
||||
T::TlsConnect: Send,
|
||||
T::Stream: Send,
|
||||
<T::TlsMode as TlsMode<Socket>>::Future: Send,
|
||||
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
|
||||
{
|
||||
params.parse::<Config>()?.connect(tls_mode)
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ use log::error;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tokio_postgres::{Error, MakeTlsMode, Socket, TargetSessionAttrs, TlsMode};
|
||||
use tokio_postgres::{Error, MakeTlsConnect, Socket, TargetSessionAttrs, TlsConnect};
|
||||
|
||||
use crate::{Client, RUNTIME};
|
||||
|
||||
@ -94,10 +94,10 @@ impl Config {
|
||||
|
||||
pub fn connect<T>(&self, tls_mode: T) -> Result<Client, Error>
|
||||
where
|
||||
T: MakeTlsMode<Socket> + 'static + Send,
|
||||
T::TlsMode: Send,
|
||||
T: MakeTlsConnect<Socket> + 'static + Send,
|
||||
T::TlsConnect: Send,
|
||||
T::Stream: Send,
|
||||
<T::TlsMode as TlsMode<Socket>>::Future: Send,
|
||||
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
|
||||
{
|
||||
let connect = self.0.connect(tls_mode);
|
||||
let (client, connection) = oneshot::spawn(connect, &RUNTIME.executor()).wait()?;
|
||||
|
@ -2,13 +2,13 @@ use futures::{Future, Stream};
|
||||
use native_tls::{self, Certificate};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
|
||||
use tokio_postgres::TlsConnect;
|
||||
|
||||
use crate::TlsConnector;
|
||||
|
||||
fn smoke_test<T>(s: &str, tls: T)
|
||||
where
|
||||
T: TlsMode<TcpStream>,
|
||||
T: TlsConnect<TcpStream>,
|
||||
T::Stream: 'static,
|
||||
{
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
@ -44,8 +44,8 @@ fn require() {
|
||||
.build()
|
||||
.unwrap();
|
||||
smoke_test(
|
||||
"user=ssl_user dbname=postgres",
|
||||
RequireTls(TlsConnector::with_connector(connector, "localhost")),
|
||||
"user=ssl_user dbname=postgres sslmode=require",
|
||||
TlsConnector::with_connector(connector, "localhost"),
|
||||
);
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ fn prefer() {
|
||||
.unwrap();
|
||||
smoke_test(
|
||||
"user=ssl_user dbname=postgres",
|
||||
PreferTls(TlsConnector::with_connector(connector, "localhost")),
|
||||
TlsConnector::with_connector(connector, "localhost"),
|
||||
);
|
||||
}
|
||||
|
||||
@ -72,7 +72,7 @@ fn scram_user() {
|
||||
.build()
|
||||
.unwrap();
|
||||
smoke_test(
|
||||
"user=scram_user password=password dbname=postgres",
|
||||
RequireTls(TlsConnector::with_connector(connector, "localhost")),
|
||||
"user=scram_user password=password dbname=postgres sslmode=require",
|
||||
TlsConnector::with_connector(connector, "localhost"),
|
||||
);
|
||||
}
|
||||
|
@ -2,13 +2,13 @@ use futures::{Future, Stream};
|
||||
use openssl::ssl::{SslConnector, SslMethod};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
|
||||
use tokio_postgres::TlsConnect;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn smoke_test<T>(s: &str, tls: T)
|
||||
where
|
||||
T: TlsMode<TcpStream>,
|
||||
T: TlsConnect<TcpStream>,
|
||||
T::Stream: 'static,
|
||||
{
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
@ -41,8 +41,8 @@ fn require() {
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
"user=ssl_user dbname=postgres",
|
||||
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
"user=ssl_user dbname=postgres sslmode=require",
|
||||
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
|
||||
);
|
||||
}
|
||||
|
||||
@ -53,7 +53,7 @@ fn prefer() {
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
"user=ssl_user dbname=postgres",
|
||||
PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
|
||||
);
|
||||
}
|
||||
|
||||
@ -63,8 +63,8 @@ fn scram_user() {
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
"user=scram_user password=password dbname=postgres",
|
||||
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
"user=scram_user password=password dbname=postgres sslmode=require",
|
||||
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
|
||||
);
|
||||
}
|
||||
|
||||
@ -78,8 +78,8 @@ fn runtime() {
|
||||
let connector = MakeTlsConnector::new(builder.build());
|
||||
|
||||
let connect = tokio_postgres::connect(
|
||||
"host=localhost port=5433 user=postgres",
|
||||
RequireTls(connector),
|
||||
"host=localhost port=5433 user=postgres sslmode=require",
|
||||
connector,
|
||||
);
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
|
@ -49,7 +49,6 @@ postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
|
||||
state_machine_future = "0.1.7"
|
||||
tokio-codec = "0.1"
|
||||
tokio-io = "0.1"
|
||||
void = "1.0"
|
||||
|
||||
tokio-tcp = { version = "0.1", optional = true }
|
||||
futures-cpupool = { version = "0.1", optional = true }
|
||||
|
@ -19,8 +19,8 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use crate::proto::ConnectFuture;
|
||||
use crate::proto::ConnectRawFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::{Connect, MakeTlsMode, Socket};
|
||||
use crate::{ConnectRaw, Error, TlsMode};
|
||||
use crate::{Connect, MakeTlsConnect, Socket};
|
||||
use crate::{ConnectRaw, Error, TlsConnect};
|
||||
|
||||
/// Properties required of a session.
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -34,6 +34,17 @@ pub enum TargetSessionAttrs {
|
||||
__NonExhaustive,
|
||||
}
|
||||
|
||||
/// TLS configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub enum SslMode {
|
||||
/// Do not use TLS.
|
||||
Disable,
|
||||
/// Attempt to connect with TLS but allow sessions without.
|
||||
Prefer,
|
||||
/// Require the use of TLS.
|
||||
Require,
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) enum Host {
|
||||
@ -49,6 +60,7 @@ pub(crate) struct Inner {
|
||||
pub(crate) dbname: Option<String>,
|
||||
pub(crate) options: Option<String>,
|
||||
pub(crate) application_name: Option<String>,
|
||||
pub(crate) ssl_mode: SslMode,
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) host: Vec<Host>,
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -79,6 +91,8 @@ pub(crate) struct Inner {
|
||||
/// * `dbname` - The name of the database to connect to. Defaults to the username.
|
||||
/// * `options` - Command line options used to configure the server.
|
||||
/// * `application_name` - Sets the `application_name` parameter on the server.
|
||||
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
|
||||
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
|
||||
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
|
||||
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
|
||||
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
|
||||
@ -152,6 +166,7 @@ impl Config {
|
||||
dbname: None,
|
||||
options: None,
|
||||
application_name: None,
|
||||
ssl_mode: SslMode::Prefer,
|
||||
#[cfg(feature = "runtime")]
|
||||
host: vec![],
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -204,6 +219,14 @@ impl Config {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
|
||||
Arc::make_mut(&mut self.0).ssl_mode = ssl_mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a host to the configuration.
|
||||
///
|
||||
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
|
||||
@ -320,6 +343,15 @@ impl Config {
|
||||
"application_name" => {
|
||||
self.application_name(&value);
|
||||
}
|
||||
"sslmode" => {
|
||||
let mode = match value {
|
||||
"disable" => SslMode::Disable,
|
||||
"prefer" => SslMode::Prefer,
|
||||
"require" => SslMode::Require,
|
||||
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
|
||||
};
|
||||
self.ssl_mode(mode);
|
||||
}
|
||||
#[cfg(feature = "runtime")]
|
||||
"host" => {
|
||||
for host in value.split(',') {
|
||||
@ -390,22 +422,22 @@ impl Config {
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
|
||||
pub fn connect<T>(&self, tls: T) -> Connect<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
|
||||
Connect(ConnectFuture::new(tls, Ok(self.clone())))
|
||||
}
|
||||
|
||||
/// Connects to a PostgreSQL database over an arbitrary stream.
|
||||
///
|
||||
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored.
|
||||
pub fn connect_raw<S, T>(&self, stream: S, tls_mode: T) -> ConnectRaw<S, T>
|
||||
pub fn connect_raw<S, T>(&self, stream: S, tls: T) -> ConnectRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
ConnectRaw(ConnectRawFuture::new(stream, tls_mode, self.clone(), None))
|
||||
ConnectRaw(ConnectRawFuture::new(stream, tls, self.clone(), None))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -127,11 +127,11 @@ fn next_portal() -> String {
|
||||
///
|
||||
/// [`Config`]: ./Config.t.html
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn connect<T>(config: &str, tls_mode: T) -> Connect<T>
|
||||
pub fn connect<T>(config: &str, tls: T) -> Connect<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
Connect(proto::ConnectFuture::new(tls_mode, config.parse()))
|
||||
Connect(proto::ConnectFuture::new(tls, config.parse()))
|
||||
}
|
||||
|
||||
/// An asynchronous PostgreSQL client.
|
||||
@ -250,7 +250,7 @@ impl Client {
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn cancel_query<T>(&mut self, make_tls_mode: T) -> CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
CancelQuery(self.0.cancel_query(make_tls_mode))
|
||||
}
|
||||
@ -260,7 +260,7 @@ impl Client {
|
||||
pub fn cancel_query_raw<S, T>(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode))
|
||||
}
|
||||
@ -291,11 +291,12 @@ impl Client {
|
||||
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
|
||||
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Connection<S>(proto::Connection<S>);
|
||||
pub struct Connection<S, T>(proto::Connection<proto::MaybeTlsStream<S, T>>);
|
||||
|
||||
impl<S> Connection<S>
|
||||
impl<S, T> Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
/// Returns the value of a runtime parameter for this connection.
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
@ -311,9 +312,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Future for Connection<S>
|
||||
impl<S, T> Future for Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = Error;
|
||||
@ -342,12 +344,12 @@ pub enum AsyncMessage {
|
||||
pub struct CancelQueryRaw<S, T>(proto::CancelQueryRawFuture<S, T>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>;
|
||||
T: TlsConnect<S>;
|
||||
|
||||
impl<S, T> Future for CancelQueryRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = Error;
|
||||
@ -361,12 +363,12 @@ where
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct CancelQuery<T>(proto::CancelQueryFuture<T>)
|
||||
where
|
||||
T: MakeTlsMode<Socket>;
|
||||
T: MakeTlsConnect<Socket>;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T> Future for CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = Error;
|
||||
@ -380,17 +382,17 @@ where
|
||||
pub struct ConnectRaw<S, T>(proto::ConnectRawFuture<S, T>)
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>;
|
||||
T: TlsConnect<S>;
|
||||
|
||||
impl<S, T> Future for ConnectRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
type Item = (Client, Connection<T::Stream>);
|
||||
type Item = (Client, Connection<S, T::Stream>);
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(Client, Connection<T::Stream>), Error> {
|
||||
fn poll(&mut self) -> Poll<(Client, Connection<S, T::Stream>), Error> {
|
||||
let (client, connection) = try_ready!(self.0.poll());
|
||||
|
||||
Ok(Async::Ready((Client(client), Connection(connection))))
|
||||
@ -401,17 +403,17 @@ where
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Connect<T>(proto::ConnectFuture<T>)
|
||||
where
|
||||
T: MakeTlsMode<Socket>;
|
||||
T: MakeTlsConnect<Socket>;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T> Future for Connect<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
type Item = (Client, Connection<T::Stream>);
|
||||
type Item = (Client, Connection<Socket, T::Stream>);
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(Client, Connection<T::Stream>), Error> {
|
||||
fn poll(&mut self) -> Poll<(Client, Connection<Socket, T::Stream>), Error> {
|
||||
let (client, connection) = try_ready!(self.0.poll());
|
||||
|
||||
Ok(Async::Ready((Client(client), Connection(connection))))
|
||||
|
@ -3,16 +3,16 @@ use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::io;
|
||||
|
||||
use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture};
|
||||
use crate::{Config, Error, Host, MakeTlsMode, Socket};
|
||||
use crate::{Config, Error, Host, MakeTlsConnect, Socket, SslMode};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(ConnectingSocket))]
|
||||
Start {
|
||||
make_tls_mode: T,
|
||||
tls: T,
|
||||
idx: Option<usize>,
|
||||
config: Config,
|
||||
process_id: i32,
|
||||
@ -21,13 +21,14 @@ where
|
||||
#[state_machine_future(transitions(Canceling))]
|
||||
ConnectingSocket {
|
||||
future: ConnectSocketFuture,
|
||||
tls_mode: T::TlsMode,
|
||||
mode: SslMode,
|
||||
tls: T::TlsConnect,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Canceling {
|
||||
future: CancelQueryRawFuture<Socket, T::TlsMode>,
|
||||
future: CancelQueryRawFuture<Socket, T::TlsConnect>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished(()),
|
||||
@ -37,7 +38,7 @@ where
|
||||
|
||||
impl<T> PollCancelQuery<T> for CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
@ -52,14 +53,15 @@ where
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
let tls_mode = state
|
||||
.make_tls_mode
|
||||
.make_tls_mode(hostname)
|
||||
let tls = state
|
||||
.tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
transition!(ConnectingSocket {
|
||||
mode: state.config.0.ssl_mode,
|
||||
future: ConnectSocketFuture::new(state.config, idx),
|
||||
tls_mode,
|
||||
tls,
|
||||
process_id: state.process_id,
|
||||
secret_key: state.secret_key,
|
||||
})
|
||||
@ -74,7 +76,8 @@ where
|
||||
transition!(Canceling {
|
||||
future: CancelQueryRawFuture::new(
|
||||
socket,
|
||||
state.tls_mode,
|
||||
state.mode,
|
||||
state.tls,
|
||||
state.process_id,
|
||||
state.secret_key
|
||||
),
|
||||
@ -91,15 +94,15 @@ where
|
||||
|
||||
impl<T> CancelQueryFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
pub fn new(
|
||||
make_tls_mode: T,
|
||||
tls: T,
|
||||
idx: Option<usize>,
|
||||
config: Config,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> CancelQueryFuture<T> {
|
||||
CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key)
|
||||
CancelQuery::start(tls, idx, config, process_id, secret_key)
|
||||
}
|
||||
}
|
||||
|
@ -5,14 +5,14 @@ use tokio_io::io::{self, Flush, WriteAll};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::proto::TlsFuture;
|
||||
use crate::TlsMode;
|
||||
use crate::proto::{MaybeTlsStream, TlsFuture};
|
||||
use crate::{SslMode, TlsConnect};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum CancelQueryRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(SendingCancel))]
|
||||
Start {
|
||||
@ -22,10 +22,12 @@ where
|
||||
},
|
||||
#[state_machine_future(transitions(FlushingCancel))]
|
||||
SendingCancel {
|
||||
future: WriteAll<T::Stream, Vec<u8>>,
|
||||
future: WriteAll<MaybeTlsStream<S, T::Stream>, Vec<u8>>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
FlushingCancel { future: Flush<T::Stream> },
|
||||
FlushingCancel {
|
||||
future: Flush<MaybeTlsStream<S, T::Stream>>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished(()),
|
||||
#[state_machine_future(error)]
|
||||
@ -35,7 +37,7 @@ where
|
||||
impl<S, T> PollCancelQueryRaw<S, T> for CancelQueryRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
|
||||
let (stream, _) = try_ready!(state.future.poll());
|
||||
@ -69,14 +71,15 @@ where
|
||||
impl<S, T> CancelQueryRawFuture<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
pub fn new(
|
||||
stream: S,
|
||||
tls_mode: T,
|
||||
mode: SslMode,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> CancelQueryRawFuture<S, T> {
|
||||
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key)
|
||||
CancelQueryRaw::start(TlsFuture::new(stream, mode, tls), process_id, secret_key)
|
||||
}
|
||||
}
|
||||
|
@ -25,9 +25,9 @@ use crate::proto::statement::Statement;
|
||||
use crate::proto::CancelQueryFuture;
|
||||
use crate::proto::CancelQueryRawFuture;
|
||||
use crate::types::{IsNull, Oid, ToSql, Type};
|
||||
use crate::{Config, Error, TlsMode};
|
||||
use crate::{Config, Error, TlsConnect};
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::{MakeTlsMode, Socket};
|
||||
use crate::{MakeTlsConnect, Socket};
|
||||
|
||||
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
|
||||
|
||||
@ -247,7 +247,7 @@ impl Client {
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn cancel_query<T>(&self, make_tls_mode: T) -> CancelQueryFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
CancelQueryFuture::new(
|
||||
make_tls_mode,
|
||||
@ -258,12 +258,18 @@ impl Client {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cancel_query_raw<S, T>(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture<S, T>
|
||||
pub fn cancel_query_raw<S, T>(&self, stream: S, mode: T) -> CancelQueryRawFuture<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key)
|
||||
CancelQueryRawFuture::new(
|
||||
stream,
|
||||
self.0.config.0.ssl_mode,
|
||||
mode,
|
||||
self.0.process_id,
|
||||
self.0.secret_key,
|
||||
)
|
||||
}
|
||||
|
||||
fn close(&self, ty: u8, name: &str) {
|
||||
|
@ -1,35 +1,35 @@
|
||||
use futures::{Async, Future, Poll};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
|
||||
use crate::proto::{Client, ConnectOnceFuture, Connection};
|
||||
use crate::{Config, Error, Host, MakeTlsMode, Socket};
|
||||
use crate::proto::{Client, ConnectOnceFuture, Connection, MaybeTlsStream};
|
||||
use crate::{Config, Error, Host, MakeTlsConnect, Socket};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Connect<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(Connecting))]
|
||||
Start {
|
||||
make_tls_mode: T,
|
||||
tls: T,
|
||||
config: Result<Config, Error>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Connecting {
|
||||
future: ConnectOnceFuture<T::TlsMode>,
|
||||
future: ConnectOnceFuture<T::TlsConnect>,
|
||||
idx: usize,
|
||||
make_tls_mode: T,
|
||||
tls: T,
|
||||
config: Config,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
|
||||
impl<T> PollConnect<T> for Connect<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
@ -50,15 +50,15 @@ where
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
let tls_mode = state
|
||||
.make_tls_mode
|
||||
.make_tls_mode(hostname)
|
||||
let tls = state
|
||||
.tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
transition!(Connecting {
|
||||
future: ConnectOnceFuture::new(0, tls_mode, config.clone()),
|
||||
future: ConnectOnceFuture::new(0, tls, config.clone()),
|
||||
idx: 0,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
tls: state.tls,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -84,13 +84,12 @@ where
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
let tls_mode = state
|
||||
.make_tls_mode
|
||||
.make_tls_mode(hostname)
|
||||
let tls = state
|
||||
.tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
state.future =
|
||||
ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone());
|
||||
state.future = ConnectOnceFuture::new(state.idx, tls, state.config.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -99,9 +98,9 @@ where
|
||||
|
||||
impl<T> ConnectFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
pub fn new(make_tls_mode: T, config: Result<Config, Error>) -> ConnectFuture<T> {
|
||||
Connect::start(make_tls_mode, config)
|
||||
pub fn new(tls: T, config: Result<Config, Error>) -> ConnectFuture<T> {
|
||||
Connect::start(tls, config)
|
||||
}
|
||||
}
|
||||
|
@ -4,25 +4,23 @@ use futures::{try_ready, Async, Future, Poll, Stream};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::io;
|
||||
|
||||
use crate::proto::{Client, ConnectRawFuture, ConnectSocketFuture, Connection, SimpleQueryStream};
|
||||
use crate::{Config, Error, Socket, TargetSessionAttrs, TlsMode};
|
||||
use crate::proto::{
|
||||
Client, ConnectRawFuture, ConnectSocketFuture, Connection, MaybeTlsStream, SimpleQueryStream,
|
||||
};
|
||||
use crate::{Config, Error, Socket, TargetSessionAttrs, TlsConnect};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum ConnectOnce<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: TlsConnect<Socket>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(ConnectingSocket))]
|
||||
Start {
|
||||
idx: usize,
|
||||
tls_mode: T,
|
||||
config: Config,
|
||||
},
|
||||
Start { idx: usize, tls: T, config: Config },
|
||||
#[state_machine_future(transitions(ConnectingRaw))]
|
||||
ConnectingSocket {
|
||||
future: ConnectSocketFuture,
|
||||
idx: usize,
|
||||
tls_mode: T,
|
||||
tls: T,
|
||||
config: Config,
|
||||
},
|
||||
#[state_machine_future(transitions(CheckingSessionAttrs, Finished))]
|
||||
@ -34,17 +32,17 @@ where
|
||||
CheckingSessionAttrs {
|
||||
stream: SimpleQueryStream,
|
||||
client: Client,
|
||||
connection: Connection<T::Stream>,
|
||||
connection: Connection<MaybeTlsStream<Socket, T::Stream>>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
|
||||
impl<T> PollConnectOnce<T> for ConnectOnce<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: TlsConnect<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let state = state.take();
|
||||
@ -52,7 +50,7 @@ where
|
||||
transition!(ConnectingSocket {
|
||||
future: ConnectSocketFuture::new(state.config.clone(), state.idx),
|
||||
idx: state.idx,
|
||||
tls_mode: state.tls_mode,
|
||||
tls: state.tls,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
@ -65,7 +63,7 @@ where
|
||||
|
||||
transition!(ConnectingRaw {
|
||||
target_session_attrs: state.config.0.target_session_attrs,
|
||||
future: ConnectRawFuture::new(socket, state.tls_mode, state.config, Some(state.idx)),
|
||||
future: ConnectRawFuture::new(socket, state.tls, state.config, Some(state.idx)),
|
||||
})
|
||||
}
|
||||
|
||||
@ -111,9 +109,9 @@ where
|
||||
|
||||
impl<T> ConnectOnceFuture<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
T: TlsConnect<Socket>,
|
||||
{
|
||||
pub fn new(idx: usize, tls_mode: T, config: Config) -> ConnectOnceFuture<T> {
|
||||
ConnectOnce::start(idx, tls_mode, config)
|
||||
pub fn new(idx: usize, tls: T, config: Config) -> ConnectOnceFuture<T> {
|
||||
ConnectOnce::start(idx, tls, config)
|
||||
}
|
||||
}
|
||||
|
@ -11,14 +11,14 @@ use std::collections::HashMap;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::{Client, Connection, PostgresCodec, TlsFuture};
|
||||
use crate::{ChannelBinding, Config, Error, TlsMode};
|
||||
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
|
||||
use crate::{ChannelBinding, Config, Error, TlsConnect};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum ConnectRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(SendingStartup))]
|
||||
Start {
|
||||
@ -28,47 +28,47 @@ where
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuth))]
|
||||
SendingStartup {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
future: sink::Send<Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
|
||||
ReadingAuth {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuthCompletion))]
|
||||
SendingPassword {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
future: sink::Send<Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingSasl))]
|
||||
SendingSasl {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
future: sink::Send<Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>>,
|
||||
scram: ScramSha256,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
|
||||
ReadingSasl {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
|
||||
scram: ScramSha256,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo))]
|
||||
ReadingAuthCompletion {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
ReadingInfo {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
parameters: HashMap<String, String>,
|
||||
@ -76,7 +76,7 @@ where
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
Finished((Client, Connection<MaybeTlsStream<S, T::Stream>>)),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
@ -84,7 +84,7 @@ where
|
||||
impl<S, T> PollConnectRaw<S, T> for ConnectRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
|
||||
let (stream, channel_binding) = try_ready!(state.future.poll());
|
||||
@ -377,14 +377,9 @@ where
|
||||
impl<S, T> ConnectRawFuture<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
pub fn new(
|
||||
stream: S,
|
||||
tls_mode: T,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
) -> ConnectRawFuture<S, T> {
|
||||
ConnectRaw::start(TlsFuture::new(stream, tls_mode), config, idx)
|
||||
pub fn new(stream: S, tls: T, config: Config, idx: Option<usize>) -> ConnectRawFuture<S, T> {
|
||||
ConnectRaw::start(TlsFuture::new(stream, config.0.ssl_mode, tls), config, idx)
|
||||
}
|
||||
}
|
||||
|
88
tokio-postgres/src/proto/maybe_tls_stream.rs
Normal file
88
tokio-postgres/src/proto/maybe_tls_stream.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use futures::Poll;
|
||||
use std::io::{self, Read, Write};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub enum MaybeTlsStream<T, U> {
|
||||
Raw(T),
|
||||
Tls(U),
|
||||
}
|
||||
|
||||
impl<T, U> Read for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: Read,
|
||||
U: Read,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.read(buf),
|
||||
MaybeTlsStream::Tls(s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncRead for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf),
|
||||
MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: BufMut,
|
||||
{
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.read_buf(buf),
|
||||
MaybeTlsStream::Tls(s) => s.read_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Write for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: Write,
|
||||
U: Write,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.write(buf),
|
||||
MaybeTlsStream::Tls(s) => s.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.flush(),
|
||||
MaybeTlsStream::Tls(s) => s.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncWrite for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: AsyncWrite,
|
||||
U: AsyncWrite,
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.shutdown(),
|
||||
MaybeTlsStream::Tls(s) => s.shutdown(),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: Buf,
|
||||
{
|
||||
match self {
|
||||
MaybeTlsStream::Raw(s) => s.write_buf(buf),
|
||||
MaybeTlsStream::Tls(s) => s.write_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
@ -36,6 +36,7 @@ mod copy_in;
|
||||
mod copy_out;
|
||||
mod execute;
|
||||
mod idle;
|
||||
mod maybe_tls_stream;
|
||||
mod portal;
|
||||
mod prepare;
|
||||
mod query;
|
||||
@ -64,6 +65,7 @@ pub use crate::proto::connection::Connection;
|
||||
pub use crate::proto::copy_in::CopyInFuture;
|
||||
pub use crate::proto::copy_out::CopyOutStream;
|
||||
pub use crate::proto::execute::ExecuteFuture;
|
||||
pub use crate::proto::maybe_tls_stream::MaybeTlsStream;
|
||||
pub use crate::proto::portal::Portal;
|
||||
pub use crate::proto::prepare::PrepareFuture;
|
||||
pub use crate::proto::query::QueryStream;
|
||||
|
@ -4,54 +4,65 @@ use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use tokio_io::io::{self, ReadExact, WriteAll};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::{ChannelBinding, Error, TlsMode};
|
||||
use crate::proto::MaybeTlsStream;
|
||||
use crate::tls::private::ForcePrivateApi;
|
||||
use crate::{ChannelBinding, Error, SslMode, TlsConnect};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Tls<S, T>
|
||||
where
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
#[state_machine_future(start, transitions(SendingTls, ConnectingTls))]
|
||||
Start { stream: S, tls_mode: T },
|
||||
#[state_machine_future(start, transitions(SendingTls, Ready))]
|
||||
Start { stream: S, mode: SslMode, tls: T },
|
||||
#[state_machine_future(transitions(ReadingTls))]
|
||||
SendingTls {
|
||||
future: WriteAll<S, Vec<u8>>,
|
||||
tls_mode: T,
|
||||
mode: SslMode,
|
||||
tls: T,
|
||||
},
|
||||
#[state_machine_future(transitions(ConnectingTls))]
|
||||
#[state_machine_future(transitions(ConnectingTls, Ready))]
|
||||
ReadingTls {
|
||||
future: ReadExact<S, [u8; 1]>,
|
||||
tls_mode: T,
|
||||
mode: SslMode,
|
||||
tls: T,
|
||||
},
|
||||
#[state_machine_future(transitions(Ready))]
|
||||
ConnectingTls { future: T::Future },
|
||||
#[state_machine_future(ready)]
|
||||
Ready((T::Stream, ChannelBinding)),
|
||||
Ready((MaybeTlsStream<S, T::Stream>, ChannelBinding)),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
|
||||
impl<S, T> PollTls<S, T> for Tls<S, T>
|
||||
where
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
|
||||
let state = state.take();
|
||||
|
||||
if state.tls_mode.request_tls() {
|
||||
let mut buf = vec![];
|
||||
frontend::ssl_request(&mut buf);
|
||||
match state.mode {
|
||||
SslMode::Disable => transition!(Ready((
|
||||
MaybeTlsStream::Raw(state.stream),
|
||||
ChannelBinding::none()
|
||||
))),
|
||||
SslMode::Prefer if !state.tls.can_connect(ForcePrivateApi) => transition!(Ready((
|
||||
MaybeTlsStream::Raw(state.stream),
|
||||
ChannelBinding::none()
|
||||
))),
|
||||
SslMode::Prefer | SslMode::Require => {
|
||||
let mut buf = vec![];
|
||||
frontend::ssl_request(&mut buf);
|
||||
|
||||
transition!(SendingTls {
|
||||
future: io::write_all(state.stream, buf),
|
||||
tls_mode: state.tls_mode,
|
||||
})
|
||||
} else {
|
||||
transition!(ConnectingTls {
|
||||
future: state.tls_mode.handle_tls(false, state.stream),
|
||||
})
|
||||
transition!(SendingTls {
|
||||
future: io::write_all(state.stream, buf),
|
||||
mode: state.mode,
|
||||
tls: state.tls,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,7 +73,8 @@ where
|
||||
let state = state.take();
|
||||
transition!(ReadingTls {
|
||||
future: io::read_exact(stream, [0]),
|
||||
tls_mode: state.tls_mode,
|
||||
mode: state.mode,
|
||||
tls: state.tls,
|
||||
})
|
||||
}
|
||||
|
||||
@ -72,26 +84,32 @@ where
|
||||
let (stream, buf) = try_ready!(state.future.poll().map_err(Error::io));
|
||||
let state = state.take();
|
||||
|
||||
let use_tls = buf[0] == b'S';
|
||||
transition!(ConnectingTls {
|
||||
future: state.tls_mode.handle_tls(use_tls, stream)
|
||||
})
|
||||
if buf[0] == b'S' {
|
||||
transition!(ConnectingTls {
|
||||
future: state.tls.connect(stream),
|
||||
})
|
||||
} else if state.mode == SslMode::Require {
|
||||
Err(Error::tls("server does not support TLS".into()))
|
||||
} else {
|
||||
transition!(Ready((MaybeTlsStream::Raw(stream), ChannelBinding::none())))
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_connecting_tls<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingTls<S, T>>,
|
||||
) -> Poll<AfterConnectingTls<S, T>, Error> {
|
||||
let t = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
|
||||
transition!(Ready(t))
|
||||
let (stream, channel_binding) =
|
||||
try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
|
||||
transition!(Ready((MaybeTlsStream::Tls(stream), channel_binding)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> TlsFuture<S, T>
|
||||
where
|
||||
T: TlsMode<S>,
|
||||
T: TlsConnect<S>,
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
pub fn new(stream: S, tls_mode: T) -> TlsFuture<S, T> {
|
||||
Tls::start(stream, tls_mode)
|
||||
pub fn new(stream: S, mode: SslMode, tls: T) -> TlsFuture<S, T> {
|
||||
Tls::start(stream, mode, tls)
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,13 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use futures::future::{self, FutureResult};
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use futures::{Future, Poll};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::io::{self, Read, Write};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use void::Void;
|
||||
|
||||
pub(crate) mod private {
|
||||
pub struct ForcePrivateApi;
|
||||
}
|
||||
|
||||
pub struct ChannelBinding {
|
||||
pub(crate) tls_server_end_point: Option<Vec<u8>>,
|
||||
@ -25,25 +27,6 @@ impl ChannelBinding {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub trait MakeTlsMode<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
type TlsMode: TlsMode<S, Stream = Self::Stream>;
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
|
||||
fn make_tls_mode(&mut self, domain: &str) -> Result<Self::TlsMode, Self::Error>;
|
||||
}
|
||||
|
||||
pub trait TlsMode<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
type Future: Future<Item = (Self::Stream, ChannelBinding), Error = Self::Error>;
|
||||
|
||||
fn request_tls(&self) -> bool;
|
||||
|
||||
fn handle_tls(self, use_tls: bool, stream: S) -> Self::Future;
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub trait MakeTlsConnect<S> {
|
||||
type Stream: AsyncRead + AsyncWrite;
|
||||
@ -59,271 +42,74 @@ pub trait TlsConnect<S> {
|
||||
type Future: Future<Item = (Self::Stream, ChannelBinding), Error = Self::Error>;
|
||||
|
||||
fn connect(self, stream: S) -> Self::Future;
|
||||
|
||||
#[doc(hidden)]
|
||||
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct NoTls;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<S> MakeTlsMode<S> for NoTls
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Stream = S;
|
||||
type TlsMode = NoTls;
|
||||
type Error = Void;
|
||||
impl<S> MakeTlsConnect<S> for NoTls where {
|
||||
type Stream = NoTlsStream;
|
||||
type TlsConnect = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_mode(&mut self, _: &str) -> Result<NoTls, Void> {
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> TlsMode<S> for NoTls
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Stream = S;
|
||||
type Error = Void;
|
||||
type Future = FutureResult<(S, ChannelBinding), Void>;
|
||||
impl<S> TlsConnect<S> for NoTls {
|
||||
type Stream = NoTlsStream;
|
||||
type Error = NoTlsError;
|
||||
type Future = FutureResult<(NoTlsStream, ChannelBinding), NoTlsError>;
|
||||
|
||||
fn request_tls(&self) -> bool {
|
||||
fn connect(self, _: S) -> FutureResult<(NoTlsStream, ChannelBinding), NoTlsError> {
|
||||
future::err(NoTlsError(()))
|
||||
}
|
||||
|
||||
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_tls(self, use_tls: bool, stream: S) -> FutureResult<(S, ChannelBinding), Void> {
|
||||
debug_assert!(!use_tls);
|
||||
pub enum NoTlsStream {}
|
||||
|
||||
future::ok((stream, ChannelBinding::none()))
|
||||
impl Read for NoTlsStream {
|
||||
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
|
||||
match *self {}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct PreferTls<T>(pub T);
|
||||
impl AsyncRead for NoTlsStream {}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T, S> MakeTlsMode<S> for PreferTls<T>
|
||||
where
|
||||
T: MakeTlsConnect<S>,
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Stream = MaybeTlsStream<T::Stream, S>;
|
||||
type TlsMode = PreferTls<T::TlsConnect>;
|
||||
type Error = T::Error;
|
||||
|
||||
fn make_tls_mode(&mut self, domain: &str) -> Result<PreferTls<T::TlsConnect>, T::Error> {
|
||||
self.0.make_tls_connect(domain).map(PreferTls)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TlsMode<S> for PreferTls<T>
|
||||
where
|
||||
T: TlsConnect<S>,
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Stream = MaybeTlsStream<T::Stream, S>;
|
||||
type Error = T::Error;
|
||||
type Future = PreferTlsFuture<T::Future, S>;
|
||||
|
||||
fn request_tls(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn handle_tls(self, use_tls: bool, stream: S) -> PreferTlsFuture<T::Future, S> {
|
||||
let f = if use_tls {
|
||||
PreferTlsFutureInner::Tls(self.0.connect(stream))
|
||||
} else {
|
||||
PreferTlsFutureInner::Raw(Some(stream))
|
||||
};
|
||||
|
||||
PreferTlsFuture(f)
|
||||
}
|
||||
}
|
||||
|
||||
enum PreferTlsFutureInner<F, S> {
|
||||
Tls(F),
|
||||
Raw(Option<S>),
|
||||
}
|
||||
|
||||
pub struct PreferTlsFuture<F, S>(PreferTlsFutureInner<F, S>);
|
||||
|
||||
impl<F, S, T> Future for PreferTlsFuture<F, S>
|
||||
where
|
||||
F: Future<Item = (T, ChannelBinding)>,
|
||||
{
|
||||
type Item = (MaybeTlsStream<T, S>, ChannelBinding);
|
||||
type Error = F::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(MaybeTlsStream<T, S>, ChannelBinding), F::Error> {
|
||||
match &mut self.0 {
|
||||
PreferTlsFutureInner::Tls(f) => {
|
||||
let (stream, channel_binding) = try_ready!(f.poll());
|
||||
Ok(Async::Ready((MaybeTlsStream::Tls(stream), channel_binding)))
|
||||
}
|
||||
PreferTlsFutureInner::Raw(s) => Ok(Async::Ready((
|
||||
MaybeTlsStream::Raw(s.take().expect("future polled after completion")),
|
||||
ChannelBinding::none(),
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum MaybeTlsStream<T, U> {
|
||||
Tls(T),
|
||||
Raw(U),
|
||||
}
|
||||
|
||||
impl<T, U> Read for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: Read,
|
||||
U: Read,
|
||||
{
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.read(buf),
|
||||
MaybeTlsStream::Raw(s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncRead for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: AsyncRead,
|
||||
U: AsyncRead,
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf),
|
||||
MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: BufMut,
|
||||
{
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.read_buf(buf),
|
||||
MaybeTlsStream::Raw(s) => s.read_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> Write for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: Write,
|
||||
U: Write,
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.write(buf),
|
||||
MaybeTlsStream::Raw(s) => s.write(buf),
|
||||
}
|
||||
impl Write for NoTlsStream {
|
||||
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
|
||||
match *self {}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.flush(),
|
||||
MaybeTlsStream::Raw(s) => s.flush(),
|
||||
}
|
||||
match *self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> AsyncWrite for MaybeTlsStream<T, U>
|
||||
where
|
||||
T: AsyncWrite,
|
||||
U: AsyncWrite,
|
||||
{
|
||||
impl AsyncWrite for NoTlsStream {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.shutdown(),
|
||||
MaybeTlsStream::Raw(s) => s.shutdown(),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: Buf,
|
||||
{
|
||||
match self {
|
||||
MaybeTlsStream::Tls(s) => s.write_buf(buf),
|
||||
MaybeTlsStream::Raw(s) => s.write_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct RequireTls<T>(pub T);
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T, S> MakeTlsMode<S> for RequireTls<T>
|
||||
where
|
||||
T: MakeTlsConnect<S>,
|
||||
{
|
||||
type Stream = T::Stream;
|
||||
type TlsMode = RequireTls<T::TlsConnect>;
|
||||
type Error = T::Error;
|
||||
|
||||
fn make_tls_mode(&mut self, domain: &str) -> Result<RequireTls<T::TlsConnect>, T::Error> {
|
||||
self.0.make_tls_connect(domain).map(RequireTls)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, S> TlsMode<S> for RequireTls<T>
|
||||
where
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
type Stream = T::Stream;
|
||||
type Error = Box<dyn Error + Sync + Send>;
|
||||
type Future = RequireTlsFuture<T::Future>;
|
||||
|
||||
fn request_tls(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn handle_tls(self, use_tls: bool, stream: S) -> RequireTlsFuture<T::Future> {
|
||||
let f = if use_tls {
|
||||
Ok(self.0.connect(stream))
|
||||
} else {
|
||||
Err(TlsUnsupportedError(()).into())
|
||||
};
|
||||
|
||||
RequireTlsFuture { f: Some(f) }
|
||||
match *self {}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TlsUnsupportedError(());
|
||||
pub struct NoTlsError(());
|
||||
|
||||
impl fmt::Display for TlsUnsupportedError {
|
||||
impl fmt::Display for NoTlsError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.write_str("TLS was required but not supported by the server")
|
||||
fmt.write_str("no TLS implementation configured")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for TlsUnsupportedError {}
|
||||
|
||||
pub struct RequireTlsFuture<T> {
|
||||
f: Option<Result<T, Box<dyn Error + Sync + Send>>>,
|
||||
}
|
||||
|
||||
impl<T> Future for RequireTlsFuture<T>
|
||||
where
|
||||
T: Future,
|
||||
T::Error: Into<Box<dyn Error + Sync + Send>>,
|
||||
{
|
||||
type Item = T::Item;
|
||||
type Error = Box<dyn Error + Sync + Send>;
|
||||
|
||||
fn poll(&mut self) -> Poll<T::Item, Box<dyn Error + Sync + Send>> {
|
||||
match self.f.take().expect("future polled after completion") {
|
||||
Ok(mut f) => match f.poll().map_err(Into::into)? {
|
||||
Async::Ready(r) => Ok(Async::Ready(r)),
|
||||
Async::NotReady => {
|
||||
self.f = Some(Ok(f));
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
},
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl Error for NoTlsError {}
|
||||
|
@ -12,7 +12,7 @@ use tokio::runtime::current_thread::Runtime;
|
||||
use tokio::timer::Delay;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::types::{Kind, Type};
|
||||
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls};
|
||||
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls, NoTlsStream};
|
||||
|
||||
mod parse;
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -21,7 +21,8 @@ mod types;
|
||||
|
||||
fn connect(
|
||||
s: &str,
|
||||
) -> impl Future<Item = (Client, Connection<TcpStream>), Error = tokio_postgres::Error> {
|
||||
) -> impl Future<Item = (Client, Connection<TcpStream, NoTlsStream>), Error = tokio_postgres::Error>
|
||||
{
|
||||
let builder = s.parse::<tokio_postgres::Config>().unwrap();
|
||||
TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
|
||||
.map_err(|e| panic!("{}", e))
|
||||
|
Loading…
Reference in New Issue
Block a user