From 6c3a4ab19208e682e150ce31b987297b9e8bb75b Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Tue, 24 Sep 2019 17:03:37 -0700 Subject: [PATCH] Add channel_binding=disable/prefer/require to config Closes #487 --- postgres-native-tls/src/test.rs | 4 +- postgres-openssl/src/test.rs | 26 +++++++ postgres/src/config.rs | 10 ++- tokio-postgres/src/config.rs | 40 +++++++++++ tokio-postgres/src/connect_raw.rs | 38 +++++++--- tokio-postgres/src/prepare.rs | 6 +- tokio-postgres/tests/test/main.rs | 111 ++++-------------------------- 7 files changed, 125 insertions(+), 110 deletions(-) diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index d4942fe8..416a3c14 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -12,9 +12,7 @@ where T: TlsConnect, T::Stream: 'static + Send, { - let stream = TcpStream::connect("127.0.0.1:5433") - .await - .unwrap(); + let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap(); let builder = s.parse::().unwrap(); let (mut client, connection) = builder.connect_raw(stream, tls).await.unwrap(); diff --git a/postgres-openssl/src/test.rs b/postgres-openssl/src/test.rs index df9054a0..e3ee454e 100644 --- a/postgres-openssl/src/test.rs +++ b/postgres-openssl/src/test.rs @@ -65,6 +65,32 @@ async fn scram_user() { .await; } +#[tokio::test] +async fn require_channel_binding_err() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let ctx = builder.build(); + let connector = TlsConnector::new(ctx.configure().unwrap(), "localhost"); + + let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap(); + let builder = "user=pass_user password=password dbname=postgres channel_binding=require" + .parse::() + .unwrap(); + builder.connect_raw(stream, connector).await.err().unwrap(); +} + +#[tokio::test] +async fn require_channel_binding_ok() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let ctx = builder.build(); + smoke_test( + "user=scram_user password=password dbname=postgres channel_binding=require", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; +} + #[tokio::test] #[cfg(feature = "runtime")] async fn runtime() { diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 4a2d4509..354a99ce 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -14,7 +14,7 @@ use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::{Error, Socket}; #[doc(inline)] -pub use tokio_postgres::config::{SslMode, TargetSessionAttrs}; +pub use tokio_postgres::config::{SslMode, TargetSessionAttrs, ChannelBinding}; use crate::{Client, RUNTIME}; @@ -234,6 +234,14 @@ impl Config { self } + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.config.channel_binding(channel_binding); + self + } + /// Sets the executor used to run the connection futures. /// /// Defaults to a postgres-specific tokio `Runtime`. diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 19df1a35..0dc6d5bf 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -46,6 +46,19 @@ pub enum SslMode { __NonExhaustive, } +/// Channel binding configuration. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ChannelBinding { + /// Do not use channel binding. + Disable, + /// Attempt to use channel binding but allow sessions without. + Prefer, + /// Require the use of channel binding. + Require, + #[doc(hidden)] + __NonExhaustive, +} + #[derive(Debug, Clone, PartialEq)] pub(crate) enum Host { Tcp(String), @@ -87,6 +100,9 @@ pub(crate) enum Host { /// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that /// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server /// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. /// /// ## Examples /// @@ -140,6 +156,7 @@ pub struct Config { pub(crate) keepalives: bool, pub(crate) keepalives_idle: Duration, pub(crate) target_session_attrs: TargetSessionAttrs, + pub(crate) channel_binding: ChannelBinding, } impl Default for Config { @@ -164,6 +181,7 @@ impl Config { keepalives: true, keepalives_idle: Duration::from_secs(2 * 60 * 60), target_session_attrs: TargetSessionAttrs::Any, + channel_binding: ChannelBinding::Prefer, } } @@ -287,6 +305,14 @@ impl Config { self } + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.channel_binding = channel_binding; + self + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -363,6 +389,19 @@ impl Config { }; self.target_session_attrs(target_session_attrs); } + "channel_binding" => { + let channel_binding = match value { + "disable" => ChannelBinding::Disable, + "prefer" => ChannelBinding::Prefer, + "require" => ChannelBinding::Require, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "channel_binding", + )))) + } + }; + self.channel_binding(channel_binding); + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), @@ -434,6 +473,7 @@ impl fmt::Debug for Config { .field("keepalives", &self.keepalives) .field("keepalives_idle", &self.keepalives_idle) .field("target_session_attrs", &self.target_session_attrs) + .field("channel_binding", &self.channel_binding) .finish() } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 7b9fbd5e..cf80a91c 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -1,5 +1,5 @@ use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::config::Config; +use crate::config::{self, Config}; use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{ChannelBinding, TlsConnect}; @@ -141,8 +141,13 @@ where T: AsyncRead + AsyncWrite + Unpin, { match stream.try_next().await.map_err(Error::io)? { - Some(Message::AuthenticationOk) => return Ok(()), + Some(Message::AuthenticationOk) => { + no_channel_binding(config)?; + return Ok(()); + } Some(Message::AuthenticationCleartextPassword) => { + no_channel_binding(config)?; + let pass = config .password .as_ref() @@ -151,6 +156,8 @@ where authenticate_password(stream, pass).await?; } Some(Message::AuthenticationMd5Password(body)) => { + no_channel_binding(config)?; + let user = config .user .as_ref() @@ -164,12 +171,7 @@ where authenticate_password(stream, output.as_bytes()).await?; } Some(Message::AuthenticationSasl(body)) => { - let pass = config - .password - .as_ref() - .ok_or_else(|| Error::config("password missing".into()))?; - - authenticate_sasl(stream, body, channel_binding, pass).await?; + authenticate_sasl(stream, body, channel_binding, config).await?; } Some(Message::AuthenticationKerberosV5) | Some(Message::AuthenticationScmCredential) @@ -192,6 +194,16 @@ where } } +fn no_channel_binding(config: &Config) -> Result<(), Error> { + match config.channel_binding { + config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()), + config::ChannelBinding::Require => Err(Error::authentication( + "server did not use channel binding".into(), + )), + config::ChannelBinding::__NonExhaustive => unreachable!(), + } +} + async fn authenticate_password( stream: &mut StartupStream, password: &[u8], @@ -213,12 +225,17 @@ async fn authenticate_sasl( stream: &mut StartupStream, body: AuthenticationSaslBody, channel_binding: ChannelBinding, - password: &[u8], + config: &Config, ) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { + let password = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + let mut has_scram = false; let mut has_scram_plus = false; let mut mechanisms = body.mechanisms(); @@ -232,6 +249,7 @@ where let channel_binding = channel_binding .tls_server_end_point + .filter(|_| config.channel_binding != config::ChannelBinding::Disable) .map(sasl::ChannelBinding::tls_server_end_point); let (channel_binding, mechanism) = if has_scram_plus { @@ -240,6 +258,8 @@ where None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), } } else if has_scram { + no_channel_binding(config)?; + match channel_binding { Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index 7db3a5b1..c3f70c41 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -106,7 +106,11 @@ pub fn prepare( } } -fn prepare_rec(client: Arc, query: &str, types: &[Type]) -> Pin> + 'static + Send>> { +fn prepare_rec( + client: Arc, + query: &str, + types: &[Type], +) -> Pin> + 'static + Send>> { Box::pin(prepare(client, query, types)) } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 032400d9..802e9149 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -20,8 +20,7 @@ mod types; async fn connect_raw(s: &str) -> Result<(Client, Connection), Error> { let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap(); let config = s.parse::().unwrap(); - // FIXME https://github.com/rust-lang/rust/issues/64391 - async move { config.connect_raw(socket, NoTls).await }.await + config.connect_raw(socket, NoTls).await } async fn connect(s: &str) -> Client { @@ -608,100 +607,20 @@ async fn query_portal() { assert_eq!(r3.len(), 0); } -/* -#[test] -fn poll_idle_running() { - struct DelayStream(Delay); - - impl Stream for DelayStream { - type Item = Vec; - type Error = tokio_postgres::Error; - - fn poll(&mut self) -> Poll>, tokio_postgres::Error> { - try_ready!(self.0.poll().map_err(|e| panic!("{}", e))); - QUERY_DONE.store(true, Ordering::SeqCst); - Ok(Async::Ready(None)) - } - } - - struct IdleFuture(tokio_postgres::Client); - - impl Future for IdleFuture { - type Item = (); - type Error = tokio_postgres::Error; - - fn poll(&mut self) -> Poll<(), tokio_postgres::Error> { - try_ready!(self.0.poll_idle()); - assert!(QUERY_DONE.load(Ordering::SeqCst)); - Ok(Async::Ready(())) - } - } - - static QUERY_DONE: AtomicBool = AtomicBool::new(false); - - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let execute = client - .simple_query("CREATE TEMPORARY TABLE foo (id INT)") - .for_each(|_| Ok(())); - runtime.block_on(execute).unwrap(); - - let prepare = client.prepare("COPY foo FROM STDIN"); - let stmt = runtime.block_on(prepare).unwrap(); - let copy_in = client.copy_in( - &stmt, - &[], - DelayStream(Delay::new(Instant::now() + Duration::from_millis(10))), - ); - let copy_in = copy_in.map(|_| ()).map_err(|e| panic!("{}", e)); - runtime.spawn(copy_in); - - let future = IdleFuture(client); - runtime.block_on(future).unwrap(); +#[tokio::test] +async fn require_channel_binding() { + connect_raw("user=postgres channel_binding=require") + .await + .err() + .unwrap(); } -#[test] -fn poll_idle_new() { - struct IdleFuture { - client: tokio_postgres::Client, - prepare: Option, - } - - impl Future for IdleFuture { - type Item = (); - type Error = tokio_postgres::Error; - - fn poll(&mut self) -> Poll<(), tokio_postgres::Error> { - match self.prepare.take() { - Some(_future) => { - assert!(!self.client.poll_idle().unwrap().is_ready()); - Ok(Async::NotReady) - } - None => { - assert!(self.client.poll_idle().unwrap().is_ready()); - Ok(Async::Ready(())) - } - } - } - } - - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let prepare = client.prepare(""); - let future = IdleFuture { - client, - prepare: Some(prepare), - }; - runtime.block_on(future).unwrap(); +#[tokio::test] +async fn prefer_channel_binding() { + connect("user=postgres channel_binding=prefer").await; +} + +#[tokio::test] +async fn disable_channel_binding() { + connect("user=postgres channel_binding=disable").await; } -*/