From 635e6381b37f7dd59c97a1ede39b10a49ce458f6 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 28 Dec 2018 13:51:30 -0500 Subject: [PATCH] A less stringy builder This allows us to support things like non-utf8 passwords and unix socket directories. --- postgres/src/builder.rs | 34 +++++++ tokio-postgres-native-tls/src/test.rs | 17 ++-- tokio-postgres-openssl/src/test.rs | 17 ++-- tokio-postgres/src/builder.rs | 121 +++++++++++++++++++---- tokio-postgres/src/proto/connect.rs | 98 ++++++++---------- tokio-postgres/src/proto/connect_once.rs | 96 ++++++++---------- tokio-postgres/src/proto/handshake.rs | 76 +++++++------- tokio-postgres/tests/test/parse.rs | 22 ++--- 8 files changed, 279 insertions(+), 202 deletions(-) diff --git a/postgres/src/builder.rs b/postgres/src/builder.rs index 43f35e6d..4eb75526 100644 --- a/postgres/src/builder.rs +++ b/postgres/src/builder.rs @@ -1,7 +1,9 @@ use futures::sync::oneshot; use futures::Future; use log::error; +use std::path::Path; use std::str::FromStr; +use std::time::Duration; use tokio_postgres::{Error, MakeTlsMode, Socket, TlsMode}; use crate::{Client, RUNTIME}; @@ -19,11 +21,43 @@ impl Builder { Builder(tokio_postgres::Builder::new()) } + pub fn host(&mut self, host: &str) -> &mut Builder { + self.0.host(host); + self + } + + #[cfg(unix)] + pub fn host_path(&mut self, host: T) -> &mut Builder + where + T: AsRef, + { + self.0.host_path(host); + self + } + + pub fn port(&mut self, port: u16) -> &mut Builder { + self.0.port(port); + self + } + pub fn param(&mut self, key: &str, value: &str) -> &mut Builder { self.0.param(key, value); self } + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder { + self.0.connect_timeout(connect_timeout); + self + } + + pub fn password(&mut self, password: T) -> &mut Builder + where + T: AsRef<[u8]>, + { + self.0.password(password); + self + } + pub fn connect(&self, tls_mode: T) -> Result where T: MakeTlsMode + 'static + Send, diff --git a/tokio-postgres-native-tls/src/test.rs b/tokio-postgres-native-tls/src/test.rs index 8e21bf0d..78e7852e 100644 --- a/tokio-postgres-native-tls/src/test.rs +++ b/tokio-postgres-native-tls/src/test.rs @@ -6,13 +6,15 @@ use tokio_postgres::{self, PreferTls, RequireTls, TlsMode}; use crate::TlsConnector; -fn smoke_test(builder: &tokio_postgres::Builder, tls: T) +fn smoke_test(s: &str, tls: T) where T: TlsMode, T::Stream: 'static, { let mut runtime = Runtime::new().unwrap(); + let builder = s.parse::().unwrap(); + let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .map_err(|e| panic!("{}", e)) .and_then(|s| builder.handshake(s, tls)); @@ -42,9 +44,7 @@ fn require() { .build() .unwrap(); smoke_test( - tokio_postgres::Builder::new() - .user("ssl_user") - .dbname("postgres"), + "user=ssl_user dbname=postgres", RequireTls(TlsConnector::with_connector(connector, "localhost")), ); } @@ -58,9 +58,7 @@ fn prefer() { .build() .unwrap(); smoke_test( - tokio_postgres::Builder::new() - .user("ssl_user") - .dbname("postgres"), + "user=ssl_user dbname=postgres", PreferTls(TlsConnector::with_connector(connector, "localhost")), ); } @@ -74,10 +72,7 @@ fn scram_user() { .build() .unwrap(); smoke_test( - tokio_postgres::Builder::new() - .user("scram_user") - .password("password") - .dbname("postgres"), + "user=scram_user password=password dbname=postgres", RequireTls(TlsConnector::with_connector(connector, "localhost")), ); } diff --git a/tokio-postgres-openssl/src/test.rs b/tokio-postgres-openssl/src/test.rs index 28506735..a85cc534 100644 --- a/tokio-postgres-openssl/src/test.rs +++ b/tokio-postgres-openssl/src/test.rs @@ -6,13 +6,15 @@ use tokio_postgres::{self, PreferTls, RequireTls, TlsMode}; use super::*; -fn smoke_test(builder: &tokio_postgres::Builder, tls: T) +fn smoke_test(s: &str, tls: T) where T: TlsMode, T::Stream: 'static, { let mut runtime = Runtime::new().unwrap(); + let builder = s.parse::().unwrap(); + let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .map_err(|e| panic!("{}", e)) .and_then(|s| builder.handshake(s, tls)); @@ -39,9 +41,7 @@ fn require() { builder.set_ca_file("../test/server.crt").unwrap(); let ctx = builder.build(); smoke_test( - tokio_postgres::Builder::new() - .user("ssl_user") - .dbname("postgres"), + "user=ssl_user dbname=postgres", RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")), ); } @@ -52,9 +52,7 @@ fn prefer() { builder.set_ca_file("../test/server.crt").unwrap(); let ctx = builder.build(); smoke_test( - tokio_postgres::Builder::new() - .user("ssl_user") - .dbname("postgres"), + "user=ssl_user dbname=postgres", PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")), ); } @@ -65,10 +63,7 @@ fn scram_user() { builder.set_ca_file("../test/server.crt").unwrap(); let ctx = builder.build(); smoke_test( - tokio_postgres::Builder::new() - .user("scram_user") - .password("password") - .dbname("postgres"), + "user=scram_user password=password dbname=postgres", RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")), ); } diff --git a/tokio-postgres/src/builder.rs b/tokio-postgres/src/builder.rs index 69d999be..ce3792fc 100644 --- a/tokio-postgres/src/builder.rs +++ b/tokio-postgres/src/builder.rs @@ -1,6 +1,10 @@ use std::collections::hash_map::{self, HashMap}; use std::iter; +#[cfg(all(feature = "runtime", unix))] +use std::path::{Path, PathBuf}; use std::str::{self, FromStr}; +#[cfg(feature = "runtime")] +use std::time::Duration; use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] @@ -10,9 +14,24 @@ use crate::proto::HandshakeFuture; use crate::{Connect, MakeTlsMode, Socket}; use crate::{Error, Handshake, TlsMode}; -#[derive(Clone)] +#[cfg(feature = "runtime")] +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum Host { + Tcp(String), + #[cfg(unix)] + Unix(PathBuf), +} + +#[derive(Debug, Clone, PartialEq)] pub struct Builder { - params: HashMap, + pub(crate) params: HashMap, + pub(crate) password: Option>, + #[cfg(feature = "runtime")] + pub(crate) host: Vec, + #[cfg(feature = "runtime")] + pub(crate) port: Vec, + #[cfg(feature = "runtime")] + pub(crate) connect_timeout: Option, } impl Default for Builder { @@ -27,19 +46,59 @@ impl Builder { params.insert("client_encoding".to_string(), "UTF8".to_string()); params.insert("timezone".to_string(), "GMT".to_string()); - Builder { params } + Builder { + params, + password: None, + #[cfg(feature = "runtime")] + host: vec![], + #[cfg(feature = "runtime")] + port: vec![], + #[cfg(feature = "runtime")] + connect_timeout: None, + } } - pub fn user(&mut self, user: &str) -> &mut Builder { - self.param("user", user) + #[cfg(feature = "runtime")] + pub fn host(&mut self, host: &str) -> &mut Builder { + #[cfg(unix)] + { + if host.starts_with('/') { + self.host.push(Host::Unix(PathBuf::from(host))); + return self; + } + } + + self.host.push(Host::Tcp(host.to_string())); + self } - pub fn dbname(&mut self, database: &str) -> &mut Builder { - self.param("dbname", database) + #[cfg(all(feature = "runtime", unix))] + pub fn host_path(&mut self, host: T) -> &mut Builder + where + T: AsRef, + { + self.host.push(Host::Unix(host.as_ref().to_path_buf())); + self } - pub fn password(&mut self, password: &str) -> &mut Builder { - self.param("password", password) + #[cfg(feature = "runtime")] + pub fn port(&mut self, port: u16) -> &mut Builder { + self.port.push(port); + self + } + + #[cfg(feature = "runtime")] + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder { + self.connect_timeout = Some(connect_timeout); + self + } + + pub fn password(&mut self, password: T) -> &mut Builder + where + T: AsRef<[u8]>, + { + self.password = Some(password.as_ref().to_vec()); + self } pub fn param(&mut self, key: &str, value: &str) -> &mut Builder { @@ -47,17 +106,12 @@ impl Builder { self } - /// FIXME do we want this? - pub fn iter(&self) -> Iter<'_> { - Iter(self.params.iter()) - } - pub fn handshake(&self, stream: S, tls_mode: T) -> Handshake where S: AsyncRead + AsyncWrite, T: TlsMode, { - Handshake(HandshakeFuture::new(stream, tls_mode, self.params.clone())) + Handshake(HandshakeFuture::new(stream, tls_mode, self.clone())) } #[cfg(feature = "runtime")] @@ -65,7 +119,7 @@ impl Builder { where T: MakeTlsMode, { - Connect(ConnectFuture::new(make_tls_mode, self.params.clone())) + Connect(ConnectFuture::new(make_tls_mode, self.clone())) } } @@ -77,7 +131,40 @@ impl FromStr for Builder { let mut builder = Builder::new(); while let Some((key, value)) = parser.parameter()? { - builder.params.insert(key.to_string(), value); + match key { + "password" => { + builder.password(value); + } + #[cfg(feature = "runtime")] + "host" => { + for host in value.split(',') { + builder.host(host); + } + } + #[cfg(feature = "runtime")] + "port" => { + for port in value.split(',') { + let port = if port.is_empty() { + 5432 + } else { + port.parse().map_err(Error::invalid_port)? + }; + builder.port(port); + } + } + #[cfg(feature = "runtime")] + "connect_timeout" => { + let timeout = value + .parse::() + .map_err(Error::invalid_connect_timeout)?; + if timeout > 0 { + builder.connect_timeout(Duration::from_secs(timeout as u64)); + } + } + key => { + builder.param(key, &value); + } + } } Ok(builder) diff --git a/tokio-postgres/src/proto/connect.rs b/tokio-postgres/src/proto/connect.rs index 167a6a42..512178cc 100644 --- a/tokio-postgres/src/proto/connect.rs +++ b/tokio-postgres/src/proto/connect.rs @@ -1,10 +1,8 @@ use futures::{try_ready, Async, Future, Poll}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; -use std::collections::HashMap; -use std::vec; use crate::proto::{Client, ConnectOnceFuture, Connection}; -use crate::{Error, MakeTlsMode, Socket}; +use crate::{Builder, Error, Host, MakeTlsMode, Socket}; #[derive(StateMachineFuture)] pub enum Connect @@ -12,25 +10,20 @@ where T: MakeTlsMode, { #[state_machine_future(start, transitions(MakingTlsMode))] - Start { - make_tls_mode: T, - params: HashMap, - }, + Start { make_tls_mode: T, config: Builder }, #[state_machine_future(transitions(Connecting))] MakingTlsMode { future: T::Future, - host: String, - port: u16, - addrs: vec::IntoIter<(String, u16)>, + idx: usize, make_tls_mode: T, - params: HashMap, + config: Builder, }, #[state_machine_future(transitions(MakingTlsMode, Finished))] Connecting { future: ConnectOnceFuture, - addrs: vec::IntoIter<(String, u16)>, + idx: usize, make_tls_mode: T, - params: HashMap, + config: Builder, }, #[state_machine_future(ready)] Finished((Client, Connection)), @@ -45,47 +38,27 @@ where fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let mut state = state.take(); - let host = match state.params.remove("host") { - Some(host) => host, - None => return Err(Error::missing_host()), - }; - let mut addrs = host - .split(',') - .map(|s| (s.to_string(), 0u16)) - .collect::>(); - - let port = state.params.remove("port").unwrap_or_else(String::new); - let mut ports = port - .split(',') - .map(|s| { - if s.is_empty() { - Ok(5432) - } else { - s.parse::().map_err(Error::invalid_port) - } - }) - .collect::, _>>()?; - if ports.len() == 1 { - ports.resize(addrs.len(), ports[0]); + if state.config.host.is_empty() { + return Err(Error::missing_host()); } - if addrs.len() != ports.len() { + + if state.config.port.len() > 1 && state.config.port.len() != state.config.host.len() { return Err(Error::invalid_port_count()); } - for (addr, port) in addrs.iter_mut().zip(ports) { - addr.1 = port; - } - - let mut addrs = addrs.into_iter(); - let (host, port) = addrs.next().expect("addrs cannot be empty"); + let hostname = match &state.config.host[0] { + Host::Tcp(host) => &**host, + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter + #[cfg(unix)] + Host::Unix(_) => "", + }; + let future = state.make_tls_mode.make_tls_mode(hostname); transition!(MakingTlsMode { - future: state.make_tls_mode.make_tls_mode(&host), - host, - port, - addrs, + future, + idx: 0, make_tls_mode: state.make_tls_mode, - params: state.params, + config: state.config, }) } @@ -96,10 +69,10 @@ where let state = state.take(); transition!(Connecting { - future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()), - addrs: state.addrs, + future: ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone()), + idx: state.idx, make_tls_mode: state.make_tls_mode, - params: state.params, + config: state.config, }) } @@ -111,18 +84,25 @@ where Ok(Async::NotReady) => Ok(Async::NotReady), Err(e) => { let mut state = state.take(); - let (host, port) = match state.addrs.next() { - Some(addr) => addr, + let idx = state.idx + 1; + + let host = match state.config.host.get(idx) { + Some(host) => host, None => return Err(e), }; + let hostname = match host { + Host::Tcp(host) => &**host, + #[cfg(unix)] + Host::Unix(_) => "", + }; + let future = state.make_tls_mode.make_tls_mode(hostname); + transition!(MakingTlsMode { - future: state.make_tls_mode.make_tls_mode(&host), - host, - port, - addrs: state.addrs, + future, + idx, make_tls_mode: state.make_tls_mode, - params: state.params, + config: state.config, }) } } @@ -133,7 +113,7 @@ impl ConnectFuture where T: MakeTlsMode, { - pub fn new(make_tls_mode: T, params: HashMap) -> ConnectFuture { - Connect::start(make_tls_mode, params) + pub fn new(make_tls_mode: T, config: Builder) -> ConnectFuture { + Connect::start(make_tls_mode, config) } } diff --git a/tokio-postgres/src/proto/connect_once.rs b/tokio-postgres/src/proto/connect_once.rs index 4dba5c3c..da0f5778 100644 --- a/tokio-postgres/src/proto/connect_once.rs +++ b/tokio-postgres/src/proto/connect_once.rs @@ -1,13 +1,14 @@ +#![allow(clippy::large_enum_variant)] + use futures::{try_ready, Async, Future, Poll}; use futures_cpupool::{CpuFuture, CpuPool}; use lazy_static::lazy_static; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; -use std::collections::HashMap; use std::io; use std::net::{SocketAddr, ToSocketAddrs}; #[cfg(unix)] use std::path::Path; -use std::time::{Duration, Instant}; +use std::time::Instant; use std::vec; use tokio_tcp::TcpStream; use tokio_timer::Delay; @@ -15,7 +16,7 @@ use tokio_timer::Delay; use tokio_uds::UnixStream; use crate::proto::{Client, Connection, HandshakeFuture}; -use crate::{Error, Socket, TlsMode}; +use crate::{Builder, Error, Host, Socket, TlsMode}; lazy_static! { static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new() @@ -33,36 +34,32 @@ where #[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))] #[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))] Start { - host: String, - port: u16, + idx: usize, tls_mode: T, - params: HashMap, + config: Builder, }, #[cfg(unix)] #[state_machine_future(transitions(Handshaking))] ConnectingUnix { future: tokio_uds::ConnectFuture, - connect_timeout: Option, timeout: Option, tls_mode: T, - params: HashMap, + config: Builder, }, #[state_machine_future(transitions(ConnectingTcp))] ResolvingDns { future: CpuFuture, io::Error>, - connect_timeout: Option, timeout: Option, tls_mode: T, - params: HashMap, + config: Builder, }, #[state_machine_future(transitions(Handshaking))] ConnectingTcp { future: tokio_tcp::ConnectFuture, addrs: vec::IntoIter, - connect_timeout: Option, timeout: Option, tls_mode: T, - params: HashMap, + config: Builder, }, #[state_machine_future(transitions(Finished))] Handshaking { future: HandshakeFuture }, @@ -77,44 +74,41 @@ where T: TlsMode, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { - let mut state = state.take(); + let state = state.take(); - let connect_timeout = match state.params.remove("connect_timeout") { - Some(s) => { - let seconds = s.parse::().map_err(Error::invalid_connect_timeout)?; - if seconds <= 0 { - None - } else { - Some(Duration::from_secs(seconds as u64)) - } - } - None => None, - }; - let timeout = connect_timeout.map(|d| Delay::new(Instant::now() + d)); + let port = *state + .config + .port + .get(state.idx) + .or_else(|| state.config.port.get(0)) + .unwrap_or(&5432); - #[cfg(unix)] - { - if state.host.starts_with('/') { - let path = Path::new(&state.host).join(format!(".s.PGSQL.{}", state.port)); - transition!(ConnectingUnix { - future: UnixStream::connect(path), - connect_timeout, + let timeout = state + .config + .connect_timeout + .map(|d| Delay::new(Instant::now() + d)); + + match &state.config.host[state.idx] { + Host::Tcp(host) => { + let host = host.clone(); + transition!(ResolvingDns { + future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()), timeout, tls_mode: state.tls_mode, - params: state.params, + config: state.config, + }) + } + #[cfg(unix)] + Host::Unix(host) => { + let path = Path::new(host).join(format!(".s.PGSQL.{}", port)); + transition!(ConnectingUnix { + future: UnixStream::connect(path), + timeout, + tls_mode: state.tls_mode, + config: state.config, }) } } - - let host = state.host; - let port = state.port; - transition!(ResolvingDns { - future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()), - connect_timeout, - timeout, - tls_mode: state.tls_mode, - params: state.params, - }) } #[cfg(unix)] @@ -134,7 +128,7 @@ where let state = state.take(); transition!(Handshaking { - future: HandshakeFuture::new(stream, state.tls_mode, state.params) + future: HandshakeFuture::new(stream, state.tls_mode, state.config) }) } @@ -165,10 +159,9 @@ where transition!(ConnectingTcp { future: TcpStream::connect(&addr), addrs, - connect_timeout: state.connect_timeout, timeout: state.timeout, tls_mode: state.tls_mode, - params: state.params, + config: state.config, }) } @@ -202,7 +195,7 @@ where let stream = Socket::new_tcp(stream); transition!(Handshaking { - future: HandshakeFuture::new(stream, state.tls_mode, state.params), + future: HandshakeFuture::new(stream, state.tls_mode, state.config), }) } @@ -219,12 +212,7 @@ impl ConnectOnceFuture where T: TlsMode, { - pub fn new( - host: String, - port: u16, - tls_mode: T, - params: HashMap, - ) -> ConnectOnceFuture { - ConnectOnce::start(host, port, tls_mode, params) + pub fn new(idx: usize, tls_mode: T, config: Builder) -> ConnectOnceFuture { + ConnectOnce::start(idx, tls_mode, config) } } diff --git a/tokio-postgres/src/proto/handshake.rs b/tokio-postgres/src/proto/handshake.rs index 3f7127b4..79bbb513 100644 --- a/tokio-postgres/src/proto/handshake.rs +++ b/tokio-postgres/src/proto/handshake.rs @@ -13,7 +13,7 @@ use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::{Client, Connection, PostgresCodec, TlsFuture}; -use crate::{CancelData, ChannelBinding, Error, TlsMode}; +use crate::{Builder, CancelData, ChannelBinding, Error, TlsMode}; #[derive(StateMachineFuture)] pub enum Handshake @@ -24,20 +24,18 @@ where #[state_machine_future(start, transitions(SendingStartup))] Start { future: TlsFuture, - params: HashMap, + config: Builder, }, #[state_machine_future(transitions(ReadingAuth))] SendingStartup { future: sink::Send>, - user: String, - password: Option, + config: Builder, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] ReadingAuth { stream: Framed, - user: String, - password: Option, + config: Builder, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingAuthCompletion))] @@ -77,31 +75,24 @@ where { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { let (stream, channel_binding) = try_ready!(state.future.poll()); - let mut state = state.take(); - - // we don't want to send the password as a param - let password = state.params.remove("password"); - - // libpq uses the parameter "dbname" but the protocol expects "database" (!?!) - if let Some(dbname) = state.params.remove("dbname") { - state.params.insert("database".to_string(), dbname); - } + let state = state.take(); let mut buf = vec![]; - frontend::startup_message(state.params.iter().map(|(k, v)| (&**k, &**v)), &mut buf) - .map_err(Error::encode)?; + frontend::startup_message( + state.config.params.iter().map(|(k, v)| { + // libpq uses dbname, but the backend expects database (!) + let k = if k == "dbname" { "database" } else { &**k }; + (k, &**v) + }), + &mut buf, + ) + .map_err(Error::encode)?; let stream = Framed::new(stream, PostgresCodec); - let user = state - .params - .remove("user") - .ok_or_else(Error::missing_user)?; - transition!(SendingStartup { future: stream.send(buf), - user, - password, + config: state.config, channel_binding, }) } @@ -113,8 +104,7 @@ where let state = state.take(); transition!(ReadingAuth { stream, - user: state.user, - password: state.password, + config: state.config, channel_binding: state.channel_binding, }) } @@ -132,17 +122,29 @@ where parameters: HashMap::new(), }), Some(Message::AuthenticationCleartextPassword) => { - let pass = state.password.ok_or_else(Error::missing_password)?; + let pass = state + .config + .password + .as_ref() + .ok_or_else(Error::missing_password)?; let mut buf = vec![]; - frontend::password_message(pass.as_bytes(), &mut buf).map_err(Error::encode)?; + frontend::password_message(pass, &mut buf).map_err(Error::encode)?; transition!(SendingPassword { future: state.stream.send(buf) }) } Some(Message::AuthenticationMd5Password(body)) => { - let pass = state.password.ok_or_else(Error::missing_password)?; - let output = - authentication::md5_hash(state.user.as_bytes(), pass.as_bytes(), body.salt()); + let user = state + .config + .params + .get("user") + .ok_or_else(Error::missing_user)?; + let pass = state + .config + .password + .as_ref() + .ok_or_else(Error::missing_password)?; + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); let mut buf = vec![]; frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?; transition!(SendingPassword { @@ -150,7 +152,11 @@ where }) } Some(Message::AuthenticationSasl(body)) => { - let pass = state.password.ok_or_else(Error::missing_password)?; + let pass = state + .config + .password + .as_ref() + .ok_or_else(Error::missing_password)?; let mut has_scram = false; let mut has_scram_plus = false; @@ -187,7 +193,7 @@ where return Err(Error::unsupported_authentication()); }; - let scram = ScramSha256::new(pass.as_bytes(), channel_binding); + let scram = ScramSha256::new(pass, channel_binding); let mut buf = vec![]; frontend::sasl_initial_response(mechanism, scram.message(), &mut buf) @@ -324,7 +330,7 @@ where S: AsyncRead + AsyncWrite, T: TlsMode, { - pub fn new(stream: S, tls_mode: T, params: HashMap) -> HandshakeFuture { - Handshake::start(TlsFuture::new(stream, tls_mode), params) + pub fn new(stream: S, tls_mode: T, config: Builder) -> HandshakeFuture { + Handshake::start(TlsFuture::new(stream, tls_mode), config) } } diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index 236d4870..c5d2e0fa 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -1,18 +1,14 @@ -use std::collections::HashMap; - #[test] fn pairs_ok() { let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''" .parse::() .unwrap(); - let params = params.iter().collect::>(); - let mut expected = HashMap::new(); - expected.insert("user", "foo"); - expected.insert("password", r" fizz 'buzz\ "); - expected.insert("thing", ""); - expected.insert("client_encoding", "UTF8"); - expected.insert("timezone", "GMT"); + let mut expected = tokio_postgres::Builder::new(); + expected + .param("user", "foo") + .password(r" fizz 'buzz\ ") + .param("thing", ""); assert_eq!(params, expected); } @@ -22,13 +18,9 @@ fn pairs_ws() { let params = " user\t=\r\n\x0bfoo \t password = hunter2 " .parse::() .unwrap();; - let params = params.iter().collect::>(); - let mut expected = HashMap::new(); - expected.insert("user", "foo"); - expected.insert("password", r"hunter2"); - expected.insert("client_encoding", "UTF8"); - expected.insert("timezone", "GMT"); + let mut expected = tokio_postgres::Builder::new(); + expected.param("user", "foo").password("hunter2"); assert_eq!(params, expected); }