A less stringy builder
This allows us to support things like non-utf8 passwords and unix socket directories.
This commit is contained in:
parent
e80e1fcaaf
commit
635e6381b3
@ -1,7 +1,9 @@
|
||||
use futures::sync::oneshot;
|
||||
use futures::Future;
|
||||
use log::error;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use tokio_postgres::{Error, MakeTlsMode, Socket, TlsMode};
|
||||
|
||||
use crate::{Client, RUNTIME};
|
||||
@ -19,11 +21,43 @@ impl Builder {
|
||||
Builder(tokio_postgres::Builder::new())
|
||||
}
|
||||
|
||||
pub fn host(&mut self, host: &str) -> &mut Builder {
|
||||
self.0.host(host);
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn host_path<T>(&mut self, host: T) -> &mut Builder
|
||||
where
|
||||
T: AsRef<Path>,
|
||||
{
|
||||
self.0.host_path(host);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn port(&mut self, port: u16) -> &mut Builder {
|
||||
self.0.port(port);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
|
||||
self.0.param(key, value);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder {
|
||||
self.0.connect_timeout(connect_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Builder
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
self.0.password(password);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn connect<T>(&self, tls_mode: T) -> Result<Client, Error>
|
||||
where
|
||||
T: MakeTlsMode<Socket> + 'static + Send,
|
||||
|
@ -6,13 +6,15 @@ use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
|
||||
|
||||
use crate::TlsConnector;
|
||||
|
||||
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
|
||||
fn smoke_test<T>(s: &str, tls: T)
|
||||
where
|
||||
T: TlsMode<TcpStream>,
|
||||
T::Stream: 'static,
|
||||
{
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let builder = s.parse::<tokio_postgres::Builder>().unwrap();
|
||||
|
||||
let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
|
||||
.map_err(|e| panic!("{}", e))
|
||||
.and_then(|s| builder.handshake(s, tls));
|
||||
@ -42,9 +44,7 @@ fn require() {
|
||||
.build()
|
||||
.unwrap();
|
||||
smoke_test(
|
||||
tokio_postgres::Builder::new()
|
||||
.user("ssl_user")
|
||||
.dbname("postgres"),
|
||||
"user=ssl_user dbname=postgres",
|
||||
RequireTls(TlsConnector::with_connector(connector, "localhost")),
|
||||
);
|
||||
}
|
||||
@ -58,9 +58,7 @@ fn prefer() {
|
||||
.build()
|
||||
.unwrap();
|
||||
smoke_test(
|
||||
tokio_postgres::Builder::new()
|
||||
.user("ssl_user")
|
||||
.dbname("postgres"),
|
||||
"user=ssl_user dbname=postgres",
|
||||
PreferTls(TlsConnector::with_connector(connector, "localhost")),
|
||||
);
|
||||
}
|
||||
@ -74,10 +72,7 @@ fn scram_user() {
|
||||
.build()
|
||||
.unwrap();
|
||||
smoke_test(
|
||||
tokio_postgres::Builder::new()
|
||||
.user("scram_user")
|
||||
.password("password")
|
||||
.dbname("postgres"),
|
||||
"user=scram_user password=password dbname=postgres",
|
||||
RequireTls(TlsConnector::with_connector(connector, "localhost")),
|
||||
);
|
||||
}
|
||||
|
@ -6,13 +6,15 @@ use tokio_postgres::{self, PreferTls, RequireTls, TlsMode};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn smoke_test<T>(builder: &tokio_postgres::Builder, tls: T)
|
||||
fn smoke_test<T>(s: &str, tls: T)
|
||||
where
|
||||
T: TlsMode<TcpStream>,
|
||||
T::Stream: 'static,
|
||||
{
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let builder = s.parse::<tokio_postgres::Builder>().unwrap();
|
||||
|
||||
let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
|
||||
.map_err(|e| panic!("{}", e))
|
||||
.and_then(|s| builder.handshake(s, tls));
|
||||
@ -39,9 +41,7 @@ fn require() {
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
tokio_postgres::Builder::new()
|
||||
.user("ssl_user")
|
||||
.dbname("postgres"),
|
||||
"user=ssl_user dbname=postgres",
|
||||
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
);
|
||||
}
|
||||
@ -52,9 +52,7 @@ fn prefer() {
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
tokio_postgres::Builder::new()
|
||||
.user("ssl_user")
|
||||
.dbname("postgres"),
|
||||
"user=ssl_user dbname=postgres",
|
||||
PreferTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
);
|
||||
}
|
||||
@ -65,10 +63,7 @@ fn scram_user() {
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let ctx = builder.build();
|
||||
smoke_test(
|
||||
tokio_postgres::Builder::new()
|
||||
.user("scram_user")
|
||||
.password("password")
|
||||
.dbname("postgres"),
|
||||
"user=scram_user password=password dbname=postgres",
|
||||
RequireTls(TlsConnector::new(ctx.configure().unwrap(), "localhost")),
|
||||
);
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
use std::collections::hash_map::{self, HashMap};
|
||||
use std::iter;
|
||||
#[cfg(all(feature = "runtime", unix))]
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::{self, FromStr};
|
||||
#[cfg(feature = "runtime")]
|
||||
use std::time::Duration;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -10,9 +14,24 @@ use crate::proto::HandshakeFuture;
|
||||
use crate::{Connect, MakeTlsMode, Socket};
|
||||
use crate::{Error, Handshake, TlsMode};
|
||||
|
||||
#[derive(Clone)]
|
||||
#[cfg(feature = "runtime")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) enum Host {
|
||||
Tcp(String),
|
||||
#[cfg(unix)]
|
||||
Unix(PathBuf),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Builder {
|
||||
params: HashMap<String, String>,
|
||||
pub(crate) params: HashMap<String, String>,
|
||||
pub(crate) password: Option<Vec<u8>>,
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) host: Vec<Host>,
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) port: Vec<u16>,
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) connect_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
impl Default for Builder {
|
||||
@ -27,19 +46,59 @@ impl Builder {
|
||||
params.insert("client_encoding".to_string(), "UTF8".to_string());
|
||||
params.insert("timezone".to_string(), "GMT".to_string());
|
||||
|
||||
Builder { params }
|
||||
Builder {
|
||||
params,
|
||||
password: None,
|
||||
#[cfg(feature = "runtime")]
|
||||
host: vec![],
|
||||
#[cfg(feature = "runtime")]
|
||||
port: vec![],
|
||||
#[cfg(feature = "runtime")]
|
||||
connect_timeout: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(&mut self, user: &str) -> &mut Builder {
|
||||
self.param("user", user)
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn host(&mut self, host: &str) -> &mut Builder {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if host.starts_with('/') {
|
||||
self.host.push(Host::Unix(PathBuf::from(host)));
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
||||
self.host.push(Host::Tcp(host.to_string()));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn dbname(&mut self, database: &str) -> &mut Builder {
|
||||
self.param("dbname", database)
|
||||
#[cfg(all(feature = "runtime", unix))]
|
||||
pub fn host_path<T>(&mut self, host: T) -> &mut Builder
|
||||
where
|
||||
T: AsRef<Path>,
|
||||
{
|
||||
self.host.push(Host::Unix(host.as_ref().to_path_buf()));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn password(&mut self, password: &str) -> &mut Builder {
|
||||
self.param("password", password)
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn port(&mut self, port: u16) -> &mut Builder {
|
||||
self.port.push(port);
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder {
|
||||
self.connect_timeout = Some(connect_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Builder
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
self.password = Some(password.as_ref().to_vec());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
|
||||
@ -47,17 +106,12 @@ impl Builder {
|
||||
self
|
||||
}
|
||||
|
||||
/// FIXME do we want this?
|
||||
pub fn iter(&self) -> Iter<'_> {
|
||||
Iter(self.params.iter())
|
||||
}
|
||||
|
||||
pub fn handshake<S, T>(&self, stream: S, tls_mode: T) -> Handshake<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
Handshake(HandshakeFuture::new(stream, tls_mode, self.params.clone()))
|
||||
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -65,7 +119,7 @@ impl Builder {
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
Connect(ConnectFuture::new(make_tls_mode, self.params.clone()))
|
||||
Connect(ConnectFuture::new(make_tls_mode, self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,7 +131,40 @@ impl FromStr for Builder {
|
||||
let mut builder = Builder::new();
|
||||
|
||||
while let Some((key, value)) = parser.parameter()? {
|
||||
builder.params.insert(key.to_string(), value);
|
||||
match key {
|
||||
"password" => {
|
||||
builder.password(value);
|
||||
}
|
||||
#[cfg(feature = "runtime")]
|
||||
"host" => {
|
||||
for host in value.split(',') {
|
||||
builder.host(host);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "runtime")]
|
||||
"port" => {
|
||||
for port in value.split(',') {
|
||||
let port = if port.is_empty() {
|
||||
5432
|
||||
} else {
|
||||
port.parse().map_err(Error::invalid_port)?
|
||||
};
|
||||
builder.port(port);
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "runtime")]
|
||||
"connect_timeout" => {
|
||||
let timeout = value
|
||||
.parse::<i64>()
|
||||
.map_err(Error::invalid_connect_timeout)?;
|
||||
if timeout > 0 {
|
||||
builder.connect_timeout(Duration::from_secs(timeout as u64));
|
||||
}
|
||||
}
|
||||
key => {
|
||||
builder.param(key, &value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
|
@ -1,10 +1,8 @@
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::collections::HashMap;
|
||||
use std::vec;
|
||||
|
||||
use crate::proto::{Client, ConnectOnceFuture, Connection};
|
||||
use crate::{Error, MakeTlsMode, Socket};
|
||||
use crate::{Builder, Error, Host, MakeTlsMode, Socket};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Connect<T>
|
||||
@ -12,25 +10,20 @@ where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(MakingTlsMode))]
|
||||
Start {
|
||||
make_tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
},
|
||||
Start { make_tls_mode: T, config: Builder },
|
||||
#[state_machine_future(transitions(Connecting))]
|
||||
MakingTlsMode {
|
||||
future: T::Future,
|
||||
host: String,
|
||||
port: u16,
|
||||
addrs: vec::IntoIter<(String, u16)>,
|
||||
idx: usize,
|
||||
make_tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[state_machine_future(transitions(MakingTlsMode, Finished))]
|
||||
Connecting {
|
||||
future: ConnectOnceFuture<T::TlsMode>,
|
||||
addrs: vec::IntoIter<(String, u16)>,
|
||||
idx: usize,
|
||||
make_tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
@ -45,47 +38,27 @@ where
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
|
||||
let host = match state.params.remove("host") {
|
||||
Some(host) => host,
|
||||
None => return Err(Error::missing_host()),
|
||||
};
|
||||
let mut addrs = host
|
||||
.split(',')
|
||||
.map(|s| (s.to_string(), 0u16))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let port = state.params.remove("port").unwrap_or_else(String::new);
|
||||
let mut ports = port
|
||||
.split(',')
|
||||
.map(|s| {
|
||||
if s.is_empty() {
|
||||
Ok(5432)
|
||||
} else {
|
||||
s.parse::<u16>().map_err(Error::invalid_port)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if ports.len() == 1 {
|
||||
ports.resize(addrs.len(), ports[0]);
|
||||
if state.config.host.is_empty() {
|
||||
return Err(Error::missing_host());
|
||||
}
|
||||
if addrs.len() != ports.len() {
|
||||
|
||||
if state.config.port.len() > 1 && state.config.port.len() != state.config.host.len() {
|
||||
return Err(Error::invalid_port_count());
|
||||
}
|
||||
|
||||
for (addr, port) in addrs.iter_mut().zip(ports) {
|
||||
addr.1 = port;
|
||||
}
|
||||
|
||||
let mut addrs = addrs.into_iter();
|
||||
let (host, port) = addrs.next().expect("addrs cannot be empty");
|
||||
let hostname = match &state.config.host[0] {
|
||||
Host::Tcp(host) => &**host,
|
||||
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
let future = state.make_tls_mode.make_tls_mode(hostname);
|
||||
|
||||
transition!(MakingTlsMode {
|
||||
future: state.make_tls_mode.make_tls_mode(&host),
|
||||
host,
|
||||
port,
|
||||
addrs,
|
||||
future,
|
||||
idx: 0,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
params: state.params,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
|
||||
@ -96,10 +69,10 @@ where
|
||||
let state = state.take();
|
||||
|
||||
transition!(Connecting {
|
||||
future: ConnectOnceFuture::new(state.host, state.port, tls_mode, state.params.clone()),
|
||||
addrs: state.addrs,
|
||||
future: ConnectOnceFuture::new(state.idx, tls_mode, state.config.clone()),
|
||||
idx: state.idx,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
params: state.params,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
|
||||
@ -111,18 +84,25 @@ where
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(e) => {
|
||||
let mut state = state.take();
|
||||
let (host, port) = match state.addrs.next() {
|
||||
Some(addr) => addr,
|
||||
let idx = state.idx + 1;
|
||||
|
||||
let host = match state.config.host.get(idx) {
|
||||
Some(host) => host,
|
||||
None => return Err(e),
|
||||
};
|
||||
|
||||
let hostname = match host {
|
||||
Host::Tcp(host) => &**host,
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
let future = state.make_tls_mode.make_tls_mode(hostname);
|
||||
|
||||
transition!(MakingTlsMode {
|
||||
future: state.make_tls_mode.make_tls_mode(&host),
|
||||
host,
|
||||
port,
|
||||
addrs: state.addrs,
|
||||
future,
|
||||
idx,
|
||||
make_tls_mode: state.make_tls_mode,
|
||||
params: state.params,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -133,7 +113,7 @@ impl<T> ConnectFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
pub fn new(make_tls_mode: T, params: HashMap<String, String>) -> ConnectFuture<T> {
|
||||
Connect::start(make_tls_mode, params)
|
||||
pub fn new(make_tls_mode: T, config: Builder) -> ConnectFuture<T> {
|
||||
Connect::start(make_tls_mode, config)
|
||||
}
|
||||
}
|
||||
|
@ -1,13 +1,14 @@
|
||||
#![allow(clippy::large_enum_variant)]
|
||||
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use futures_cpupool::{CpuFuture, CpuPool};
|
||||
use lazy_static::lazy_static;
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
#[cfg(unix)]
|
||||
use std::path::Path;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::Instant;
|
||||
use std::vec;
|
||||
use tokio_tcp::TcpStream;
|
||||
use tokio_timer::Delay;
|
||||
@ -15,7 +16,7 @@ use tokio_timer::Delay;
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
use crate::proto::{Client, Connection, HandshakeFuture};
|
||||
use crate::{Error, Socket, TlsMode};
|
||||
use crate::{Builder, Error, Host, Socket, TlsMode};
|
||||
|
||||
lazy_static! {
|
||||
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
|
||||
@ -33,36 +34,32 @@ where
|
||||
#[cfg_attr(unix, state_machine_future(transitions(ConnectingUnix, ResolvingDns)))]
|
||||
#[cfg_attr(not(unix), state_machine_future(transitions(ConnectingTcp)))]
|
||||
Start {
|
||||
host: String,
|
||||
port: u16,
|
||||
idx: usize,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[cfg(unix)]
|
||||
#[state_machine_future(transitions(Handshaking))]
|
||||
ConnectingUnix {
|
||||
future: tokio_uds::ConnectFuture,
|
||||
connect_timeout: Option<Duration>,
|
||||
timeout: Option<Delay>,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[state_machine_future(transitions(ConnectingTcp))]
|
||||
ResolvingDns {
|
||||
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
|
||||
connect_timeout: Option<Duration>,
|
||||
timeout: Option<Delay>,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[state_machine_future(transitions(Handshaking))]
|
||||
ConnectingTcp {
|
||||
future: tokio_tcp::ConnectFuture,
|
||||
addrs: vec::IntoIter<SocketAddr>,
|
||||
connect_timeout: Option<Duration>,
|
||||
timeout: Option<Delay>,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Handshaking { future: HandshakeFuture<Socket, T> },
|
||||
@ -77,44 +74,41 @@ where
|
||||
T: TlsMode<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
let state = state.take();
|
||||
|
||||
let connect_timeout = match state.params.remove("connect_timeout") {
|
||||
Some(s) => {
|
||||
let seconds = s.parse::<i64>().map_err(Error::invalid_connect_timeout)?;
|
||||
if seconds <= 0 {
|
||||
None
|
||||
} else {
|
||||
Some(Duration::from_secs(seconds as u64))
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let timeout = connect_timeout.map(|d| Delay::new(Instant::now() + d));
|
||||
let port = *state
|
||||
.config
|
||||
.port
|
||||
.get(state.idx)
|
||||
.or_else(|| state.config.port.get(0))
|
||||
.unwrap_or(&5432);
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if state.host.starts_with('/') {
|
||||
let path = Path::new(&state.host).join(format!(".s.PGSQL.{}", state.port));
|
||||
transition!(ConnectingUnix {
|
||||
future: UnixStream::connect(path),
|
||||
connect_timeout,
|
||||
let timeout = state
|
||||
.config
|
||||
.connect_timeout
|
||||
.map(|d| Delay::new(Instant::now() + d));
|
||||
|
||||
match &state.config.host[state.idx] {
|
||||
Host::Tcp(host) => {
|
||||
let host = host.clone();
|
||||
transition!(ResolvingDns {
|
||||
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
|
||||
timeout,
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Host::Unix(host) => {
|
||||
let path = Path::new(host).join(format!(".s.PGSQL.{}", port));
|
||||
transition!(ConnectingUnix {
|
||||
future: UnixStream::connect(path),
|
||||
timeout,
|
||||
tls_mode: state.tls_mode,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let host = state.host;
|
||||
let port = state.port;
|
||||
transition!(ResolvingDns {
|
||||
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
|
||||
connect_timeout,
|
||||
timeout,
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
@ -134,7 +128,7 @@ where
|
||||
let state = state.take();
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.params)
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.config)
|
||||
})
|
||||
}
|
||||
|
||||
@ -165,10 +159,9 @@ where
|
||||
transition!(ConnectingTcp {
|
||||
future: TcpStream::connect(&addr),
|
||||
addrs,
|
||||
connect_timeout: state.connect_timeout,
|
||||
timeout: state.timeout,
|
||||
tls_mode: state.tls_mode,
|
||||
params: state.params,
|
||||
config: state.config,
|
||||
})
|
||||
}
|
||||
|
||||
@ -202,7 +195,7 @@ where
|
||||
let stream = Socket::new_tcp(stream);
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.params),
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
|
||||
})
|
||||
}
|
||||
|
||||
@ -219,12 +212,7 @@ impl<T> ConnectOnceFuture<T>
|
||||
where
|
||||
T: TlsMode<Socket>,
|
||||
{
|
||||
pub fn new(
|
||||
host: String,
|
||||
port: u16,
|
||||
tls_mode: T,
|
||||
params: HashMap<String, String>,
|
||||
) -> ConnectOnceFuture<T> {
|
||||
ConnectOnce::start(host, port, tls_mode, params)
|
||||
pub fn new(idx: usize, tls_mode: T, config: Builder) -> ConnectOnceFuture<T> {
|
||||
ConnectOnce::start(idx, tls_mode, config)
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::{Client, Connection, PostgresCodec, TlsFuture};
|
||||
use crate::{CancelData, ChannelBinding, Error, TlsMode};
|
||||
use crate::{Builder, CancelData, ChannelBinding, Error, TlsMode};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Handshake<S, T>
|
||||
@ -24,20 +24,18 @@ where
|
||||
#[state_machine_future(start, transitions(SendingStartup))]
|
||||
Start {
|
||||
future: TlsFuture<S, T>,
|
||||
params: HashMap<String, String>,
|
||||
config: Builder,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuth))]
|
||||
SendingStartup {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
user: String,
|
||||
password: Option<String>,
|
||||
config: Builder,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
|
||||
ReadingAuth {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
user: String,
|
||||
password: Option<String>,
|
||||
config: Builder,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuthCompletion))]
|
||||
@ -77,31 +75,24 @@ where
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S, T>>) -> Poll<AfterStart<S, T>, Error> {
|
||||
let (stream, channel_binding) = try_ready!(state.future.poll());
|
||||
let mut state = state.take();
|
||||
|
||||
// we don't want to send the password as a param
|
||||
let password = state.params.remove("password");
|
||||
|
||||
// libpq uses the parameter "dbname" but the protocol expects "database" (!?!)
|
||||
if let Some(dbname) = state.params.remove("dbname") {
|
||||
state.params.insert("database".to_string(), dbname);
|
||||
}
|
||||
let state = state.take();
|
||||
|
||||
let mut buf = vec![];
|
||||
frontend::startup_message(state.params.iter().map(|(k, v)| (&**k, &**v)), &mut buf)
|
||||
.map_err(Error::encode)?;
|
||||
frontend::startup_message(
|
||||
state.config.params.iter().map(|(k, v)| {
|
||||
// libpq uses dbname, but the backend expects database (!)
|
||||
let k = if k == "dbname" { "database" } else { &**k };
|
||||
(k, &**v)
|
||||
}),
|
||||
&mut buf,
|
||||
)
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
let stream = Framed::new(stream, PostgresCodec);
|
||||
|
||||
let user = state
|
||||
.params
|
||||
.remove("user")
|
||||
.ok_or_else(Error::missing_user)?;
|
||||
|
||||
transition!(SendingStartup {
|
||||
future: stream.send(buf),
|
||||
user,
|
||||
password,
|
||||
config: state.config,
|
||||
channel_binding,
|
||||
})
|
||||
}
|
||||
@ -113,8 +104,7 @@ where
|
||||
let state = state.take();
|
||||
transition!(ReadingAuth {
|
||||
stream,
|
||||
user: state.user,
|
||||
password: state.password,
|
||||
config: state.config,
|
||||
channel_binding: state.channel_binding,
|
||||
})
|
||||
}
|
||||
@ -132,17 +122,29 @@ where
|
||||
parameters: HashMap::new(),
|
||||
}),
|
||||
Some(Message::AuthenticationCleartextPassword) => {
|
||||
let pass = state.password.ok_or_else(Error::missing_password)?;
|
||||
let pass = state
|
||||
.config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(Error::missing_password)?;
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(pass.as_bytes(), &mut buf).map_err(Error::encode)?;
|
||||
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf)
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
let pass = state.password.ok_or_else(Error::missing_password)?;
|
||||
let output =
|
||||
authentication::md5_hash(state.user.as_bytes(), pass.as_bytes(), body.salt());
|
||||
let user = state
|
||||
.config
|
||||
.params
|
||||
.get("user")
|
||||
.ok_or_else(Error::missing_user)?;
|
||||
let pass = state
|
||||
.config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(Error::missing_password)?;
|
||||
let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
@ -150,7 +152,11 @@ where
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
let pass = state.password.ok_or_else(Error::missing_password)?;
|
||||
let pass = state
|
||||
.config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(Error::missing_password)?;
|
||||
|
||||
let mut has_scram = false;
|
||||
let mut has_scram_plus = false;
|
||||
@ -187,7 +193,7 @@ where
|
||||
return Err(Error::unsupported_authentication());
|
||||
};
|
||||
|
||||
let scram = ScramSha256::new(pass.as_bytes(), channel_binding);
|
||||
let scram = ScramSha256::new(pass, channel_binding);
|
||||
|
||||
let mut buf = vec![];
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf)
|
||||
@ -324,7 +330,7 @@ where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
pub fn new(stream: S, tls_mode: T, params: HashMap<String, String>) -> HandshakeFuture<S, T> {
|
||||
Handshake::start(TlsFuture::new(stream, tls_mode), params)
|
||||
pub fn new(stream: S, tls_mode: T, config: Builder) -> HandshakeFuture<S, T> {
|
||||
Handshake::start(TlsFuture::new(stream, tls_mode), config)
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn pairs_ok() {
|
||||
let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''"
|
||||
.parse::<tokio_postgres::Builder>()
|
||||
.unwrap();
|
||||
let params = params.iter().collect::<HashMap<_, _>>();
|
||||
|
||||
let mut expected = HashMap::new();
|
||||
expected.insert("user", "foo");
|
||||
expected.insert("password", r" fizz 'buzz\ ");
|
||||
expected.insert("thing", "");
|
||||
expected.insert("client_encoding", "UTF8");
|
||||
expected.insert("timezone", "GMT");
|
||||
let mut expected = tokio_postgres::Builder::new();
|
||||
expected
|
||||
.param("user", "foo")
|
||||
.password(r" fizz 'buzz\ ")
|
||||
.param("thing", "");
|
||||
|
||||
assert_eq!(params, expected);
|
||||
}
|
||||
@ -22,13 +18,9 @@ fn pairs_ws() {
|
||||
let params = " user\t=\r\n\x0bfoo \t password = hunter2 "
|
||||
.parse::<tokio_postgres::Builder>()
|
||||
.unwrap();;
|
||||
let params = params.iter().collect::<HashMap<_, _>>();
|
||||
|
||||
let mut expected = HashMap::new();
|
||||
expected.insert("user", "foo");
|
||||
expected.insert("password", r"hunter2");
|
||||
expected.insert("client_encoding", "UTF8");
|
||||
expected.insert("timezone", "GMT");
|
||||
let mut expected = tokio_postgres::Builder::new();
|
||||
expected.param("user", "foo").password("hunter2");
|
||||
|
||||
assert_eq!(params, expected);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user