Refactor connectparams

This commit is contained in:
Steven Fackler 2017-02-15 21:42:27 -08:00
parent 4c91a68dcc
commit c59799e376
7 changed files with 179 additions and 104 deletions

View File

@ -1,50 +1,142 @@
//! Connection parameters
use std::error::Error;
use std::path::PathBuf;
use std::mem;
use params::url::Url;
mod url;
/// Specifies the target server to connect to.
/// The host.
#[derive(Clone, Debug)]
pub enum ConnectTarget {
/// Connect via TCP to the specified host.
pub enum Host {
/// A TCP hostname.
Tcp(String),
/// Connect via a Unix domain socket in the specified directory.
///
/// Unix sockets are only supported on Unixy platforms (i.e. not Windows).
/// The path to a directory containing the server's Unix socket.
Unix(PathBuf),
}
/// Authentication information.
#[derive(Clone, Debug)]
pub struct UserInfo {
pub struct User {
name: String,
password: Option<String>,
}
impl User {
/// The username.
pub user: String,
pub fn name(&self) -> &str {
&self.name
}
/// An optional password.
pub password: Option<String>,
pub fn password(&self) -> Option<&str> {
self.password.as_ref().map(|p| &**p)
}
}
/// Information necessary to open a new connection to a Postgres server.
#[derive(Clone, Debug)]
pub struct ConnectParams {
/// The target server.
pub target: ConnectTarget,
host: Host,
port: u16,
user: Option<User>,
database: Option<String>,
options: Vec<(String, String)>,
}
impl ConnectParams {
/// Returns a new builder.
pub fn builder() -> Builder {
Builder::new()
}
/// The target host.
pub fn host(&self) -> &Host {
&self.host
}
/// The target port.
///
/// Defaults to 5432 if not specified.
pub port: Option<u16>,
/// The user to login as.
/// Defaults to 5432.
pub fn port(&self) -> u16 {
self.port
}
/// The user to log in as.
///
/// `Connection::connect` requires a user but `cancel_query` does not.
pub user: Option<UserInfo>,
/// A user is required to open a new connection but not to cancel a query.
pub fn user(&self) -> Option<&User> {
self.user.as_ref()
}
/// The database to connect to.
///
/// Defaults the value of `user`.
pub database: Option<String>,
pub fn database(&self) -> Option<&str> {
self.database.as_ref().map(|d| &**d)
}
/// Runtime parameters to be passed to the Postgres backend.
pub options: Vec<(String, String)>,
pub fn options(&self) -> &[(String, String)] {
&self.options
}
}
/// A builder for `ConnectParams`.
pub struct Builder {
port: u16,
user: Option<User>,
database: Option<String>,
options: Vec<(String, String)>,
}
impl Builder {
/// Creates a new builder.
pub fn new() -> Builder {
Builder {
port: 5432,
user: None,
database: None,
options: vec![],
}
}
/// Sets the port.
pub fn port(&mut self, port: u16) -> &mut Builder {
self.port = port;
self
}
/// Sets the user.
pub fn user(&mut self, name: &str, password: Option<&str>) -> &mut Builder {
self.user = Some(User {
name: name.to_string(),
password: password.map(ToString::to_string),
});
self
}
/// Sets the database.
pub fn database(&mut self, database: &str) -> &mut Builder {
self.database = Some(database.to_string());
self
}
/// Adds a runtime parameter.
pub fn option(&mut self, name: &str, value: &str) -> &mut Builder {
self.options.push((name.to_string(), value.to_string()));
self
}
/// Constructs a `ConnectParams` from the builder.
pub fn build(&mut self, host: Host) -> ConnectParams {
ConnectParams {
host: host,
port: self.port,
user: self.user.take(),
database: self.database.take(),
options: mem::replace(&mut self.options, vec![]),
}
}
}
/// A trait implemented by types that can be converted into a `ConnectParams`.
@ -78,35 +170,33 @@ impl IntoConnectParams for Url {
fn into_connect_params(self) -> Result<ConnectParams, Box<Error + Sync + Send>> {
let Url { host, port, user, path: url::Path { mut path, query: options, .. }, .. } = self;
let maybe_path = url::decode_component(&host)?;
let target = if maybe_path.starts_with('/') {
ConnectTarget::Unix(PathBuf::from(maybe_path))
} else {
ConnectTarget::Tcp(host)
};
let mut builder = ConnectParams::builder();
let user = user.map(|url::UserInfo { user, pass }| {
UserInfo {
user: user,
password: pass,
}
});
if let Some(port) = port {
builder.port(port);
}
let database = if path.is_empty() {
None
} else {
if let Some(info) = user {
builder.user(&info.user, info.pass.as_ref().map(|p| &**p));
}
if !path.is_empty() {
// path contains the leading /
path.remove(0);
Some(path)
builder.database(&path[1..]);
}
for (name, value) in options {
builder.option(&name, &value);
}
let maybe_path = url::decode_component(&host)?;
let host = if maybe_path.starts_with('/') {
Host::Unix(maybe_path.into())
} else {
Host::Tcp(maybe_path)
};
Ok(ConnectParams {
target: target,
port: port,
user: user,
database: database,
options: options,
})
Ok(builder.build(host))
}
}

View File

@ -96,7 +96,7 @@ use postgres_shared::rows::RowData;
use error::{Error, ConnectError, SqlState, DbError};
use tls::TlsHandshake;
use notification::{Notifications, Notification};
use params::{ConnectParams, IntoConnectParams, UserInfo};
use params::{ConnectParams, IntoConnectParams, User};
use priv_io::MessageStream;
use rows::{Rows, LazyRows};
use stmt::{Statement, Column};
@ -255,9 +255,7 @@ impl InnerConnection {
let params = params.into_connect_params().map_err(ConnectError::ConnectParams)?;
let stream = priv_io::initialize_stream(&params, tls)?;
let ConnectParams { user, database, mut options, .. } = params;
let user = match user {
let user = match params.user() {
Some(user) => user,
None => {
return Err(ConnectError::ConnectParams("User missing from connection parameters"
@ -285,14 +283,15 @@ impl InnerConnection {
has_typeinfo_composite_query: false,
};
let mut options = params.options().to_owned();
options.push(("client_encoding".to_owned(), "UTF8".to_owned()));
// Postgres uses the value of TimeZone as the time zone for TIMESTAMP
// WITH TIME ZONE values. Timespec converts to GMT internally.
options.push(("timezone".to_owned(), "GMT".to_owned()));
// We have to clone here since we need the user again for auth
options.push(("user".to_owned(), user.user.clone()));
if let Some(database) = database {
options.push(("database".to_owned(), database));
options.push(("user".to_owned(), user.name().to_owned()));
if let Some(database) = params.database() {
options.push(("database".to_owned(), database.to_owned()));
}
let options = options.iter().map(|&(ref a, ref b)| (&**a, &**b));
@ -390,21 +389,21 @@ impl InnerConnection {
}
}
fn handle_auth(&mut self, user: UserInfo) -> result::Result<(), ConnectError> {
fn handle_auth(&mut self, user: &User) -> result::Result<(), ConnectError> {
match self.read_message()? {
backend::Message::AuthenticationOk => return Ok(()),
backend::Message::AuthenticationCleartextPassword => {
let pass = user.password.ok_or_else(|| {
let pass = user.password().ok_or_else(|| {
ConnectError::ConnectParams("a password was requested but not provided".into())
})?;
self.stream.write_message(|buf| frontend::password_message(&pass, buf))?;
self.stream.write_message(|buf| frontend::password_message(pass, buf))?;
self.stream.flush()?;
}
backend::Message::AuthenticationMd5Password(body) => {
let pass = user.password.ok_or_else(|| {
let pass = user.password().ok_or_else(|| {
ConnectError::ConnectParams("a password was requested but not provided".into())
})?;
let output = authentication::md5_hash(user.user.as_bytes(),
let output = authentication::md5_hash(user.name().as_bytes(),
pass.as_bytes(),
body.salt());
self.stream.write_message(|buf| frontend::password_message(&output, buf))?;
@ -932,22 +931,15 @@ impl Connection {
///
/// ```rust,no_run
/// use postgres::{Connection, TlsMode};
/// use postgres::params::{UserInfo, ConnectParams, ConnectTarget};
/// use postgres::params::{ConnectParams, Host};
/// # use std::path::PathBuf;
///
/// # #[cfg(unix)]
/// # fn f() {
/// # let some_crazy_path = PathBuf::new();
/// let params = ConnectParams {
/// target: ConnectTarget::Unix(some_crazy_path),
/// port: None,
/// user: Some(UserInfo {
/// user: "postgres".to_owned(),
/// password: None
/// }),
/// database: None,
/// options: vec![],
/// };
/// let params = ConnectParams::builder()
/// .user("postgres", None)
/// .build(Host::Unix(some_crazy_path));
/// let conn = Connection::connect(params, TlsMode::None).unwrap();
/// # }
/// ```

View File

@ -16,7 +16,7 @@ use postgres_protocol::message::backend::{self, ParseResult};
use TlsMode;
use error::ConnectError;
use tls::TlsStream;
use params::{ConnectParams, ConnectTarget};
use params::{ConnectParams, Host};
const DEFAULT_PORT: u16 = 5432;
const MESSAGE_HEADER_SIZE: usize = 5;
@ -221,18 +221,18 @@ impl Write for InternalStream {
}
fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
let port = params.port.unwrap_or(DEFAULT_PORT);
match params.target {
ConnectTarget::Tcp(ref host) => {
let port = params.port();
match *params.host() {
Host::Tcp(ref host) => {
Ok(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp)?)
}
#[cfg(unix)]
ConnectTarget::Unix(ref path) => {
Host::Unix(ref path) => {
let path = path.join(&format!(".s.PGSQL.{}", port));
Ok(UnixStream::connect(&path).map(InternalStream::Unix)?)
}
#[cfg(not(unix))]
ConnectTarget::Unix(..) => {
Host::Unix(..) => {
Err(ConnectError::Io(io::Error::new(io::ErrorKind::InvalidInput,
"unix sockets are not supported on this system")))
}
@ -265,10 +265,10 @@ pub fn initialize_stream(params: &ConnectParams,
}
}
let host = match params.target {
ConnectTarget::Tcp(ref host) => host,
let host = match *params.host() {
Host::Tcp(ref host) => host,
// Postgres doesn't support TLS over unix sockets
ConnectTarget::Unix(_) => return Err(ConnectError::Io(::bad_response())),
Host::Unix(_) => return Err(ConnectError::Io(::bad_response())),
};
handshaker.tls_handshake(host, socket).map_err(ConnectError::Tls)

View File

@ -978,8 +978,8 @@ fn url_unencoded_password() {
#[test]
fn url_encoded_password() {
let params = "postgresql://username%7b%7c:password%7b%7c@localhost".into_connect_params().unwrap();
assert_eq!("username{|", &params.user.as_ref().unwrap().user[..]);
assert_eq!("password{|", &params.user.as_ref().unwrap().password.as_ref().unwrap()[..]);
assert_eq!("username{|", params.user().unwrap().name());
assert_eq!("password{|", params.user().unwrap().password().unwrap());
}
#[test]

View File

@ -142,8 +142,8 @@ pub fn cancel_query<T>(params: T,
{
let params = match params.into_connect_params() {
Ok(params) => {
Either::A(stream::connect(params.target.clone(),
params.port.unwrap_or(5432),
Either::A(stream::connect(params.host().clone(),
params.port(),
tls_mode,
handle))
}
@ -264,8 +264,8 @@ impl Connection {
{
let fut = match params.into_connect_params() {
Ok(params) => {
Either::A(stream::connect(params.target.clone(),
params.port.unwrap_or(5432),
Either::A(stream::connect(params.host().clone(),
params.port(),
tls_mode,
handle)
.map(|s| (s, params)))
@ -301,9 +301,9 @@ impl Connection {
let result = {
let options = [("client_encoding", "UTF8"), ("timezone", "GMT")];
let options = options.iter().cloned();
let options = options.chain(params.user.as_ref().map(|u| ("user", &*u.user)));
let options = options.chain(params.database.as_ref().map(|d| ("database", &**d)));
let options = options.chain(params.options.iter().map(|e| (&*e.0, &*e.1)));
let options = options.chain(params.user().map(|u| ("user", u.name())));
let options = options.chain(params.database().map(|d| ("database", d)));
let options = options.chain(params.options().iter().map(|e| (&*e.0, &*e.1)));
frontend::startup_message(options, &mut buf)
};
@ -323,7 +323,7 @@ impl Connection {
let response = match m {
backend::Message::AuthenticationOk => Ok(None),
backend::Message::AuthenticationCleartextPassword => {
match params.user.as_ref().and_then(|u| u.password.as_ref()) {
match params.user().and_then(|u| u.password()) {
Some(pass) => {
let mut buf = vec![];
frontend::password_message(pass, &mut buf)
@ -337,7 +337,7 @@ impl Connection {
}
}
backend::Message::AuthenticationMd5Password(body) => {
match params.user.as_ref().and_then(|u| u.password.as_ref().map(|p| (&u.user, p))) {
match params.user().and_then(|u| u.password().map(|p| (u.name(), p))) {
Some((user, pass)) => {
let pass = authentication::md5_hash(user.as_bytes(),
pass.as_bytes(),

View File

@ -1,6 +1,6 @@
use futures::{BoxFuture, Future, IntoFuture, Async, Sink, Stream as FuturesStream};
use futures::future::Either;
use postgres_shared::params::ConnectTarget;
use postgres_shared::params::Host;
use postgres_protocol::message::backend::{self, ParseResult};
use postgres_protocol::message::frontend;
use std::io::{self, Read, Write};
@ -16,17 +16,17 @@ use tls::TlsStream;
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
pub fn connect(host: ConnectTarget,
pub fn connect(host: Host,
port: u16,
tls_mode: TlsMode,
handle: &Handle)
-> BoxFuture<PostgresStream, ConnectError> {
let inner = match host {
ConnectTarget::Tcp(ref host) => {
Host::Tcp(ref host) => {
Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
.map(|s| Stream(InnerStream::Tcp(s))))
}
ConnectTarget::Unix(ref host) => {
Host::Unix(ref host) => {
let addr = host.join(format!(".s.PGSQL.{}", port));
Either::B(UnixStream::connect(addr, handle)
.map(|s| Stream(InnerStream::Unix(s)))
@ -68,8 +68,8 @@ pub fn connect(host: ConnectTarget,
(None, _) => Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))).into_future()),
_ => {
let host = match host {
ConnectTarget::Tcp(ref host) => host,
ConnectTarget::Unix(_) => unreachable!(),
Host::Tcp(ref host) => host,
Host::Unix(_) => unreachable!(),
};
Either::B(handshaker.handshake(host, s).map_err(ConnectError::Tls))
}

View File

@ -7,7 +7,7 @@ use tokio_core::reactor::{Core, Interval};
use super::*;
use error::{Error, ConnectError, SqlState};
use params::{ConnectParams, ConnectTarget, UserInfo};
use params::{ConnectParams, Host};
use types::{ToSql, FromSql, Type, IsNull, Kind};
#[test]
@ -182,16 +182,9 @@ fn unix_socket() {
.and_then(|(s, c)| c.query(&s, &[]).collect())
.then(|r| {
let r = r.unwrap().0;
let params = ConnectParams {
target: ConnectTarget::Unix(PathBuf::from(r[0].get::<String, _>(0))),
port: None,
user: Some(UserInfo {
user: "postgres".to_owned(),
password: None,
}),
database: None,
options: vec![],
};
let params = ConnectParams::builder()
.user("postgres", None)
.build(Host::Unix(PathBuf::from(r[0].get::<String, _>(0))));
Connection::connect(params, TlsMode::None, &handle)
})
.then(|c| c.unwrap().batch_execute(""));