Move the TLS mode into config

This commit is contained in:
Steven Fackler 2019-01-13 14:53:19 -08:00
parent dfc614bed1
commit 2d3b9bb1c6
18 changed files with 358 additions and 426 deletions

View File

@ -3,7 +3,7 @@ use std::io::{self, Read};
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::Error;
#[cfg(feature = "runtime")]
use tokio_postgres::{MakeTlsMode, Socket, TlsMode};
use tokio_postgres::{MakeTlsConnect, Socket, TlsConnect};
#[cfg(feature = "runtime")]
use crate::Config;
@ -15,10 +15,10 @@ impl Client {
#[cfg(feature = "runtime")]
pub fn connect<T>(params: &str, tls_mode: T) -> Result<Client, Error>
where
T: MakeTlsMode<Socket> + 'static + Send,
T::TlsMode: Send,
T: MakeTlsConnect<Socket> + 'static + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsMode as TlsMode<Socket>>::Future: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
params.parse::<Config>()?.connect(tls_mode)
}

View File

@ -4,7 +4,7 @@ use log::error;
use std::path::Path;
use std::str::FromStr;
use std::time::Duration;
use tokio_postgres::{Error, MakeTlsMode, Socket, TargetSessionAttrs, TlsMode};
use tokio_postgres::{Error, MakeTlsConnect, Socket, TargetSessionAttrs, TlsConnect};
use crate::{Client, RUNTIME};
@ -94,10 +94,10 @@ impl Config {
pub fn connect<T>(&self, tls_mode: T) -> Result<Client, Error>
where
T: MakeTlsMode<Socket> + 'static + Send,
T::TlsMode: Send,
T: MakeTlsConnect<Socket> + 'static + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsMode as TlsMode<Socket>>::Future: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let connect = self.0.connect(tls_mode);
let (client, connection) = oneshot::spawn(connect, &RUNTIME.executor()).wait()?;

View File

@ -2,13 +2,13 @@ use futures::{Future, Stream};
use native_tls::{self, Certificate};
use tokio::net::TcpStream;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
use tokio_postgres::TlsConnect;
use crate::TlsConnector;
fn smoke_test<T>(s: &str, tls: T)
where
T: TlsMode<TcpStream>,
T: TlsConnect<TcpStream>,
T::Stream: 'static,
{
let mut runtime = Runtime::new().unwrap();
@ -44,8 +44,8 @@ fn require() {
.build()
.unwrap();
smoke_test(
"user=ssl_user dbname=postgres",
RequireTls(TlsConnector::with_connector(connector, "localhost")),
"user=ssl_user dbname=postgres sslmode=require",
TlsConnector::with_connector(connector, "localhost"),
);
}
@ -59,7 +59,7 @@ fn prefer() {
.unwrap();
smoke_test(
"user=ssl_user dbname=postgres",
PreferTls(TlsConnector::with_connector(connector, "localhost")),
TlsConnector::with_connector(connector, "localhost"),
);
}
@ -72,7 +72,7 @@ fn scram_user() {
.build()
.unwrap();
smoke_test(
"user=scram_user password=password dbname=postgres",
RequireTls(TlsConnector::with_connector(connector, "localhost")),
"user=scram_user password=password dbname=postgres sslmode=require",
TlsConnector::with_connector(connector, "localhost"),
);
}

View File

@ -2,13 +2,13 @@ use futures::{Future, Stream};
use openssl::ssl::{SslConnector, SslMethod};
use tokio::net::TcpStream;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
use tokio_postgres::TlsConnect;
use super::*;
fn smoke_test<T>(s: &str, tls: T)
where
T: TlsMode<TcpStream>,
T: TlsConnect<TcpStream>,
T::Stream: 'static,
{
let mut runtime = Runtime::new().unwrap();
@ -41,8 +41,8 @@ fn require() {
builder.set_ca_file("../test/server.crt").unwrap();
let ctx = builder.build();
smoke_test(
"user=ssl_user dbname=postgres",
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
"user=ssl_user dbname=postgres sslmode=require",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
);
}
@ -53,7 +53,7 @@ fn prefer() {
let ctx = builder.build();
smoke_test(
"user=ssl_user dbname=postgres",
PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
);
}
@ -63,8 +63,8 @@ fn scram_user() {
builder.set_ca_file("../test/server.crt").unwrap();
let ctx = builder.build();
smoke_test(
"user=scram_user password=password dbname=postgres",
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
"user=scram_user password=password dbname=postgres sslmode=require",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
);
}
@ -78,8 +78,8 @@ fn runtime() {
let connector = MakeTlsConnector::new(builder.build());
let connect = tokio_postgres::connect(
"host=localhost port=5433 user=postgres",
RequireTls(connector),
"host=localhost port=5433 user=postgres sslmode=require",
connector,
);
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));

View File

@ -49,7 +49,6 @@ postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
state_machine_future = "0.1.7"
tokio-codec = "0.1"
tokio-io = "0.1"
void = "1.0"
tokio-tcp = { version = "0.1", optional = true }
futures-cpupool = { version = "0.1", optional = true }

View File

@ -19,8 +19,8 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::ConnectFuture;
use crate::proto::ConnectRawFuture;
#[cfg(feature = "runtime")]
use crate::{Connect, MakeTlsMode, Socket};
use crate::{ConnectRaw, Error, TlsMode};
use crate::{Connect, MakeTlsConnect, Socket};
use crate::{ConnectRaw, Error, TlsConnect};
/// Properties required of a session.
#[cfg(feature = "runtime")]
@ -34,6 +34,17 @@ pub enum TargetSessionAttrs {
__NonExhaustive,
}
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum SslMode {
/// Do not use TLS.
Disable,
/// Attempt to connect with TLS but allow sessions without.
Prefer,
/// Require the use of TLS.
Require,
}
#[cfg(feature = "runtime")]
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Host {
@ -49,6 +60,7 @@ pub(crate) struct Inner {
pub(crate) dbname: Option<String>,
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
#[cfg(feature = "runtime")]
pub(crate) host: Vec<Host>,
#[cfg(feature = "runtime")]
@ -79,6 +91,8 @@ pub(crate) struct Inner {
/// * `dbname` - The name of the database to connect to. Defaults to the username.
/// * `options` - Command line options used to configure the server.
/// * `application_name` - Sets the `application_name` parameter on the server.
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
@ -152,6 +166,7 @@ impl Config {
dbname: None,
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
#[cfg(feature = "runtime")]
host: vec![],
#[cfg(feature = "runtime")]
@ -204,6 +219,14 @@ impl Config {
self
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
Arc::make_mut(&mut self.0).ssl_mode = ssl_mode;
self
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
@ -320,6 +343,15 @@ impl Config {
"application_name" => {
self.application_name(&value);
}
"sslmode" => {
let mode = match value {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
};
self.ssl_mode(mode);
}
#[cfg(feature = "runtime")]
"host" => {
for host in value.split(',') {
@ -390,22 +422,22 @@ impl Config {
///
/// Requires the `runtime` Cargo feature (enabled by default).
#[cfg(feature = "runtime")]
pub fn connect<T>(&self, make_tls_mode: T) -> Connect<T>
pub fn connect<T>(&self, tls: T) -> Connect<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
Connect(ConnectFuture::new(tls, Ok(self.clone())))
}
/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored.
pub fn connect_raw<S, T>(&self, stream: S, tls_mode: T) -> ConnectRaw<S, T>
pub fn connect_raw<S, T>(&self, stream: S, tls: T) -> ConnectRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
ConnectRaw(ConnectRawFuture::new(stream, tls_mode, self.clone(), None))
ConnectRaw(ConnectRawFuture::new(stream, tls, self.clone(), None))
}
}

View File

@ -127,11 +127,11 @@ fn next_portal() -> String {
///
/// [`Config`]: ./Config.t.html
#[cfg(feature = "runtime")]
pub fn connect<T>(config: &str, tls_mode: T) -> Connect<T>
pub fn connect<T>(config: &str, tls: T) -> Connect<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
Connect(proto::ConnectFuture::new(tls_mode, config.parse()))
Connect(proto::ConnectFuture::new(tls, config.parse()))
}
/// An asynchronous PostgreSQL client.
@ -250,7 +250,7 @@ impl Client {
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&mut self, make_tls_mode: T) -> CancelQuery<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
CancelQuery(self.0.cancel_query(make_tls_mode))
}
@ -260,7 +260,7 @@ impl Client {
pub fn cancel_query_raw<S, T>(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode))
}
@ -291,11 +291,12 @@ impl Client {
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S>(proto::Connection<S>);
pub struct Connection<S, T>(proto::Connection<proto::MaybeTlsStream<S, T>>);
impl<S> Connection<S>
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite,
T: AsyncRead + AsyncWrite,
{
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
@ -311,9 +312,10 @@ where
}
}
impl<S> Future for Connection<S>
impl<S, T> Future for Connection<S, T>
where
S: AsyncRead + AsyncWrite,
T: AsyncRead + AsyncWrite,
{
type Item = ();
type Error = Error;
@ -342,12 +344,12 @@ pub enum AsyncMessage {
pub struct CancelQueryRaw<S, T>(proto::CancelQueryRawFuture<S, T>)
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>;
T: TlsConnect<S>;
impl<S, T> Future for CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
type Item = ();
type Error = Error;
@ -361,12 +363,12 @@ where
#[must_use = "futures do nothing unless polled"]
pub struct CancelQuery<T>(proto::CancelQueryFuture<T>)
where
T: MakeTlsMode<Socket>;
T: MakeTlsConnect<Socket>;
#[cfg(feature = "runtime")]
impl<T> Future for CancelQuery<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
type Item = ();
type Error = Error;
@ -380,17 +382,17 @@ where
pub struct ConnectRaw<S, T>(proto::ConnectRawFuture<S, T>)
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>;
T: TlsConnect<S>;
impl<S, T> Future for ConnectRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
type Item = (Client, Connection<T::Stream>);
type Item = (Client, Connection<S, T::Stream>);
type Error = Error;
fn poll(&mut self) -> Poll<(Client, Connection<T::Stream>), Error> {
fn poll(&mut self) -> Poll<(Client, Connection<S, T::Stream>), Error> {
let (client, connection) = try_ready!(self.0.poll());
Ok(Async::Ready((Client(client), Connection(connection))))
@ -401,17 +403,17 @@ where
#[must_use = "futures do nothing unless polled"]
pub struct Connect<T>(proto::ConnectFuture<T>)
where
T: MakeTlsMode<Socket>;
T: MakeTlsConnect<Socket>;
#[cfg(feature = "runtime")]
impl<T> Future for Connect<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
type Item = (Client, Connection<T::Stream>);
type Item = (Client, Connection<Socket, T::Stream>);
type Error = Error;
fn poll(&mut self) -> Poll<(Client, Connection<T::Stream>), Error> {
fn poll(&mut self) -> Poll<(Client, Connection<Socket, T::Stream>), Error> {
let (client, connection) = try_ready!(self.0.poll());
Ok(Async::Ready((Client(client), Connection(connection))))

View File

@ -3,16 +3,16 @@ use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::io;
use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture};
use crate::{Config, Error, Host, MakeTlsMode, Socket};
use crate::{Config, Error, Host, MakeTlsConnect, Socket, SslMode};
#[derive(StateMachineFuture)]
pub enum CancelQuery<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
#[state_machine_future(start, transitions(ConnectingSocket))]
Start {
make_tls_mode: T,
tls: T,
idx: Option<usize>,
config: Config,
process_id: i32,
@ -21,13 +21,14 @@ where
#[state_machine_future(transitions(Canceling))]
ConnectingSocket {
future: ConnectSocketFuture,
tls_mode: T::TlsMode,
mode: SslMode,
tls: T::TlsConnect,
process_id: i32,
secret_key: i32,
},
#[state_machine_future(transitions(Finished))]
Canceling {
future: CancelQueryRawFuture<Socket, T::TlsMode>,
future: CancelQueryRawFuture<Socket, T::TlsConnect>,
},
#[state_machine_future(ready)]
Finished(()),
@ -37,7 +38,7 @@ where
impl<T> PollCancelQuery<T> for CancelQuery<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
@ -52,14 +53,15 @@ where
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
let tls = state
.tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
transition!(ConnectingSocket {
mode: state.config.0.ssl_mode,
future: ConnectSocketFuture::new(state.config, idx),
tls_mode,
tls,
process_id: state.process_id,
secret_key: state.secret_key,
})
@ -74,7 +76,8 @@ where
transition!(Canceling {
future: CancelQueryRawFuture::new(
socket,
state.tls_mode,
state.mode,
state.tls,
state.process_id,
state.secret_key
),
@ -91,15 +94,15 @@ where
impl<T> CancelQueryFuture<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
pub fn new(
make_tls_mode: T,
tls: T,
idx: Option<usize>,
config: Config,
process_id: i32,
secret_key: i32,
) -> CancelQueryFuture<T> {
CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key)
CancelQuery::start(tls, idx, config, process_id, secret_key)
}
}

View File

@ -5,14 +5,14 @@ use tokio_io::io::{self, Flush, WriteAll};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::error::Error;
use crate::proto::TlsFuture;
use crate::TlsMode;
use crate::proto::{MaybeTlsStream, TlsFuture};
use crate::{SslMode, TlsConnect};
#[derive(StateMachineFuture)]
pub enum CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
#[state_machine_future(start, transitions(SendingCancel))]
Start {
@ -22,10 +22,12 @@ where
},
#[state_machine_future(transitions(FlushingCancel))]
SendingCancel {
future: WriteAll<T::Stream, Vec<u8>>,
future: WriteAll<MaybeTlsStream<S, T::Stream>, Vec<u8>>,
},
#[state_machine_future(transitions(Finished))]
FlushingCancel { future: Flush<T::Stream> },
FlushingCancel {
future: Flush<MaybeTlsStream<S, T::Stream>>,
},
#[state_machine_future(ready)]
Finished(()),
#[state_machine_future(error)]
@ -35,7 +37,7 @@ where
impl<S, T> PollCancelQueryRaw<S, T> for CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
let (stream, _) = try_ready!(state.future.poll());
@ -69,14 +71,15 @@ where
impl<S, T> CancelQueryRawFuture<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
pub fn new(
stream: S,
tls_mode: T,
mode: SslMode,
tls: T,
process_id: i32,
secret_key: i32,
) -> CancelQueryRawFuture<S, T> {
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key)
CancelQueryRaw::start(TlsFuture::new(stream, mode, tls), process_id, secret_key)
}
}

View File

@ -25,9 +25,9 @@ use crate::proto::statement::Statement;
use crate::proto::CancelQueryFuture;
use crate::proto::CancelQueryRawFuture;
use crate::types::{IsNull, Oid, ToSql, Type};
use crate::{Config, Error, TlsMode};
use crate::{Config, Error, TlsConnect};
#[cfg(feature = "runtime")]
use crate::{MakeTlsMode, Socket};
use crate::{MakeTlsConnect, Socket};
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
@ -247,7 +247,7 @@ impl Client {
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, make_tls_mode: T) -> CancelQueryFuture<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
CancelQueryFuture::new(
make_tls_mode,
@ -258,12 +258,18 @@ impl Client {
)
}
pub fn cancel_query_raw<S, T>(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture<S, T>
pub fn cancel_query_raw<S, T>(&self, stream: S, mode: T) -> CancelQueryRawFuture<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key)
CancelQueryRawFuture::new(
stream,
self.0.config.0.ssl_mode,
mode,
self.0.process_id,
self.0.secret_key,
)
}
fn close(&self, ty: u8, name: &str) {

View File

@ -1,35 +1,35 @@
use futures::{Async, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use crate::proto::{Client, ConnectOnceFuture, Connection};
use crate::{Config, Error, Host, MakeTlsMode, Socket};
use crate::proto::{Client, ConnectOnceFuture, Connection, MaybeTlsStream};
use crate::{Config, Error, Host, MakeTlsConnect, Socket};
#[derive(StateMachineFuture)]
pub enum Connect<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
#[state_machine_future(start, transitions(Connecting))]
Start {
make_tls_mode: T,
tls: T,
config: Result<Config, Error>,
},
#[state_machine_future(transitions(Finished))]
Connecting {
future: ConnectOnceFuture<T::TlsMode>,
future: ConnectOnceFuture<T::TlsConnect>,
idx: usize,
make_tls_mode: T,
tls: T,
config: Config,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)),
#[state_machine_future(error)]
Failed(Error),
}
impl<T> PollConnect<T> for Connect<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
@ -50,15 +50,15 @@ where
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
let tls = state
.tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
transition!(Connecting {
future: ConnectOnceFuture::new(0, tls_mode, config.clone()),
future: ConnectOnceFuture::new(0, tls, config.clone()),
idx: 0,
make_tls_mode: state.make_tls_mode,
tls: state.tls,
config,
})
}
@ -84,13 +84,12 @@ where
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
let tls = state
.tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
state.future =
ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone());
state.future = ConnectOnceFuture::new(state.idx, tls, state.config.clone());
}
}
}
@ -99,9 +98,9 @@ where
impl<T> ConnectFuture<T>
where
T: MakeTlsMode<Socket>,
T: MakeTlsConnect<Socket>,
{
pub fn new(make_tls_mode: T, config: Result<Config, Error>) -> ConnectFuture<T> {
Connect::start(make_tls_mode, config)
pub fn new(tls: T, config: Result<Config, Error>) -> ConnectFuture<T> {
Connect::start(tls, config)
}
}

View File

@ -4,25 +4,23 @@ use futures::{try_ready, Async, Future, Poll, Stream};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::io;
use crate::proto::{Client, ConnectRawFuture, ConnectSocketFuture, Connection, SimpleQueryStream};
use crate::{Config, Error, Socket, TargetSessionAttrs, TlsMode};
use crate::proto::{
Client, ConnectRawFuture, ConnectSocketFuture, Connection, MaybeTlsStream, SimpleQueryStream,
};
use crate::{Config, Error, Socket, TargetSessionAttrs, TlsConnect};
#[derive(StateMachineFuture)]
pub enum ConnectOnce<T>
where
T: TlsMode<Socket>,
T: TlsConnect<Socket>,
{
#[state_machine_future(start, transitions(ConnectingSocket))]
Start {
idx: usize,
tls_mode: T,
config: Config,
},
Start { idx: usize, tls: T, config: Config },
#[state_machine_future(transitions(ConnectingRaw))]
ConnectingSocket {
future: ConnectSocketFuture,
idx: usize,
tls_mode: T,
tls: T,
config: Config,
},
#[state_machine_future(transitions(CheckingSessionAttrs, Finished))]
@ -34,17 +32,17 @@ where
CheckingSessionAttrs {
stream: SimpleQueryStream,
client: Client,
connection: Connection<T::Stream>,
connection: Connection<MaybeTlsStream<Socket, T::Stream>>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
Finished((Client, Connection<MaybeTlsStream<Socket, T::Stream>>)),
#[state_machine_future(error)]
Failed(Error),
}
impl<T> PollConnectOnce<T> for ConnectOnce<T>
where
T: TlsMode<Socket>,
T: TlsConnect<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let state = state.take();
@ -52,7 +50,7 @@ where
transition!(ConnectingSocket {
future: ConnectSocketFuture::new(state.config.clone(), state.idx),
idx: state.idx,
tls_mode: state.tls_mode,
tls: state.tls,
config: state.config,
})
}
@ -65,7 +63,7 @@ where
transition!(ConnectingRaw {
target_session_attrs: state.config.0.target_session_attrs,
future: ConnectRawFuture::new(socket, state.tls_mode, state.config, Some(state.idx)),
future: ConnectRawFuture::new(socket, state.tls, state.config, Some(state.idx)),
})
}
@ -111,9 +109,9 @@ where
impl<T> ConnectOnceFuture<T>
where
T: TlsMode<Socket>,
T: TlsConnect<Socket>,
{
pub fn new(idx: usize, tls_mode: T, config: Config) -> ConnectOnceFuture<T> {
ConnectOnce::start(idx, tls_mode, config)
pub fn new(idx: usize, tls: T, config: Config) -> ConnectOnceFuture<T> {
ConnectOnce::start(idx, tls, config)
}
}

View File

@ -11,14 +11,14 @@ use std::collections::HashMap;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::{Client, Connection, PostgresCodec, TlsFuture};
use crate::{ChannelBinding, Config, Error, TlsMode};
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
use crate::{ChannelBinding, Config, Error, TlsConnect};
#[derive(StateMachineFuture)]
pub enum ConnectRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
#[state_machine_future(start, transitions(SendingStartup))]
Start {
@ -28,47 +28,47 @@ where
},
#[state_machine_future(transitions(ReadingAuth))]
SendingStartup {
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
future: sink::Send<Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>>,
config: Config,
idx: Option<usize>,
channel_binding: ChannelBinding,
},
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
ReadingAuth {
stream: Framed<T::Stream, PostgresCodec>,
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
config: Config,
idx: Option<usize>,
channel_binding: ChannelBinding,
},
#[state_machine_future(transitions(ReadingAuthCompletion))]
SendingPassword {
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
future: sink::Send<Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>>,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(ReadingSasl))]
SendingSasl {
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
future: sink::Send<Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>>,
scram: ScramSha256,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
ReadingSasl {
stream: Framed<T::Stream, PostgresCodec>,
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
scram: ScramSha256,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(ReadingInfo))]
ReadingAuthCompletion {
stream: Framed<T::Stream, PostgresCodec>,
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(Finished))]
ReadingInfo {
stream: Framed<T::Stream, PostgresCodec>,
stream: Framed<MaybeTlsStream<S, T::Stream>, PostgresCodec>,
process_id: i32,
secret_key: i32,
parameters: HashMap<String, String>,
@ -76,7 +76,7 @@ where
idx: Option<usize>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
Finished((Client, Connection<MaybeTlsStream<S, T::Stream>>)),
#[state_machine_future(error)]
Failed(Error),
}
@ -84,7 +84,7 @@ where
impl<S, T> PollConnectRaw<S, T> for ConnectRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
let (stream, channel_binding) = try_ready!(state.future.poll());
@ -377,14 +377,9 @@ where
impl<S, T> ConnectRawFuture<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
T: TlsConnect<S>,
{
pub fn new(
stream: S,
tls_mode: T,
config: Config,
idx: Option<usize>,
) -> ConnectRawFuture<S, T> {
ConnectRaw::start(TlsFuture::new(stream, tls_mode), config, idx)
pub fn new(stream: S, tls: T, config: Config, idx: Option<usize>) -> ConnectRawFuture<S, T> {
ConnectRaw::start(TlsFuture::new(stream, config.0.ssl_mode, tls), config, idx)
}
}

View File

@ -0,0 +1,88 @@
use bytes::{Buf, BufMut};
use futures::Poll;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
pub enum MaybeTlsStream<T, U> {
Raw(T),
Tls(U),
}
impl<T, U> Read for MaybeTlsStream<T, U>
where
T: Read,
U: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
MaybeTlsStream::Raw(s) => s.read(buf),
MaybeTlsStream::Tls(s) => s.read(buf),
}
}
}
impl<T, U> AsyncRead for MaybeTlsStream<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self {
MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf),
MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf),
}
}
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
match self {
MaybeTlsStream::Raw(s) => s.read_buf(buf),
MaybeTlsStream::Tls(s) => s.read_buf(buf),
}
}
}
impl<T, U> Write for MaybeTlsStream<T, U>
where
T: Write,
U: Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
MaybeTlsStream::Raw(s) => s.write(buf),
MaybeTlsStream::Tls(s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
MaybeTlsStream::Raw(s) => s.flush(),
MaybeTlsStream::Tls(s) => s.flush(),
}
}
}
impl<T, U> AsyncWrite for MaybeTlsStream<T, U>
where
T: AsyncWrite,
U: AsyncWrite,
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self {
MaybeTlsStream::Raw(s) => s.shutdown(),
MaybeTlsStream::Tls(s) => s.shutdown(),
}
}
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: Buf,
{
match self {
MaybeTlsStream::Raw(s) => s.write_buf(buf),
MaybeTlsStream::Tls(s) => s.write_buf(buf),
}
}
}

