Unify Unix and TCP connection creation

Not all Unix socket setups can be configured via a URL connection string
since paths need not be UTF8, so it's possible to directly pass a
`PostgresConnectParams` type into `connect`.

In addition, an SSL encrypted connection can be used via Unix sockets.
This commit is contained in:
Steven Fackler 2014-04-20 22:27:55 -07:00
parent 11193628e9
commit 5c866ba20b
3 changed files with 250 additions and 171 deletions

View File

@ -84,7 +84,7 @@ use openssl::ssl::{SslStream, SslContext};
use serialize::hex::ToHex; use serialize::hex::ToHex;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::from_str::FromStr; use std::from_str::FromStr;
use std::io::{BufferedStream, IoResult}; use std::io::{Stream, BufferedStream, IoResult};
use std::io::net; use std::io::net;
use std::io::net::ip::{Port, SocketAddr}; use std::io::net::ip::{Port, SocketAddr};
use std::io::net::tcp::TcpStream; use std::io::net::tcp::TcpStream;
@ -207,6 +207,101 @@ static CANARY: u32 = 0xdeadbeef;
/// A typedef of the result returned by many methods. /// A typedef of the result returned by many methods.
pub type PostgresResult<T> = Result<T, PostgresError>; pub type PostgresResult<T> = Result<T, PostgresError>;
/// Specifies the target server to connect to.
#[deriving(Clone)]
pub enum PostgresConnectTarget {
/// Connect via TCP to the specified host.
TargetTcp(~str),
/// Connect via a Unix domain socket in the specified directory.
TargetUnix(Path)
}
/// Information necessary to open a new connection to a Postgres server.
#[deriving(Clone)]
pub struct PostgresConnectParams {
/// The target server
pub target: PostgresConnectTarget,
/// The target port
pub port: Port,
/// The user to login as.
///
/// `PostgresConnection::connect` requires a user but `cancel_query` does
/// not.
pub user: Option<~str>,
/// An optional password used for authentication
pub password: Option<~str>,
/// The database to connect to. Defaults the value of `user`.
pub database: Option<~str>,
/// Runtime parameters to be passed to the Postgres backend.
pub options: Vec<(~str, ~str)>,
}
/// A trait implemented by types that can be converted into a
/// `PostgresConnectParams`.
pub trait IntoConnectParams {
/// Converts the value of `self` into a `PostgresConnectParams`.
fn into_connect_params(self) -> Result<PostgresConnectParams,
PostgresConnectError>;
}
impl IntoConnectParams for PostgresConnectParams {
fn into_connect_params(self) -> Result<PostgresConnectParams,
PostgresConnectError> {
Ok(self)
}
}
impl<'a> IntoConnectParams for &'a str {
fn into_connect_params(self) -> Result<PostgresConnectParams,
PostgresConnectError> {
let Url {
host,
port,
user,
path,
query: options,
..
}: Url = match FromStr::from_str(self) {
Some(url) => url,
None => return Err(InvalidUrl)
};
let host = url::decode_component(host);
let host = if host.starts_with("/") {
TargetUnix(Path::new(host))
} else {
TargetTcp(host)
};
let (user, pass) = match user {
Some(UserInfo { user, pass}) => (Some(user), pass),
None => (None, None),
};
let port = match port {
Some(port) => FromStr::from_str(port).unwrap(),
None => DEFAULT_PORT
};
let database = if !path.is_empty() {
// path contains the leading /
let (_, path) = path.slice_shift_char();
Some(path.to_owned())
} else {
None
};
Ok(PostgresConnectParams {
target: host,
port: port,
user: user,
password: pass,
database: database,
options: options,
})
}
}
/// Trait for types that can handle Postgres notice messages /// Trait for types that can handle Postgres notice messages
pub trait PostgresNoticeHandler { pub trait PostgresNoticeHandler {
/// Handle a Postgres notice message /// Handle a Postgres notice message
@ -269,7 +364,7 @@ pub struct PostgresCancelData {
/// `PostgresConnection::cancel_data`. The object can cancel any query made on /// `PostgresConnection::cancel_data`. The object can cancel any query made on
/// that connection. /// that connection.
/// ///
/// Only the host and port of the URL are used. /// Only the host and port of the connetion info are used.
/// ///
/// # Example /// # Example
/// ///
@ -284,18 +379,12 @@ pub struct PostgresCancelData {
/// # let _ = /// # let _ =
/// postgres::cancel_query(url, &NoSsl, cancel_data); /// postgres::cancel_query(url, &NoSsl, cancel_data);
/// ``` /// ```
pub fn cancel_query(url: &str, ssl: &SslMode, data: PostgresCancelData) pub fn cancel_query<T: IntoConnectParams>(params: T, ssl: &SslMode,
-> Result<(), PostgresConnectError> { data: PostgresCancelData)
let Url { host, port, .. }: Url = match FromStr::from_str(url) { -> Result<(), PostgresConnectError> {
Some(url) => url, let params = try!(params.into_connect_params());
None => return Err(InvalidUrl)
};
let port = match port {
Some(port) => FromStr::from_str(port).unwrap(),
None => DEFAULT_PORT
};
let mut socket = match initialize_stream(host, port, ssl) { let mut socket = match initialize_stream(&params, ssl) {
Ok(socket) => socket, Ok(socket) => socket,
Err(err) => return Err(err) Err(err) => return Err(err)
}; };
@ -310,8 +399,8 @@ pub fn cancel_query(url: &str, ssl: &SslMode, data: PostgresCancelData)
Ok(()) Ok(())
} }
fn open_socket(host: &str, port: Port) fn open_tcp_socket(host: &str, port: Port)
-> Result<TcpStream, PostgresConnectError> { -> Result<TcpStream, PostgresConnectError> {
let addrs = match net::get_host_addresses(host) { let addrs = match net::get_host_addresses(host) {
Ok(addrs) => addrs, Ok(addrs) => addrs,
Err(err) => return Err(DnsError(err)) Err(err) => return Err(DnsError(err))
@ -328,12 +417,93 @@ fn open_socket(host: &str, port: Port)
Err(SocketError(err.unwrap())) Err(SocketError(err.unwrap()))
} }
fn initialize_stream(host: &str, port: Port, ssl: &SslMode) fn open_unix_socket(path: &Path, port: Port) -> Result<UnixStream,
-> Result<InternalStream, PostgresConnectError> { PostgresConnectError> {
let mut socket = match open_socket(host, port) { let mut socket = path.clone();
Ok(socket) => socket, socket.push(format!(".s.PGSQL.{}", port));
Err(err) => return Err(err)
}; match UnixStream::connect(&socket) {
Ok(unix) => Ok(unix),
// FIXME bad error variant
Err(err) => Err(SocketError(err))
}
}
enum MaybeSslStream<S> {
SslStream(SslStream<S>),
NormalStream(S),
}
impl<S: Stream> Reader for MaybeSslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
SslStream(ref mut s) => s.read(buf),
NormalStream(ref mut s) => s.read(buf),
}
}
}
impl<S: Stream> Writer for MaybeSslStream<S> {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
SslStream(ref mut s) => s.write(buf),
NormalStream(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.flush(),
SslStream(ref mut s) => s.flush(),
}
}
}
enum InternalStream {
TcpStream(TcpStream),
UnixStream(UnixStream),
}
impl Reader for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
TcpStream(ref mut s) => s.read(buf),
UnixStream(ref mut s) => s.read(buf),
}
}
}
impl Writer for InternalStream {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
TcpStream(ref mut s) => s.write(buf),
UnixStream(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
TcpStream(ref mut s) => s.flush(),
UnixStream(ref mut s) => s.flush(),
}
}
}
fn open_socket(params: &PostgresConnectParams)
-> Result<InternalStream, PostgresConnectError> {
match params.target {
TargetTcp(ref host) => open_tcp_socket(host.as_slice(), params.port)
.map(|s| TcpStream(s)),
TargetUnix(ref path) => open_unix_socket(path, params.port)
.map(|s| UnixStream(s)),
}
}
fn initialize_stream(params: &PostgresConnectParams, ssl: &SslMode)
-> Result<MaybeSslStream<InternalStream>,
PostgresConnectError> {
let mut socket = try!(open_socket(params));
let (ssl_required, ctx) = match *ssl { let (ssl_required, ctx) = match *ssl {
NoSsl => return Ok(NormalStream(socket)), NoSsl => return Ok(NormalStream(socket)),
@ -358,50 +528,8 @@ fn initialize_stream(host: &str, port: Port, ssl: &SslMode)
} }
} }
fn initialize_unix(path: &Path)
-> Result<InternalStream, PostgresConnectError> {
match UnixStream::connect(path) {
Ok(unix) => Ok(NormalUnix(unix)),
Err(err) => Err(SocketError(err))
}
}
enum InternalStream {
NormalStream(TcpStream),
SslStream(SslStream<TcpStream>),
NormalUnix(UnixStream)
}
impl Reader for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
NormalStream(ref mut s) => s.read(buf),
SslStream(ref mut s) => s.read(buf),
NormalUnix(ref mut s) => s.read(buf)
}
}
}
impl Writer for InternalStream {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.write(buf),
SslStream(ref mut s) => s.write(buf),
NormalUnix(ref mut s) => s.write(buf)
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.flush(),
SslStream(ref mut s) => s.flush(),
NormalUnix(ref mut s) => s.flush()
}
}
}
struct InnerPostgresConnection { struct InnerPostgresConnection {
stream: BufferedStream<InternalStream>, stream: BufferedStream<MaybeSslStream<InternalStream>>,
next_stmt_id: uint, next_stmt_id: uint,
notice_handler: ~PostgresNoticeHandler:Send, notice_handler: ~PostgresNoticeHandler:Send,
notifications: RingBuf<PostgresNotification>, notifications: RingBuf<PostgresNotification>,
@ -421,56 +549,11 @@ impl Drop for InnerPostgresConnection {
} }
impl InnerPostgresConnection { impl InnerPostgresConnection {
fn connect(url: &str, ssl: &SslMode) -> Result<InnerPostgresConnection, fn connect<T: IntoConnectParams>(params: T, ssl: &SslMode)
PostgresConnectError> { -> Result<InnerPostgresConnection,
let Url { PostgresConnectError> {
host, let params = try!(params.into_connect_params());
port, let stream = try!(initialize_stream(&params, ssl));
user,
path,
query: mut args,
..
}: Url = match FromStr::from_str(url) {
Some(url) => url,
None => return Err(InvalidUrl)
};
let user = match user {
Some(user) => user,
None => return Err(MissingUser)
};
let port = match port {
Some(port) => FromStr::from_str(port).unwrap(),
None => DEFAULT_PORT
};
let stream = try!(initialize_stream(host, port, ssl));
if !path.is_empty() {
// path contains the leading /
let (_, path) = path.slice_shift_char();
args.push((~"database", path.to_owned()));
}
InnerPostgresConnection::connect_finish(stream, args, user)
}
fn connect_unix(socket_dir: &Path, port: Port, user: UserInfo, database: ~str)
-> Result<InnerPostgresConnection, PostgresConnectError> {
let mut socket = socket_dir.clone();
socket.push(format!(".s.PGSQL.{}", port));
let stream = try!(initialize_unix(&socket));
let mut args = Vec::new();
args.push((~"database", database));
InnerPostgresConnection::connect_finish(stream, args, user)
}
fn connect_finish(stream: InternalStream, mut args: Vec<(~str, ~str)>, user: UserInfo)
-> Result<InnerPostgresConnection, PostgresConnectError> {
let mut conn = InnerPostgresConnection { let mut conn = InnerPostgresConnection {
stream: BufferedStream::new(stream), stream: BufferedStream::new(stream),
@ -484,19 +567,36 @@ impl InnerPostgresConnection {
canary: CANARY, canary: CANARY,
}; };
args.push((~"client_encoding", ~"UTF8")); let PostgresConnectParams {
user,
password,
database,
mut options,
..
} = params;
let user = match user {
Some(user) => user,
None => return Err(MissingUser),
};
options.push((~"client_encoding", ~"UTF8"));
// Postgres uses the value of TimeZone as the time zone for TIMESTAMP // Postgres uses the value of TimeZone as the time zone for TIMESTAMP
// WITH TIME ZONE values. Timespec converts to GMT internally. // WITH TIME ZONE values. Timespec converts to GMT internally.
args.push((~"TimeZone", ~"GMT")); options.push((~"TimeZone", ~"GMT"));
// We have to clone here since we need the user again for auth // We have to clone here since we need the user again for auth
args.push((~"user", user.user.clone())); options.push((~"user", user.clone()));
match database {
Some(database) => options.push((~"database", database)),
None => {}
}
try_pg_conn!(conn.write_messages([StartupMessage { try_pg_conn!(conn.write_messages([StartupMessage {
version: message::PROTOCOL_VERSION, version: message::PROTOCOL_VERSION,
parameters: args.as_slice() parameters: options.as_slice()
}])); }]));
try!(conn.handle_auth(user)); try!(conn.handle_auth(user, password));
loop { loop {
match try_pg_conn!(conn.read_message()) { match try_pg_conn!(conn.read_message()) {
@ -542,12 +642,12 @@ impl InnerPostgresConnection {
} }
} }
fn handle_auth(&mut self, user: UserInfo) fn handle_auth(&mut self, user: ~str, pass: Option<~str>)
-> Result<(), PostgresConnectError> { -> Result<(), PostgresConnectError> {
match try_pg_conn!(self.read_message()) { match try_pg_conn!(self.read_message()) {
AuthenticationOk => return Ok(()), AuthenticationOk => return Ok(()),
AuthenticationCleartextPassword => { AuthenticationCleartextPassword => {
let pass = match user.pass { let pass = match pass {
Some(pass) => pass, Some(pass) => pass,
None => return Err(MissingPassword) None => return Err(MissingPassword)
}; };
@ -556,7 +656,6 @@ impl InnerPostgresConnection {
}])); }]));
} }
AuthenticationMD5Password { salt } => { AuthenticationMD5Password { salt } => {
let UserInfo { user, pass } = user;
let pass = match pass { let pass = match pass {
Some(pass) => pass, Some(pass) => pass,
None => return Err(MissingPassword) None => return Err(MissingPassword)
@ -734,67 +833,44 @@ pub struct PostgresConnection {
impl PostgresConnection { impl PostgresConnection {
/// Creates a new connection to a Postgres database. /// Creates a new connection to a Postgres database.
/// ///
/// The URL should be provided in the normal format: /// Most applications can use a URL string in the normal format:
/// ///
/// ```notrust /// ```notrust
/// postgres://user[:password]@host[:port][/database][?param1=val1[[&param2=val2]...]] /// postgresql://user[:password]@host[:port][/database][?param1=val1[[&param2=val2]...]]
/// ``` /// ```
/// ///
/// The password may be omitted if not required. The default Postgres port /// The password may be omitted if not required. The default Postgres port
/// (5432) is used if none is specified. The database name defaults to the /// (5432) is used if none is specified. The database name defaults to the
/// username if not specified. /// username if not specified.
/// ///
/// # Example /// To connect to the server via Unix sockets, `host` should be set to the
/// path to the directory containing the socket file. Since `/` is a
/// reserved character in URLs, the path should be URL encoded. If the
/// path contains non-UTF 8 characters, a `PostgresConnectParams` struct
/// should be created manually and passed in.
///
/// # Examples
/// ///
/// ```rust,no_run /// ```rust,no_run
/// # use postgres::{PostgresConnection, NoSsl}; /// # use postgres::{PostgresConnection, NoSsl};
/// let url = "postgres://postgres:hunter2@localhost:2994/foodb"; /// let url = "postgresql://postgres:hunter2@localhost:2994/foodb";
/// let maybe_conn = PostgresConnection::connect(url, &NoSsl); /// let maybe_conn = PostgresConnection::connect(url, &NoSsl);
/// let conn = match maybe_conn { /// let conn = match maybe_conn {
/// Ok(conn) => conn, /// Ok(conn) => conn,
/// Err(err) => fail!("Error connecting: {}", err) /// Err(err) => fail!("Error connecting: {}", err)
/// }; /// };
/// ``` /// ```
pub fn connect(url: &str, ssl: &SslMode) -> Result<PostgresConnection,
PostgresConnectError> {
InnerPostgresConnection::connect(url, ssl).map(|conn| {
PostgresConnection {
conn: RefCell::new(conn)
}
})
}
/// Creates a new connection to a Postgres database.
///
/// The path should be the directory containing the socket.
///
/// The password in the UserInfo may be omitted if not required.
///
/// # Example
/// ///
/// ```rust,no_run /// ```rust,no_run
/// # extern crate url; /// # use postgres::{PostgresConnection, NoSsl};
/// # extern crate postgres; /// let url = "postgresql://postgres@%2Ftmp";
/// # use url::UserInfo; /// let maybe_conn = PostgresConnection::connect(url, &NoSsl);
/// # use postgres::PostgresConnection;
/// # fn main() {
/// let path = Path::new("/tmp");
/// let port = 5432;
/// let user = UserInfo::new(~"username", None);
/// let database = ~"postgres";
/// let maybe_conn = PostgresConnection::connect_unix(&path, 5432, user, database);
/// let conn = match maybe_conn {
/// Ok(conn) => conn,
/// Err(err) => fail!("Error connecting: {}", err)
/// };
/// # }
/// ``` /// ```
pub fn connect_unix(path: &Path, port: Port, user: UserInfo, database: ~str) pub fn connect<T: IntoConnectParams>(params: T, ssl: &SslMode)
-> Result<PostgresConnection, PostgresConnectError> { -> Result<PostgresConnection,
InnerPostgresConnection::connect_unix(path, port, user, database).map(|conn| { PostgresConnectError> {
PostgresConnection { InnerPostgresConnection::connect(params, ssl).map(|conn| {
conn: RefCell::new(conn) PostgresConnection { conn: RefCell::new(conn) }
}
}) })
} }

View File

@ -4,6 +4,8 @@ use std::cast;
use sync::{Arc, Mutex}; use sync::{Arc, Mutex};
use {PostgresNotifications, use {PostgresNotifications,
PostgresConnectParams,
IntoConnectParams,
PostgresResult, PostgresResult,
PostgresCancelData, PostgresCancelData,
PostgresConnection, PostgresConnection,
@ -14,7 +16,7 @@ use error::PostgresConnectError;
use types::ToSql; use types::ToSql;
struct InnerConnectionPool { struct InnerConnectionPool {
url: ~str, params: PostgresConnectParams,
ssl: SslMode, ssl: SslMode,
// Actually Vec<~PostgresConnection> // Actually Vec<~PostgresConnection>
pool: Vec<*()>, pool: Vec<*()>,
@ -36,7 +38,7 @@ impl Drop for InnerConnectionPool {
impl InnerConnectionPool { impl InnerConnectionPool {
fn add_connection(&mut self) -> Result<(), PostgresConnectError> { fn add_connection(&mut self) -> Result<(), PostgresConnectError> {
PostgresConnection::connect(self.url, &self.ssl) PostgresConnection::connect(self.params.clone(), &self.ssl)
.map(|c| unsafe { self.pool.push(cast::transmute(~c)); }) .map(|c| unsafe { self.pool.push(cast::transmute(~c)); })
} }
} }
@ -50,7 +52,7 @@ impl InnerConnectionPool {
/// ```rust,no_run /// ```rust,no_run
/// # use postgres::NoSsl; /// # use postgres::NoSsl;
/// # use postgres::pool::PostgresConnectionPool; /// # use postgres::pool::PostgresConnectionPool;
/// let pool = PostgresConnectionPool::new(~"postgres://postgres@localhost", /// let pool = PostgresConnectionPool::new("postgres://postgres@localhost",
/// NoSsl, 5).unwrap(); /// NoSsl, 5).unwrap();
/// for _ in range(0, 10) { /// for _ in range(0, 10) {
/// let pool = pool.clone(); /// let pool = pool.clone();
@ -70,10 +72,10 @@ impl PostgresConnectionPool {
/// ///
/// Returns an error if the specified number of connections cannot be /// Returns an error if the specified number of connections cannot be
/// created. /// created.
pub fn new(url: ~str, ssl: SslMode, pool_size: uint) pub fn new<T: IntoConnectParams>(params: T, ssl: SslMode, pool_size: uint)
-> Result<PostgresConnectionPool, PostgresConnectError> { -> Result<PostgresConnectionPool, PostgresConnectError> {
let mut pool = InnerConnectionPool { let mut pool = InnerConnectionPool {
url: url, params: try!(params.into_connect_params()),
ssl: ssl, ssl: ssl,
pool: Vec::new(), pool: Vec::new(),
}; };

View File

@ -9,7 +9,7 @@ use openssl::ssl::{SslContext, Sslv3};
use std::f32; use std::f32;
use std::f64; use std::f64;
use std::io::timer; use std::io::timer;
use url::UserInfo; use url;
use {PostgresNoticeHandler, use {PostgresNoticeHandler,
PostgresNotification, PostgresNotification,
@ -50,7 +50,7 @@ macro_rules! or_fail(
#[test] #[test]
// Make sure we can take both connections at once and can still get one after // Make sure we can take both connections at once and can still get one after
fn test_pool() { fn test_pool() {
let pool = or_fail!(PostgresConnectionPool::new(~"postgres://postgres@localhost", let pool = or_fail!(PostgresConnectionPool::new("postgres://postgres@localhost",
NoSsl, 2)); NoSsl, 2));
let (stream1, stream2) = sync::duplex(); let (stream1, stream2) = sync::duplex();
@ -121,7 +121,8 @@ fn test_unix_connection() {
let unix_socket_directory = unix_socket_directories.splitn(',', 1).next().unwrap(); let unix_socket_directory = unix_socket_directories.splitn(',', 1).next().unwrap();
let conn = or_fail!(PostgresConnection::connect_unix(&Path::new(unix_socket_directory), 5432, UserInfo::new(~"postgres", None), ~"postgres")); let url = format!("postgres://postgres@{}", url::encode_component(unix_socket_directory));
let conn = or_fail!(PostgresConnection::connect(url.as_slice(), &NoSsl));
assert!(conn.finish().is_ok()); assert!(conn.finish().is_ok());
} }