Wrap Builder in an Arc

The builder ends up being cloned a couple of times per connection, so
use Arc::get_mut to make that faster.
This commit is contained in:
Steven Fackler 2018-12-28 14:16:38 -05:00
parent 540bcc5556
commit 634d24a951
4 changed files with 35 additions and 17 deletions

View File

@ -3,6 +3,7 @@ use std::iter;
#[cfg(all(feature = "runtime", unix))]
use std::path::{Path, PathBuf};
use std::str::{self, FromStr};
use std::sync::Arc;
#[cfg(feature = "runtime")]
use std::time::Duration;
use tokio_io::{AsyncRead, AsyncWrite};
@ -23,7 +24,7 @@ pub(crate) enum Host {
}
#[derive(Debug, Clone, PartialEq)]
pub struct Builder {
pub(crate) struct Inner {
pub(crate) params: HashMap<String, String>,
pub(crate) password: Option<Vec<u8>>,
#[cfg(feature = "runtime")]
@ -34,6 +35,9 @@ pub struct Builder {
pub(crate) connect_timeout: Option<Duration>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Builder(pub(crate) Arc<Inner>);
impl Default for Builder {
fn default() -> Builder {
Builder::new()
@ -46,7 +50,7 @@ impl Builder {
params.insert("client_encoding".to_string(), "UTF8".to_string());
params.insert("timezone".to_string(), "GMT".to_string());
Builder {
Builder(Arc::new(Inner {
params,
password: None,
#[cfg(feature = "runtime")]
@ -55,7 +59,7 @@ impl Builder {
port: vec![],
#[cfg(feature = "runtime")]
connect_timeout: None,
}
}))
}
#[cfg(feature = "runtime")]
@ -63,12 +67,16 @@ impl Builder {
#[cfg(unix)]
{
if host.starts_with('/') {
self.host.push(Host::Unix(PathBuf::from(host)));
Arc::make_mut(&mut self.0)
.host
.push(Host::Unix(PathBuf::from(host)));
return self;
}
}
self.host.push(Host::Tcp(host.to_string()));
Arc::make_mut(&mut self.0)
.host
.push(Host::Tcp(host.to_string()));
self
}
@ -77,19 +85,21 @@ impl Builder {
where
T: AsRef<Path>,
{
self.host.push(Host::Unix(host.as_ref().to_path_buf()));
Arc::make_mut(&mut self.0)
.host
.push(Host::Unix(host.as_ref().to_path_buf()));
self
}
#[cfg(feature = "runtime")]
pub fn port(&mut self, port: u16) -> &mut Builder {
self.port.push(port);
Arc::make_mut(&mut self.0).port.push(port);
self
}
#[cfg(feature = "runtime")]
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder {
self.connect_timeout = Some(connect_timeout);
Arc::make_mut(&mut self.0).connect_timeout = Some(connect_timeout);
self
}
@ -97,12 +107,14 @@ impl Builder {
where
T: AsRef<[u8]>,
{
self.password = Some(password.as_ref().to_vec());
Arc::make_mut(&mut self.0).password = Some(password.as_ref().to_vec());
self
}
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
self.params.insert(key.to_string(), value.to_string());
Arc::make_mut(&mut self.0)
.params
.insert(key.to_string(), value.to_string());
self
}

View File

@ -38,15 +38,15 @@ where
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
if state.config.host.is_empty() {
if state.config.0.host.is_empty() {
return Err(Error::missing_host());
}
if state.config.port.len() > 1 && state.config.port.len() != state.config.host.len() {
if state.config.0.port.len() > 1 && state.config.0.port.len() != state.config.0.host.len() {
return Err(Error::invalid_port_count());
}
let hostname = match &state.config.host[0] {
let hostname = match &state.config.0.host[0] {
Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)]
@ -86,7 +86,7 @@ where
let mut state = state.take();
let idx = state.idx + 1;
let host = match state.config.host.get(idx) {
let host = match state.config.0.host.get(idx) {
Some(host) => host,
None => return Err(e),
};

View File

@ -78,17 +78,19 @@ where
let port = *state
.config
.0
.port
.get(state.idx)
.or_else(|| state.config.port.get(0))
.or_else(|| state.config.0.port.get(0))
.unwrap_or(&5432);
let timeout = state
.config
.0
.connect_timeout
.map(|d| Delay::new(Instant::now() + d));
match &state.config.host[state.idx] {
match &state.config.0.host[state.idx] {
Host::Tcp(host) => {
let host = host.clone();
transition!(ResolvingDns {

View File

@ -79,7 +79,7 @@ where
let mut buf = vec![];
frontend::startup_message(
state.config.params.iter().map(|(k, v)| {
state.config.0.params.iter().map(|(k, v)| {
// libpq uses dbname, but the backend expects database (!)
let k = if k == "dbname" { "database" } else { &**k };
(k, &**v)
@ -124,6 +124,7 @@ where
Some(Message::AuthenticationCleartextPassword) => {
let pass = state
.config
.0
.password
.as_ref()
.ok_or_else(Error::missing_password)?;
@ -136,11 +137,13 @@ where
Some(Message::AuthenticationMd5Password(body)) => {
let user = state
.config
.0
.params
.get("user")
.ok_or_else(Error::missing_user)?;
let pass = state
.config
.0
.password
.as_ref()
.ok_or_else(Error::missing_password)?;
@ -154,6 +157,7 @@ where
Some(Message::AuthenticationSasl(body)) => {
let pass = state
.config
.0
.password
.as_ref()
.ok_or_else(Error::missing_password)?;