parent
a8d945c70e
commit
6c3a4ab192
@ -12,9 +12,7 @@ where
|
||||
T: TlsConnect<TcpStream>,
|
||||
T::Stream: 'static + Send,
|
||||
{
|
||||
let stream = TcpStream::connect("127.0.0.1:5433")
|
||||
.await
|
||||
.unwrap();
|
||||
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
|
||||
|
||||
let builder = s.parse::<tokio_postgres::Config>().unwrap();
|
||||
let (mut client, connection) = builder.connect_raw(stream, tls).await.unwrap();
|
||||
|
@ -65,6 +65,32 @@ async fn scram_user() {
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn require_channel_binding_err() {
|
||||
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
let connector = TlsConnector::new(ctx.configure().unwrap(), "localhost");
|
||||
|
||||
let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap();
|
||||
let builder = "user=pass_user password=password dbname=postgres channel_binding=require"
|
||||
.parse::<tokio_postgres::Config>()
|
||||
.unwrap();
|
||||
builder.connect_raw(stream, connector).await.err().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn require_channel_binding_ok() {
|
||||
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
"user=scram_user password=password dbname=postgres channel_binding=require",
|
||||
TlsConnector::new(ctx.configure().unwrap(), "localhost"),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(feature = "runtime")]
|
||||
async fn runtime() {
|
||||
|
@ -14,7 +14,7 @@ use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
|
||||
use tokio_postgres::{Error, Socket};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs};
|
||||
pub use tokio_postgres::config::{SslMode, TargetSessionAttrs, ChannelBinding};
|
||||
|
||||
use crate::{Client, RUNTIME};
|
||||
|
||||
@ -234,6 +234,14 @@ impl Config {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
|
||||
self.config.channel_binding(channel_binding);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the executor used to run the connection futures.
|
||||
///
|
||||
/// Defaults to a postgres-specific tokio `Runtime`.
|
||||
|
@ -46,6 +46,19 @@ pub enum SslMode {
|
||||
__NonExhaustive,
|
||||
}
|
||||
|
||||
/// Channel binding configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub enum ChannelBinding {
|
||||
/// Do not use channel binding.
|
||||
Disable,
|
||||
/// Attempt to use channel binding but allow sessions without.
|
||||
Prefer,
|
||||
/// Require the use of channel binding.
|
||||
Require,
|
||||
#[doc(hidden)]
|
||||
__NonExhaustive,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) enum Host {
|
||||
Tcp(String),
|
||||
@ -87,6 +100,9 @@ pub(crate) enum Host {
|
||||
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
|
||||
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
|
||||
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
|
||||
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
|
||||
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
|
||||
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
@ -140,6 +156,7 @@ pub struct Config {
|
||||
pub(crate) keepalives: bool,
|
||||
pub(crate) keepalives_idle: Duration,
|
||||
pub(crate) target_session_attrs: TargetSessionAttrs,
|
||||
pub(crate) channel_binding: ChannelBinding,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@ -164,6 +181,7 @@ impl Config {
|
||||
keepalives: true,
|
||||
keepalives_idle: Duration::from_secs(2 * 60 * 60),
|
||||
target_session_attrs: TargetSessionAttrs::Any,
|
||||
channel_binding: ChannelBinding::Prefer,
|
||||
}
|
||||
}
|
||||
|
||||
@ -287,6 +305,14 @@ impl Config {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
|
||||
self.channel_binding = channel_binding;
|
||||
self
|
||||
}
|
||||
|
||||
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
|
||||
match key {
|
||||
"user" => {
|
||||
@ -363,6 +389,19 @@ impl Config {
|
||||
};
|
||||
self.target_session_attrs(target_session_attrs);
|
||||
}
|
||||
"channel_binding" => {
|
||||
let channel_binding = match value {
|
||||
"disable" => ChannelBinding::Disable,
|
||||
"prefer" => ChannelBinding::Prefer,
|
||||
"require" => ChannelBinding::Require,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"channel_binding",
|
||||
))))
|
||||
}
|
||||
};
|
||||
self.channel_binding(channel_binding);
|
||||
}
|
||||
key => {
|
||||
return Err(Error::config_parse(Box::new(UnknownOption(
|
||||
key.to_string(),
|
||||
@ -434,6 +473,7 @@ impl fmt::Debug for Config {
|
||||
.field("keepalives", &self.keepalives)
|
||||
.field("keepalives_idle", &self.keepalives_idle)
|
||||
.field("target_session_attrs", &self.target_session_attrs)
|
||||
.field("channel_binding", &self.channel_binding)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::config::Config;
|
||||
use crate::config::{self, Config};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{ChannelBinding, TlsConnect};
|
||||
@ -141,8 +141,13 @@ where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => return Ok(()),
|
||||
Some(Message::AuthenticationOk) => {
|
||||
no_channel_binding(config)?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(Message::AuthenticationCleartextPassword) => {
|
||||
no_channel_binding(config)?;
|
||||
|
||||
let pass = config
|
||||
.password
|
||||
.as_ref()
|
||||
@ -151,6 +156,8 @@ where
|
||||
authenticate_password(stream, pass).await?;
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
no_channel_binding(config)?;
|
||||
|
||||
let user = config
|
||||
.user
|
||||
.as_ref()
|
||||
@ -164,12 +171,7 @@ where
|
||||
authenticate_password(stream, output.as_bytes()).await?;
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
let pass = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
authenticate_sasl(stream, body, channel_binding, pass).await?;
|
||||
authenticate_sasl(stream, body, channel_binding, config).await?;
|
||||
}
|
||||
Some(Message::AuthenticationKerberosV5)
|
||||
| Some(Message::AuthenticationScmCredential)
|
||||
@ -192,6 +194,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn no_channel_binding(config: &Config) -> Result<(), Error> {
|
||||
match config.channel_binding {
|
||||
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
|
||||
config::ChannelBinding::Require => Err(Error::authentication(
|
||||
"server did not use channel binding".into(),
|
||||
)),
|
||||
config::ChannelBinding::__NonExhaustive => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_password<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
password: &[u8],
|
||||
@ -213,12 +225,17 @@ async fn authenticate_sasl<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
body: AuthenticationSaslBody,
|
||||
channel_binding: ChannelBinding,
|
||||
password: &[u8],
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let password = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
let mut has_scram = false;
|
||||
let mut has_scram_plus = false;
|
||||
let mut mechanisms = body.mechanisms();
|
||||
@ -232,6 +249,7 @@ where
|
||||
|
||||
let channel_binding = channel_binding
|
||||
.tls_server_end_point
|
||||
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
|
||||
.map(sasl::ChannelBinding::tls_server_end_point);
|
||||
|
||||
let (channel_binding, mechanism) = if has_scram_plus {
|
||||
@ -240,6 +258,8 @@ where
|
||||
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
|
||||
}
|
||||
} else if has_scram {
|
||||
no_channel_binding(config)?;
|
||||
|
||||
match channel_binding {
|
||||
Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
|
||||
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
|
||||
|
@ -106,7 +106,11 @@ pub fn prepare(
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_rec(client: Arc<InnerClient>, query: &str, types: &[Type]) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
|
||||
fn prepare_rec(
|
||||
client: Arc<InnerClient>,
|
||||
query: &str,
|
||||
types: &[Type],
|
||||
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'static + Send>> {
|
||||
Box::pin(prepare(client, query, types))
|
||||
}
|
||||
|
||||
|
@ -20,8 +20,7 @@ mod types;
|
||||
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
|
||||
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
|
||||
let config = s.parse::<Config>().unwrap();
|
||||
// FIXME https://github.com/rust-lang/rust/issues/64391
|
||||
async move { config.connect_raw(socket, NoTls).await }.await
|
||||
config.connect_raw(socket, NoTls).await
|
||||
}
|
||||
|
||||
async fn connect(s: &str) -> Client {
|
||||
@ -608,100 +607,20 @@ async fn query_portal() {
|
||||
assert_eq!(r3.len(), 0);
|
||||
}
|
||||
|
||||
/*
|
||||
#[test]
|
||||
fn poll_idle_running() {
|
||||
struct DelayStream(Delay);
|
||||
|
||||
impl Stream for DelayStream {
|
||||
type Item = Vec<u8>;
|
||||
type Error = tokio_postgres::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Vec<u8>>, tokio_postgres::Error> {
|
||||
try_ready!(self.0.poll().map_err(|e| panic!("{}", e)));
|
||||
QUERY_DONE.store(true, Ordering::SeqCst);
|
||||
Ok(Async::Ready(None))
|
||||
}
|
||||
}
|
||||
|
||||
struct IdleFuture(tokio_postgres::Client);
|
||||
|
||||
impl Future for IdleFuture {
|
||||
type Item = ();
|
||||
type Error = tokio_postgres::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
|
||||
try_ready!(self.0.poll_idle());
|
||||
assert!(QUERY_DONE.load(Ordering::SeqCst));
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
|
||||
static QUERY_DONE: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
let execute = client
|
||||
.simple_query("CREATE TEMPORARY TABLE foo (id INT)")
|
||||
.for_each(|_| Ok(()));
|
||||
runtime.block_on(execute).unwrap();
|
||||
|
||||
let prepare = client.prepare("COPY foo FROM STDIN");
|
||||
let stmt = runtime.block_on(prepare).unwrap();
|
||||
let copy_in = client.copy_in(
|
||||
&stmt,
|
||||
&[],
|
||||
DelayStream(Delay::new(Instant::now() + Duration::from_millis(10))),
|
||||
);
|
||||
let copy_in = copy_in.map(|_| ()).map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(copy_in);
|
||||
|
||||
let future = IdleFuture(client);
|
||||
runtime.block_on(future).unwrap();
|
||||
#[tokio::test]
|
||||
async fn require_channel_binding() {
|
||||
connect_raw("user=postgres channel_binding=require")
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poll_idle_new() {
|
||||
struct IdleFuture {
|
||||
client: tokio_postgres::Client,
|
||||
prepare: Option<impls::Prepare>,
|
||||
}
|
||||
|
||||
impl Future for IdleFuture {
|
||||
type Item = ();
|
||||
type Error = tokio_postgres::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
|
||||
match self.prepare.take() {
|
||||
Some(_future) => {
|
||||
assert!(!self.client.poll_idle().unwrap().is_ready());
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
None => {
|
||||
assert!(self.client.poll_idle().unwrap().is_ready());
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
let prepare = client.prepare("");
|
||||
let future = IdleFuture {
|
||||
client,
|
||||
prepare: Some(prepare),
|
||||
};
|
||||
runtime.block_on(future).unwrap();
|
||||
#[tokio::test]
|
||||
async fn prefer_channel_binding() {
|
||||
connect("user=postgres channel_binding=prefer").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn disable_channel_binding() {
|
||||
connect("user=postgres channel_binding=disable").await;
|
||||
}
|
||||
*/
|
||||
|
Loading…
Reference in New Issue
Block a user