Wrap Builder in an Arc

The builder ends up being cloned a couple of times per connection, so
use Arc::get_mut to make that faster.
This commit is contained in:
Steven Fackler 2018-12-28 14:16:38 -05:00
parent 540bcc5556
commit 634d24a951
4 changed files with 35 additions and 17 deletions

View File

@ -3,6 +3,7 @@ use std::iter;
#[cfg(all(feature = "runtime", unix))] #[cfg(all(feature = "runtime", unix))]
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::str::{self, FromStr}; use std::str::{self, FromStr};
use std::sync::Arc;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use std::time::Duration; use std::time::Duration;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
@ -23,7 +24,7 @@ pub(crate) enum Host {
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Builder { pub(crate) struct Inner {
pub(crate) params: HashMap<String, String>, pub(crate) params: HashMap<String, String>,
pub(crate) password: Option<Vec<u8>>, pub(crate) password: Option<Vec<u8>>,
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -34,6 +35,9 @@ pub struct Builder {
pub(crate) connect_timeout: Option<Duration>, pub(crate) connect_timeout: Option<Duration>,
} }
#[derive(Debug, Clone, PartialEq)]
pub struct Builder(pub(crate) Arc<Inner>);
impl Default for Builder { impl Default for Builder {
fn default() -> Builder { fn default() -> Builder {
Builder::new() Builder::new()
@ -46,7 +50,7 @@ impl Builder {
params.insert("client_encoding".to_string(), "UTF8".to_string()); params.insert("client_encoding".to_string(), "UTF8".to_string());
params.insert("timezone".to_string(), "GMT".to_string()); params.insert("timezone".to_string(), "GMT".to_string());
Builder { Builder(Arc::new(Inner {
params, params,
password: None, password: None,
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -55,7 +59,7 @@ impl Builder {
port: vec![], port: vec![],
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
connect_timeout: None, connect_timeout: None,
} }))
} }
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -63,12 +67,16 @@ impl Builder {
#[cfg(unix)] #[cfg(unix)]
{ {
if host.starts_with('/') { if host.starts_with('/') {
self.host.push(Host::Unix(PathBuf::from(host))); Arc::make_mut(&mut self.0)
.host
.push(Host::Unix(PathBuf::from(host)));
return self; return self;
} }
} }
self.host.push(Host::Tcp(host.to_string())); Arc::make_mut(&mut self.0)
.host
.push(Host::Tcp(host.to_string()));
self self
} }
@ -77,19 +85,21 @@ impl Builder {
where where
T: AsRef<Path>, T: AsRef<Path>,
{ {
self.host.push(Host::Unix(host.as_ref().to_path_buf())); Arc::make_mut(&mut self.0)
.host
.push(Host::Unix(host.as_ref().to_path_buf()));
self self
} }
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub fn port(&mut self, port: u16) -> &mut Builder { pub fn port(&mut self, port: u16) -> &mut Builder {
self.port.push(port); Arc::make_mut(&mut self.0).port.push(port);
self self
} }
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder { pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Builder {
self.connect_timeout = Some(connect_timeout); Arc::make_mut(&mut self.0).connect_timeout = Some(connect_timeout);
self self
} }
@ -97,12 +107,14 @@ impl Builder {
where where
T: AsRef<[u8]>, T: AsRef<[u8]>,
{ {
self.password = Some(password.as_ref().to_vec()); Arc::make_mut(&mut self.0).password = Some(password.as_ref().to_vec());
self self
} }
pub fn param(&mut self, key: &str, value: &str) -> &mut Builder { pub fn param(&mut self, key: &str, value: &str) -> &mut Builder {
self.params.insert(key.to_string(), value.to_string()); Arc::make_mut(&mut self.0)
.params
.insert(key.to_string(), value.to_string());
self self
} }

View File

@ -38,15 +38,15 @@ where
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take(); let mut state = state.take();
if state.config.host.is_empty() { if state.config.0.host.is_empty() {
return Err(Error::missing_host()); return Err(Error::missing_host());
} }
if state.config.port.len() > 1 && state.config.port.len() != state.config.host.len() { if state.config.0.port.len() > 1 && state.config.0.port.len() != state.config.0.host.len() {
return Err(Error::invalid_port_count()); return Err(Error::invalid_port_count());
} }
let hostname = match &state.config.host[0] { let hostname = match &state.config.0.host[0] {
Host::Tcp(host) => &**host, Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)] #[cfg(unix)]
@ -86,7 +86,7 @@ where
let mut state = state.take(); let mut state = state.take();
let idx = state.idx + 1; let idx = state.idx + 1;
let host = match state.config.host.get(idx) { let host = match state.config.0.host.get(idx) {
Some(host) => host, Some(host) => host,
None => return Err(e), None => return Err(e),
}; };

View File

@ -78,17 +78,19 @@ where
let port = *state let port = *state
.config .config
.0
.port .port
.get(state.idx) .get(state.idx)
.or_else(|| state.config.port.get(0)) .or_else(|| state.config.0.port.get(0))
.unwrap_or(&5432); .unwrap_or(&5432);
let timeout = state let timeout = state
.config .config
.0
.connect_timeout .connect_timeout
.map(|d| Delay::new(Instant::now() + d)); .map(|d| Delay::new(Instant::now() + d));
match &state.config.host[state.idx] { match &state.config.0.host[state.idx] {
Host::Tcp(host) => { Host::Tcp(host) => {
let host = host.clone(); let host = host.clone();
transition!(ResolvingDns { transition!(ResolvingDns {

View File

@ -79,7 +79,7 @@ where
let mut buf = vec![]; let mut buf = vec![];
frontend::startup_message( frontend::startup_message(
state.config.params.iter().map(|(k, v)| { state.config.0.params.iter().map(|(k, v)| {
// libpq uses dbname, but the backend expects database (!) // libpq uses dbname, but the backend expects database (!)
let k = if k == "dbname" { "database" } else { &**k }; let k = if k == "dbname" { "database" } else { &**k };
(k, &**v) (k, &**v)
@ -124,6 +124,7 @@ where
Some(Message::AuthenticationCleartextPassword) => { Some(Message::AuthenticationCleartextPassword) => {
let pass = state let pass = state
.config .config
.0
.password .password
.as_ref() .as_ref()
.ok_or_else(Error::missing_password)?; .ok_or_else(Error::missing_password)?;
@ -136,11 +137,13 @@ where
Some(Message::AuthenticationMd5Password(body)) => { Some(Message::AuthenticationMd5Password(body)) => {
let user = state let user = state
.config .config
.0
.params .params
.get("user") .get("user")
.ok_or_else(Error::missing_user)?; .ok_or_else(Error::missing_user)?;
let pass = state let pass = state
.config .config
.0
.password .password
.as_ref() .as_ref()
.ok_or_else(Error::missing_password)?; .ok_or_else(Error::missing_password)?;
@ -154,6 +157,7 @@ where
Some(Message::AuthenticationSasl(body)) => { Some(Message::AuthenticationSasl(body)) => {
let pass = state let pass = state
.config .config
.0
.password .password
.as_ref() .as_ref()
.ok_or_else(Error::missing_password)?; .ok_or_else(Error::missing_password)?;