View File

@ -36,6 +36,7 @@ mod copy_in;
mod copy_out;
mod execute;
mod idle;
mod maybe_tls_stream;
mod portal;
mod prepare;
mod query;
@ -64,6 +65,7 @@ pub use crate::proto::connection::Connection;
pub use crate::proto::copy_in::CopyInFuture;
pub use crate::proto::copy_out::CopyOutStream;
pub use crate::proto::execute::ExecuteFuture;
pub use crate::proto::maybe_tls_stream::MaybeTlsStream;
pub use crate::proto::portal::Portal;
pub use crate::proto::prepare::PrepareFuture;
pub use crate::proto::query::QueryStream;

View File

@ -4,54 +4,65 @@ use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use tokio_io::io::{self, ReadExact, WriteAll};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{ChannelBinding, Error, TlsMode};
use crate::proto::MaybeTlsStream;
use crate::tls::private::ForcePrivateApi;
use crate::{ChannelBinding, Error, SslMode, TlsConnect};
#[derive(StateMachineFuture)]
pub enum Tls<S, T>
where
T: TlsMode<S>,
T: TlsConnect<S>,
S: AsyncRead + AsyncWrite,
{
#[state_machine_future(start, transitions(SendingTls, ConnectingTls))]
Start { stream: S, tls_mode: T },
#[state_machine_future(start, transitions(SendingTls, Ready))]
Start { stream: S, mode: SslMode, tls: T },
#[state_machine_future(transitions(ReadingTls))]
SendingTls {
future: WriteAll<S, Vec<u8>>,
tls_mode: T,
mode: SslMode,
tls: T,
},
#[state_machine_future(transitions(ConnectingTls))]
#[state_machine_future(transitions(ConnectingTls, Ready))]
ReadingTls {
future: ReadExact<S, [u8; 1]>,
tls_mode: T,
mode: SslMode,
tls: T,
},
#[state_machine_future(transitions(Ready))]
ConnectingTls { future: T::Future },
#[state_machine_future(ready)]
Ready((T::Stream, ChannelBinding)),
Ready((MaybeTlsStream<S, T::Stream>, ChannelBinding)),
#[state_machine_future(error)]
Failed(Error),
}
impl<S, T> PollTls<S, T> for Tls<S, T>
where
T: TlsMode<S>,
T: TlsConnect<S>,
S: AsyncRead + AsyncWrite,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
let state = state.take();
if state.tls_mode.request_tls() {
let mut buf = vec![];
frontend::ssl_request(&mut buf);
match state.mode {
SslMode::Disable => transition!(Ready((
MaybeTlsStream::Raw(state.stream),
ChannelBinding::none()
))),
SslMode::Prefer if !state.tls.can_connect(ForcePrivateApi) => transition!(Ready((
MaybeTlsStream::Raw(state.stream),
ChannelBinding::none()
))),
SslMode::Prefer | SslMode::Require => {
let mut buf = vec![];
frontend::ssl_request(&mut buf);
transition!(SendingTls {
future: io::write_all(state.stream, buf),
tls_mode: state.tls_mode,
})
} else {
transition!(ConnectingTls {
future: state.tls_mode.handle_tls(false, state.stream),
})
transition!(SendingTls {
future: io::write_all(state.stream, buf),
mode: state.mode,
tls: state.tls,
})
}
}
}
@ -62,7 +73,8 @@ where
let state = state.take();
transition!(ReadingTls {
future: io::read_exact(stream, [0]),
tls_mode: state.tls_mode,
mode: state.mode,
tls: state.tls,
})
}
@ -72,26 +84,32 @@ where
let (stream, buf) = try_ready!(state.future.poll().map_err(Error::io));
let state = state.take();
let use_tls = buf[0] == b'S';
transition!(ConnectingTls {
future: state.tls_mode.handle_tls(use_tls, stream)
})
if buf[0] == b'S' {
transition!(ConnectingTls {
future: state.tls.connect(stream),
})
} else if state.mode == SslMode::Require {
Err(Error::tls("server does not support TLS".into()))
} else {
transition!(Ready((MaybeTlsStream::Raw(stream), ChannelBinding::none())))
}
}
fn poll_connecting_tls<'a>(
state: &'a mut RentToOwn<'a, ConnectingTls<S, T>>,
) -> Poll<AfterConnectingTls<S, T>, Error> {
let t = try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
transition!(Ready(t))
let (stream, channel_binding) =
try_ready!(state.future.poll().map_err(|e| Error::tls(e.into())));
transition!(Ready((MaybeTlsStream::Tls(stream), channel_binding)))
}
}
impl<S, T> TlsFuture<S, T>
where
T: TlsMode<S>,
T: TlsConnect<S>,
S: AsyncRead + AsyncWrite,
{
pub fn new(stream: S, tls_mode: T) -> TlsFuture<S, T> {
Tls::start(stream, tls_mode)
pub fn new(stream: S, mode: SslMode, tls: T) -> TlsFuture<S, T> {
Tls::start(stream, mode, tls)
}
}

View File

@ -1,11 +1,13 @@
use bytes::{Buf, BufMut};
use futures::future::{self, FutureResult};
use futures::{try_ready, Async, Future, Poll};
use futures::{Future, Poll};
use std::error::Error;
use std::fmt;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
use void::Void;
pub(crate) mod private {
pub struct ForcePrivateApi;
}
pub struct ChannelBinding {
pub(crate) tls_server_end_point: Option<Vec<u8>>,
@ -25,25 +27,6 @@ impl ChannelBinding {
}
}
#[cfg(feature = "runtime")]
pub trait MakeTlsMode<S> {
type Stream: AsyncRead + AsyncWrite;
type TlsMode: TlsMode<S, Stream = Self::Stream>;
type Error: Into<Box<dyn Error + Sync + Send>>;
fn make_tls_mode(&mut self, domain: &str) -> Result<Self::TlsMode, Self::Error>;
}
pub trait TlsMode<S> {
type Stream: AsyncRead + AsyncWrite;
type Error: Into<Box<dyn Error + Sync + Send>>;
type Future: Future<Item = (Self::Stream, ChannelBinding), Error = Self::Error>;
fn request_tls(&self) -> bool;
fn handle_tls(self, use_tls: bool, stream: S) -> Self::Future;
}
#[cfg(feature = "runtime")]
pub trait MakeTlsConnect<S> {
type Stream: AsyncRead + AsyncWrite;
@ -59,271 +42,74 @@ pub trait TlsConnect<S> {
type Future: Future<Item = (Self::Stream, ChannelBinding), Error = Self::Error>;
fn connect(self, stream: S) -> Self::Future;
#[doc(hidden)]
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
true
}
}
#[derive(Debug, Copy, Clone)]
pub struct NoTls;
#[cfg(feature = "runtime")]
impl<S> MakeTlsMode<S> for NoTls
where
S: AsyncRead + AsyncWrite,
{
type Stream = S;
type TlsMode = NoTls;
type Error = Void;
impl<S> MakeTlsConnect<S> for NoTls where {
type Stream = NoTlsStream;
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_mode(&mut self, _: &str) -> Result<NoTls, Void> {
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}
impl<S> TlsMode<S> for NoTls
where
S: AsyncRead + AsyncWrite,
{
type Stream = S;
type Error = Void;
type Future = FutureResult<(S, ChannelBinding), Void>;
impl<S> TlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type Error = NoTlsError;
type Future = FutureResult<(NoTlsStream, ChannelBinding), NoTlsError>;
fn request_tls(&self) -> bool {
fn connect(self, _: S) -> FutureResult<(NoTlsStream, ChannelBinding), NoTlsError> {
future::err(NoTlsError(()))
}
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
false
}
}
fn handle_tls(self, use_tls: bool, stream: S) -> FutureResult<(S, ChannelBinding), Void> {
debug_assert!(!use_tls);
pub enum NoTlsStream {}
future::ok((stream, ChannelBinding::none()))
impl Read for NoTlsStream {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
match *self {}
}
}
#[derive(Debug, Copy, Clone)]
pub struct PreferTls<T>(pub T);
impl AsyncRead for NoTlsStream {}
#[cfg(feature = "runtime")]
impl<T, S> MakeTlsMode<S> for PreferTls<T>
where
T: MakeTlsConnect<S>,
S: AsyncRead + AsyncWrite,
{
type Stream = MaybeTlsStream<T::Stream, S>;
type TlsMode = PreferTls<T::TlsConnect>;
type Error = T::Error;
fn make_tls_mode(&mut self, domain: &str) -> Result<PreferTls<T::TlsConnect>, T::Error> {
self.0.make_tls_connect(domain).map(PreferTls)
}
}
impl<T, S> TlsMode<S> for PreferTls<T>
where
T: TlsConnect<S>,
S: AsyncRead + AsyncWrite,
{
type Stream = MaybeTlsStream<T::Stream, S>;
type Error = T::Error;
type Future = PreferTlsFuture<T::Future, S>;
fn request_tls(&self) -> bool {
true
}
fn handle_tls(self, use_tls: bool, stream: S) -> PreferTlsFuture<T::Future, S> {
let f = if use_tls {
PreferTlsFutureInner::Tls(self.0.connect(stream))
} else {
PreferTlsFutureInner::Raw(Some(stream))
};
PreferTlsFuture(f)
}
}
enum PreferTlsFutureInner<F, S> {
Tls(F),
Raw(Option<S>),
}
pub struct PreferTlsFuture<F, S>(PreferTlsFutureInner<F, S>);
impl<F, S, T> Future for PreferTlsFuture<F, S>
where
F: Future<Item = (T, ChannelBinding)>,
{
type Item = (MaybeTlsStream<T, S>, ChannelBinding);
type Error = F::Error;
fn poll(&mut self) -> Poll<(MaybeTlsStream<T, S>, ChannelBinding), F::Error> {
match &mut self.0 {
PreferTlsFutureInner::Tls(f) => {
let (stream, channel_binding) = try_ready!(f.poll());
Ok(Async::Ready((MaybeTlsStream::Tls(stream), channel_binding)))
}
PreferTlsFutureInner::Raw(s) => Ok(Async::Ready((
MaybeTlsStream::Raw(s.take().expect("future polled after completion")),
ChannelBinding::none(),
))),
}
}
}
pub enum MaybeTlsStream<T, U> {
Tls(T),
Raw(U),
}
impl<T, U> Read for MaybeTlsStream<T, U>
where
T: Read,
U: Read,
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
MaybeTlsStream::Tls(s) => s.read(buf),
MaybeTlsStream::Raw(s) => s.read(buf),
}
}
}
impl<T, U> AsyncRead for MaybeTlsStream<T, U>
where
T: AsyncRead,
U: AsyncRead,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self {
MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf),
MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf),
}
}
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
match self {
MaybeTlsStream::Tls(s) => s.read_buf(buf),
MaybeTlsStream::Raw(s) => s.read_buf(buf),
}
}
}
impl<T, U> Write for MaybeTlsStream<T, U>
where
T: Write,
U: Write,
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
MaybeTlsStream::Tls(s) => s.write(buf),
MaybeTlsStream::Raw(s) => s.write(buf),
}
impl Write for NoTlsStream {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
match *self {}
}
fn flush(&mut self) -> io::Result<()> {
match self {
MaybeTlsStream::Tls(s) => s.flush(),
MaybeTlsStream::Raw(s) => s.flush(),
}
match *self {}
}
}
impl<T, U> AsyncWrite for MaybeTlsStream<T, U>
where
T: AsyncWrite,
U: AsyncWrite,
{
impl AsyncWrite for NoTlsStream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self {
MaybeTlsStream::Tls(s) => s.shutdown(),
MaybeTlsStream::Raw(s) => s.shutdown(),
}
}
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: Buf,
{
match self {
MaybeTlsStream::Tls(s) => s.write_buf(buf),
MaybeTlsStream::Raw(s) => s.write_buf(buf),
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct RequireTls<T>(pub T);
#[cfg(feature = "runtime")]
impl<T, S> MakeTlsMode<S> for RequireTls<T>
where
T: MakeTlsConnect<S>,
{
type Stream = T::Stream;
type TlsMode = RequireTls<T::TlsConnect>;
type Error = T::Error;
fn make_tls_mode(&mut self, domain: &str) -> Result<RequireTls<T::TlsConnect>, T::Error> {
self.0.make_tls_connect(domain).map(RequireTls)
}
}
impl<T, S> TlsMode<S> for RequireTls<T>
where
T: TlsConnect<S>,
{
type Stream = T::Stream;
type Error = Box<dyn Error + Sync + Send>;
type Future = RequireTlsFuture<T::Future>;
fn request_tls(&self) -> bool {
true
}
fn handle_tls(self, use_tls: bool, stream: S) -> RequireTlsFuture<T::Future> {
let f = if use_tls {
Ok(self.0.connect(stream))
} else {
Err(TlsUnsupportedError(()).into())
};
RequireTlsFuture { f: Some(f) }
match *self {}
}
}
#[derive(Debug)]
pub struct TlsUnsupportedError(());
pub struct NoTlsError(());
impl fmt::Display for TlsUnsupportedError {
impl fmt::Display for NoTlsError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("TLS was required but not supported by the server")
fmt.write_str("no TLS implementation configured")
}
}
impl Error for TlsUnsupportedError {}
pub struct RequireTlsFuture<T> {
f: Option<Result<T, Box<dyn Error + Sync + Send>>>,
}
impl<T> Future for RequireTlsFuture<T>
where
T: Future,
T::Error: Into<Box<dyn Error + Sync + Send>>,
{
type Item = T::Item;
type Error = Box<dyn Error + Sync + Send>;
fn poll(&mut self) -> Poll<T::Item, Box<dyn Error + Sync + Send>> {
match self.f.take().expect("future polled after completion") {
Ok(mut f) => match f.poll().map_err(Into::into)? {
Async::Ready(r) => Ok(Async::Ready(r)),
Async::NotReady => {
self.f = Some(Ok(f));
Ok(Async::NotReady)
}
},
Err(e) => Err(e),
}
}
}
impl Error for NoTlsError {}

View File

@ -12,7 +12,7 @@ use tokio::runtime::current_thread::Runtime;
use tokio::timer::Delay;
use tokio_postgres::error::SqlState;
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls};
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls, NoTlsStream};
mod parse;
#[cfg(feature = "runtime")]
@ -21,7 +21,8 @@ mod types;
fn connect(
s: &str,
) -> impl Future<Item = (Client, Connection<TcpStream>), Error = tokio_postgres::Error> {
) -> impl Future<Item = (Client, Connection<TcpStream, NoTlsStream>), Error = tokio_postgres::Error>
{
let builder = s.parse::<tokio_postgres::Config>().unwrap();
TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
.map_err(|e| panic!("{}", e))