Add channel_binding=disable/prefer/require to config

Closes #487
This commit is contained in:
Steven Fackler 2019-09-24 17:03:37 -07:00
parent a8d945c70e
commit 6c3a4ab192
7 changed files with 125 additions and 110 deletions

View File

@ -12,9 +12,7 @@ where
T: TlsConnect<TcpStream>,
T::Stream: 'static + Send,
{
let stream = TcpStream::connect("127.0.0.1:5433")
.await
.unwrap();
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let builder = s.parse::<tokio_postgres::Config>().unwrap();
let (mut client, connection) = builder.connect_raw(stream, tls).await.unwrap();

View File

@ -65,6 +65,32 @@ async fn scram_user() {
.await;
}
#[tokio::test]
async fn require_channel_binding_err() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let ctx = builder.build();
let connector = TlsConnector::new(ctx.configure().unwrap(), "localhost");
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let builder = "user=pass_user password=password dbname=postgres channel_binding=require"
.parse::<tokio_postgres::Config>()
.unwrap();
builder.connect_raw(stream, connector).await.err().unwrap();
}
#[tokio::test]
async fn require_channel_binding_ok() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let ctx = builder.build();
smoke_test(
"user=scram_user password=password dbname=postgres channel_binding=require",
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
)
.await;
}
#[tokio::test]
#[cfg(feature = "runtime")]
async fn runtime() {

View File

@ -14,7 +14,7 @@ use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Error, Socket};
#[doc(inline)]
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs};
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs, ChannelBinding};
use crate::{Client, RUNTIME};
@ -234,6 +234,14 @@ impl Config {
self
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Sets the executor used to run the connection futures.
///
/// Defaults to a postgres-specific tokio `Runtime`.

View File

@ -46,6 +46,19 @@ pub enum SslMode {
__NonExhaustive,
}
/// Channel binding configuration.
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ChannelBinding {
/// Do not use channel binding.
Disable,
/// Attempt to use channel binding but allow sessions without.
Prefer,
/// Require the use of channel binding.
Require,
#[doc(hidden)]
__NonExhaustive,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Host {
Tcp(String),
@ -87,6 +100,9 @@ pub(crate) enum Host {
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
///
/// ## Examples
///
@ -140,6 +156,7 @@ pub struct Config {
pub(crate) keepalives: bool,
pub(crate) keepalives_idle: Duration,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
}
impl Default for Config {
@ -164,6 +181,7 @@ impl Config {
keepalives: true,
keepalives_idle: Duration::from_secs(2 * 60 * 60),
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
}
}
@ -287,6 +305,14 @@ impl Config {
self
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.channel_binding = channel_binding;
self
}
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
@ -363,6 +389,19 @@ impl Config {
};
self.target_session_attrs(target_session_attrs);
}
"channel_binding" => {
let channel_binding = match value {
"disable" => ChannelBinding::Disable,
"prefer" => ChannelBinding::Prefer,
"require" => ChannelBinding::Require,
_ => {
return Err(Error::config_parse(Box::new(InvalidValue(
"channel_binding",
))))
}
};
self.channel_binding(channel_binding);
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
@ -434,6 +473,7 @@ impl fmt::Debug for Config {
.field("keepalives", &self.keepalives)
.field("keepalives_idle", &self.keepalives_idle)
.field("target_session_attrs", &self.target_session_attrs)
.field("channel_binding", &self.channel_binding)
.finish()
}
}

View File

@ -1,5 +1,5 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::Config;
use crate::config::{self, Config};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{ChannelBinding, TlsConnect};
@ -141,8 +141,13 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => return Ok(()),
Some(Message::AuthenticationOk) => {
no_channel_binding(config)?;
return Ok(());
}
Some(Message::AuthenticationCleartextPassword) => {
no_channel_binding(config)?;
let pass = config
.password
.as_ref()
@ -151,6 +156,8 @@ where
authenticate_password(stream, pass).await?;
}
Some(Message::AuthenticationMd5Password(body)) => {
no_channel_binding(config)?;
let user = config
.user
.as_ref()
@ -164,12 +171,7 @@ where
authenticate_password(stream, output.as_bytes()).await?;
}
Some(Message::AuthenticationSasl(body)) => {
let pass = config
.password
.as_ref()
.ok_or_else(|| Error::config("password missing".into()))?;
authenticate_sasl(stream, body, channel_binding, pass).await?;
authenticate_sasl(stream, body, channel_binding, config).await?;
}
Some(Message::AuthenticationKerberosV5)
| Some(Message::AuthenticationScmCredential)
@ -192,6 +194,16 @@ where
}
}
fn no_channel_binding(config: &Config) -> Result<(), Error> {
match config.channel_binding {
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
config::ChannelBinding::Require => Err(Error::authentication(
"server did not use channel binding".into(),
)),
config::ChannelBinding::__NonExhaustive => unreachable!(),
}
}
async fn authenticate_password<S, T>(
stream: &mut StartupStream<S, T>,
password: &[u8],
@ -213,12 +225,17 @@ async fn authenticate_sasl<S, T>(
stream: &mut StartupStream<S, T>,
body: AuthenticationSaslBody,
channel_binding: ChannelBinding,
password: &[u8],
config: &Config,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let password = config
.password
.as_ref()
.ok_or_else(|| Error::config("password missing".into()))?;
let mut has_scram = false;
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
@ -232,6 +249,7 @@ where
let channel_binding = channel_binding
.tls_server_end_point
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
.map(sasl::ChannelBinding::tls_server_end_point);
let (channel_binding, mechanism) = if has_scram_plus {
@ -240,6 +258,8 @@ where
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else if has_scram {
no_channel_binding(config)?;
match channel_binding {
Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),

View File

@ -106,7 +106,11 @@ pub fn prepare(
}
}
fn prepare_rec(client: Arc<InnerClient>, query: &str, types: &[Type]) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
fn prepare_rec(
client: Arc<InnerClient>,
query: &str,
types: &[Type],
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
Box::pin(prepare(client, query, types))
}

View File

@ -20,8 +20,7 @@ mod types;
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let config = s.parse::<Config>().unwrap();
// FIXME https://github.com/rust-lang/rust/issues/64391
async move { config.connect_raw(socket, NoTls).await }.await
config.connect_raw(socket, NoTls).await
}
async fn connect(s: &str) -> Client {
@ -608,100 +607,20 @@ async fn query_portal() {
assert_eq!(r3.len(), 0);
}
/*
#[test]
fn poll_idle_running() {
struct DelayStream(Delay);
impl Stream for DelayStream {
type Item = Vec<u8>;
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<Option<Vec<u8>>, tokio_postgres::Error> {
try_ready!(self.0.poll().map_err(|e| panic!("{}", e)));
QUERY_DONE.store(true, Ordering::SeqCst);
Ok(Async::Ready(None))
}
}
struct IdleFuture(tokio_postgres::Client);
impl Future for IdleFuture {
type Item = ();
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
try_ready!(self.0.poll_idle());
assert!(QUERY_DONE.load(Ordering::SeqCst));
Ok(Async::Ready(()))
}
}
static QUERY_DONE: AtomicBool = AtomicBool::new(false);
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let execute = client
.simple_query("CREATE TEMPORARY TABLE foo (id INT)")
.for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
let prepare = client.prepare("COPY foo FROM STDIN");
let stmt = runtime.block_on(prepare).unwrap();
let copy_in = client.copy_in(
&stmt,
&[],
DelayStream(Delay::new(Instant::now() + Duration::from_millis(10))),
);
let copy_in = copy_in.map(|_| ()).map_err(|e| panic!("{}", e));
runtime.spawn(copy_in);
let future = IdleFuture(client);
runtime.block_on(future).unwrap();
#[tokio::test]
async fn require_channel_binding() {
connect_raw("user=postgres channel_binding=require")
.await
.err()
.unwrap();
}
#[test]
fn poll_idle_new() {
struct IdleFuture {
client: tokio_postgres::Client,
prepare: Option<impls::Prepare>,
}
impl Future for IdleFuture {
type Item = ();
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
match self.prepare.take() {
Some(_future) => {
assert!(!self.client.poll_idle().unwrap().is_ready());
Ok(Async::NotReady)
}
None => {
assert!(self.client.poll_idle().unwrap().is_ready());
Ok(Async::Ready(()))
}
}
}
}
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare = client.prepare("");
let future = IdleFuture {
client,
prepare: Some(prepare),
};
runtime.block_on(future).unwrap();
#[tokio::test]
async fn prefer_channel_binding() {
connect("user=postgres channel_binding=prefer").await;
}
#[tokio::test]
async fn disable_channel_binding() {
connect("user=postgres channel_binding=disable").await;
}
*/