Unify Unix and TCP connection creation

Not all Unix socket setups can be configured via a URL connection string
since paths need not be UTF8, so it's possible to directly pass a
`PostgresConnectParams` type into `connect`.

In addition, an SSL encrypted connection can be used via Unix sockets.
This commit is contained in:
Steven Fackler 2014-04-20 22:27:55 -07:00
parent 11193628e9
commit 5c866ba20b
3 changed files with 250 additions and 171 deletions

View File

@ -84,7 +84,7 @@ use openssl::ssl::{SslStream, SslContext};
use serialize::hex::ToHex;
use std::cell::{Cell, RefCell};
use std::from_str::FromStr;
use std::io::{BufferedStream, IoResult};
use std::io::{Stream, BufferedStream, IoResult};
use std::io::net;
use std::io::net::ip::{Port, SocketAddr};
use std::io::net::tcp::TcpStream;
@ -207,6 +207,101 @@ static CANARY: u32 = 0xdeadbeef;
/// A typedef of the result returned by many methods.
pub type PostgresResult<T> = Result<T, PostgresError>;
/// Specifies the target server to connect to.
#[deriving(Clone)]
pub enum PostgresConnectTarget {
/// Connect via TCP to the specified host.
TargetTcp(~str),
/// Connect via a Unix domain socket in the specified directory.
TargetUnix(Path)
}
/// Information necessary to open a new connection to a Postgres server.
#[deriving(Clone)]
pub struct PostgresConnectParams {
/// The target server
pub target: PostgresConnectTarget,
/// The target port
pub port: Port,
/// The user to login as.
///
/// `PostgresConnection::connect` requires a user but `cancel_query` does
/// not.
pub user: Option<~str>,
/// An optional password used for authentication
pub password: Option<~str>,
/// The database to connect to. Defaults the value of `user`.
pub database: Option<~str>,
/// Runtime parameters to be passed to the Postgres backend.
pub options: Vec<(~str, ~str)>,
}
/// A trait implemented by types that can be converted into a
/// `PostgresConnectParams`.
pub trait IntoConnectParams {
/// Converts the value of `self` into a `PostgresConnectParams`.
fn into_connect_params(self) -> Result<PostgresConnectParams,
PostgresConnectError>;
}
impl IntoConnectParams for PostgresConnectParams {
fn into_connect_params(self) -> Result<PostgresConnectParams,
PostgresConnectError> {
Ok(self)
}
}
impl<'a> IntoConnectParams for &'a str {
fn into_connect_params(self) -> Result<PostgresConnectParams,
PostgresConnectError> {
let Url {
host,
port,
user,
path,
query: options,
..
}: Url = match FromStr::from_str(self) {
Some(url) => url,
None => return Err(InvalidUrl)
};
let host = url::decode_component(host);
let host = if host.starts_with("/") {
TargetUnix(Path::new(host))
} else {
TargetTcp(host)
};
let (user, pass) = match user {
Some(UserInfo { user, pass}) => (Some(user), pass),
None => (None, None),
};
let port = match port {
Some(port) => FromStr::from_str(port).unwrap(),
None => DEFAULT_PORT
};
let database = if !path.is_empty() {
// path contains the leading /
let (_, path) = path.slice_shift_char();
Some(path.to_owned())
} else {
None
};
Ok(PostgresConnectParams {
target: host,
port: port,
user: user,
password: pass,
database: database,
options: options,
})
}
}
/// Trait for types that can handle Postgres notice messages
pub trait PostgresNoticeHandler {
/// Handle a Postgres notice message
@ -269,7 +364,7 @@ pub struct PostgresCancelData {
/// `PostgresConnection::cancel_data`. The object can cancel any query made on
/// that connection.
///
/// Only the host and port of the URL are used.
/// Only the host and port of the connetion info are used.
///
/// # Example
///
@ -284,18 +379,12 @@ pub struct PostgresCancelData {
/// # let _ =
/// postgres::cancel_query(url, &NoSsl, cancel_data);
/// ```
pub fn cancel_query(url: &str, ssl: &SslMode, data: PostgresCancelData)
-> Result<(), PostgresConnectError> {
let Url { host, port, .. }: Url = match FromStr::from_str(url) {
Some(url) => url,
None => return Err(InvalidUrl)
};
let port = match port {
Some(port) => FromStr::from_str(port).unwrap(),
None => DEFAULT_PORT
};
pub fn cancel_query<T: IntoConnectParams>(params: T, ssl: &SslMode,
data: PostgresCancelData)
-> Result<(), PostgresConnectError> {
let params = try!(params.into_connect_params());
let mut socket = match initialize_stream(host, port, ssl) {
let mut socket = match initialize_stream(&params, ssl) {
Ok(socket) => socket,
Err(err) => return Err(err)
};
@ -310,8 +399,8 @@ pub fn cancel_query(url: &str, ssl: &SslMode, data: PostgresCancelData)
Ok(())
}
fn open_socket(host: &str, port: Port)
-> Result<TcpStream, PostgresConnectError> {
fn open_tcp_socket(host: &str, port: Port)
-> Result<TcpStream, PostgresConnectError> {
let addrs = match net::get_host_addresses(host) {
Ok(addrs) => addrs,
Err(err) => return Err(DnsError(err))
@ -328,12 +417,93 @@ fn open_socket(host: &str, port: Port)
Err(SocketError(err.unwrap()))
}
fn initialize_stream(host: &str, port: Port, ssl: &SslMode)
-> Result<InternalStream, PostgresConnectError> {
let mut socket = match open_socket(host, port) {
Ok(socket) => socket,
Err(err) => return Err(err)
};
fn open_unix_socket(path: &Path, port: Port) -> Result<UnixStream,
PostgresConnectError> {
let mut socket = path.clone();
socket.push(format!(".s.PGSQL.{}", port));
match UnixStream::connect(&socket) {
Ok(unix) => Ok(unix),
// FIXME bad error variant
Err(err) => Err(SocketError(err))
}
}
enum MaybeSslStream<S> {
SslStream(SslStream<S>),
NormalStream(S),
}
impl<S: Stream> Reader for MaybeSslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
SslStream(ref mut s) => s.read(buf),
NormalStream(ref mut s) => s.read(buf),
}
}
}
impl<S: Stream> Writer for MaybeSslStream<S> {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
SslStream(ref mut s) => s.write(buf),
NormalStream(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.flush(),
SslStream(ref mut s) => s.flush(),
}
}
}
enum InternalStream {
TcpStream(TcpStream),
UnixStream(UnixStream),
}
impl Reader for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
TcpStream(ref mut s) => s.read(buf),
UnixStream(ref mut s) => s.read(buf),
}
}
}
impl Writer for InternalStream {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
TcpStream(ref mut s) => s.write(buf),
UnixStream(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
TcpStream(ref mut s) => s.flush(),
UnixStream(ref mut s) => s.flush(),
}
}
}
fn open_socket(params: &PostgresConnectParams)
-> Result<InternalStream, PostgresConnectError> {
match params.target {
TargetTcp(ref host) => open_tcp_socket(host.as_slice(), params.port)
.map(|s| TcpStream(s)),
TargetUnix(ref path) => open_unix_socket(path, params.port)
.map(|s| UnixStream(s)),
}
}
fn initialize_stream(params: &PostgresConnectParams, ssl: &SslMode)
-> Result<MaybeSslStream<InternalStream>,
PostgresConnectError> {
let mut socket = try!(open_socket(params));
let (ssl_required, ctx) = match *ssl {
NoSsl => return Ok(NormalStream(socket)),
@ -358,50 +528,8 @@ fn initialize_stream(host: &str, port: Port, ssl: &SslMode)
}
}
fn initialize_unix(path: &Path)
-> Result<InternalStream, PostgresConnectError> {
match UnixStream::connect(path) {
Ok(unix) => Ok(NormalUnix(unix)),
Err(err) => Err(SocketError(err))
}
}
enum InternalStream {
NormalStream(TcpStream),
SslStream(SslStream<TcpStream>),
NormalUnix(UnixStream)
}
impl Reader for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
NormalStream(ref mut s) => s.read(buf),
SslStream(ref mut s) => s.read(buf),
NormalUnix(ref mut s) => s.read(buf)
}
}
}
impl Writer for InternalStream {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.write(buf),
SslStream(ref mut s) => s.write(buf),
NormalUnix(ref mut s) => s.write(buf)
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.flush(),
SslStream(ref mut s) => s.flush(),
NormalUnix(ref mut s) => s.flush()
}
}
}
struct InnerPostgresConnection {
stream: BufferedStream<InternalStream>,
stream: BufferedStream<MaybeSslStream<InternalStream>>,
next_stmt_id: uint,
notice_handler: ~PostgresNoticeHandler:Send,
notifications: RingBuf<PostgresNotification>,
@ -421,56 +549,11 @@ impl Drop for InnerPostgresConnection {
}
impl InnerPostgresConnection {
fn connect(url: &str, ssl: &SslMode) -> Result<InnerPostgresConnection,
PostgresConnectError> {
let Url {
host,
port,
user,
path,
query: mut args,
..
}: Url = match FromStr::from_str(url) {
Some(url) => url,
None => return Err(InvalidUrl)
};
let user = match user {
Some(user) => user,
None => return Err(MissingUser)
};
let port = match port {
Some(port) => FromStr::from_str(port).unwrap(),
None => DEFAULT_PORT
};
let stream = try!(initialize_stream(host, port, ssl));
if !path.is_empty() {
// path contains the leading /
let (_, path) = path.slice_shift_char();
args.push((~"database", path.to_owned()));
}
InnerPostgresConnection::connect_finish(stream, args, user)
}
fn connect_unix(socket_dir: &Path, port: Port, user: UserInfo, database: ~str)
-> Result<InnerPostgresConnection, PostgresConnectError> {
let mut socket = socket_dir.clone();
socket.push(format!(".s.PGSQL.{}", port));
let stream = try!(initialize_unix(&socket));
let mut args = Vec::new();
args.push((~"database", database));
InnerPostgresConnection::connect_finish(stream, args, user)
}
fn connect_finish(stream: InternalStream, mut args: Vec<(~str, ~str)>, user: UserInfo)
-> Result<InnerPostgresConnection, PostgresConnectError> {
fn connect<T: IntoConnectParams>(params: T, ssl: &SslMode)
-> Result<InnerPostgresConnection,
PostgresConnectError> {
let params = try!(params.into_connect_params());
let stream = try!(initialize_stream(&params, ssl));
let mut conn = InnerPostgresConnection {
stream: BufferedStream::new(stream),
@ -484,19 +567,36 @@ impl InnerPostgresConnection {
canary: CANARY,
};
args.push((~"client_encoding", ~"UTF8"));
let PostgresConnectParams {
user,
password,
database,
mut options,
..
} = params;
let user = match user {
Some(user) => user,
None => return Err(MissingUser),
};
options.push((~"client_encoding", ~"UTF8"));
// Postgres uses the value of TimeZone as the time zone for TIMESTAMP
// WITH TIME ZONE values. Timespec converts to GMT internally.
args.push((~"TimeZone", ~"GMT"));
options.push((~"TimeZone", ~"GMT"));
// We have to clone here since we need the user again for auth
args.push((~"user", user.user.clone()));
options.push((~"user", user.clone()));
match database {
Some(database) => options.push((~"database", database)),
None => {}
}
try_pg_conn!(conn.write_messages([StartupMessage {
version: message::PROTOCOL_VERSION,
parameters: args.as_slice()
parameters: options.as_slice()
}]));
try!(conn.handle_auth(user));
try!(conn.handle_auth(user, password));
loop {
match try_pg_conn!(conn.read_message()) {
@ -542,12 +642,12 @@ impl InnerPostgresConnection {
}
}
fn handle_auth(&mut self, user: UserInfo)
fn handle_auth(&mut self, user: ~str, pass: Option<~str>)
-> Result<(), PostgresConnectError> {
match try_pg_conn!(self.read_message()) {
AuthenticationOk => return Ok(()),
AuthenticationCleartextPassword => {
let pass = match user.pass {
let pass = match pass {
Some(pass) => pass,
None => return Err(MissingPassword)
};
@ -556,7 +656,6 @@ impl InnerPostgresConnection {
}]));
}
AuthenticationMD5Password { salt } => {
let UserInfo { user, pass } = user;
let pass = match pass {
Some(pass) => pass,
None => return Err(MissingPassword)
@ -734,67 +833,44 @@ pub struct PostgresConnection {
impl PostgresConnection {
/// Creates a new connection to a Postgres database.
///
/// The URL should be provided in the normal format:
/// Most applications can use a URL string in the normal format:
///
/// ```notrust
/// postgres://user[:password]@host[:port][/database][?param1=val1[[&param2=val2]...]]
/// postgresql://user[:password]@host[:port][/database][?param1=val1[[&param2=val2]...]]
/// ```
///
/// The password may be omitted if not required. The default Postgres port
/// (5432) is used if none is specified. The database name defaults to the
/// username if not specified.
///
/// # Example
/// To connect to the server via Unix sockets, `host` should be set to the
/// path to the directory containing the socket file. Since `/` is a
/// reserved character in URLs, the path should be URL encoded. If the
/// path contains non-UTF 8 characters, a `PostgresConnectParams` struct
/// should be created manually and passed in.
///
/// # Examples
///
/// ```rust,no_run
/// # use postgres::{PostgresConnection, NoSsl};
/// let url = "postgres://postgres:hunter2@localhost:2994/foodb";
/// let url = "postgresql://postgres:hunter2@localhost:2994/foodb";
/// let maybe_conn = PostgresConnection::connect(url, &NoSsl);
/// let conn = match maybe_conn {
/// Ok(conn) => conn,
/// Err(err) => fail!("Error connecting: {}", err)
/// };
/// ```
pub fn connect(url: &str, ssl: &SslMode) -> Result<PostgresConnection,
PostgresConnectError> {
InnerPostgresConnection::connect(url, ssl).map(|conn| {
PostgresConnection {
conn: RefCell::new(conn)
}
})
}
/// Creates a new connection to a Postgres database.
///
/// The path should be the directory containing the socket.
///
/// The password in the UserInfo may be omitted if not required.
///
/// # Example
///
/// ```rust,no_run
/// # extern crate url;
/// # extern crate postgres;
/// # use url::UserInfo;
/// # use postgres::PostgresConnection;
/// # fn main() {
/// let path = Path::new("/tmp");
/// let port = 5432;
/// let user = UserInfo::new(~"username", None);
/// let database = ~"postgres";
/// let maybe_conn = PostgresConnection::connect_unix(&path, 5432, user, database);
/// let conn = match maybe_conn {
/// Ok(conn) => conn,
/// Err(err) => fail!("Error connecting: {}", err)
/// };
/// # }
/// # use postgres::{PostgresConnection, NoSsl};
/// let url = "postgresql://postgres@%2Ftmp";
/// let maybe_conn = PostgresConnection::connect(url, &NoSsl);
/// ```
pub fn connect_unix(path: &Path, port: Port, user: UserInfo, database: ~str)
-> Result<PostgresConnection, PostgresConnectError> {
InnerPostgresConnection::connect_unix(path, port, user, database).map(|conn| {
PostgresConnection {
conn: RefCell::new(conn)
}
pub fn connect<T: IntoConnectParams>(params: T, ssl: &SslMode)
-> Result<PostgresConnection,
PostgresConnectError> {
InnerPostgresConnection::connect(params, ssl).map(|conn| {
PostgresConnection { conn: RefCell::new(conn) }
})
}

View File

@ -4,6 +4,8 @@ use std::cast;
use sync::{Arc, Mutex};
use {PostgresNotifications,
PostgresConnectParams,
IntoConnectParams,
PostgresResult,
PostgresCancelData,
PostgresConnection,
@ -14,7 +16,7 @@ use error::PostgresConnectError;
use types::ToSql;
struct InnerConnectionPool {
url: ~str,
params: PostgresConnectParams,
ssl: SslMode,
// Actually Vec<~PostgresConnection>
pool: Vec<*()>,
@ -36,7 +38,7 @@ impl Drop for InnerConnectionPool {
impl InnerConnectionPool {
fn add_connection(&mut self) -> Result<(), PostgresConnectError> {
PostgresConnection::connect(self.url, &self.ssl)
PostgresConnection::connect(self.params.clone(), &self.ssl)
.map(|c| unsafe { self.pool.push(cast::transmute(~c)); })
}
}
@ -50,7 +52,7 @@ impl InnerConnectionPool {
/// ```rust,no_run
/// # use postgres::NoSsl;
/// # use postgres::pool::PostgresConnectionPool;
/// let pool = PostgresConnectionPool::new(~"postgres://postgres@localhost",
/// let pool = PostgresConnectionPool::new("postgres://postgres@localhost",
/// NoSsl, 5).unwrap();
/// for _ in range(0, 10) {
/// let pool = pool.clone();
@ -70,10 +72,10 @@ impl PostgresConnectionPool {
///
/// Returns an error if the specified number of connections cannot be
/// created.
pub fn new(url: ~str, ssl: SslMode, pool_size: uint)
pub fn new<T: IntoConnectParams>(params: T, ssl: SslMode, pool_size: uint)
-> Result<PostgresConnectionPool, PostgresConnectError> {
let mut pool = InnerConnectionPool {
url: url,
params: try!(params.into_connect_params()),
ssl: ssl,
pool: Vec::new(),
};

View File

@ -9,7 +9,7 @@ use openssl::ssl::{SslContext, Sslv3};
use std::f32;
use std::f64;
use std::io::timer;
use url::UserInfo;
use url;
use {PostgresNoticeHandler,
PostgresNotification,
@ -50,7 +50,7 @@ macro_rules! or_fail(
#[test]
// Make sure we can take both connections at once and can still get one after
fn test_pool() {
let pool = or_fail!(PostgresConnectionPool::new(~"postgres://postgres@localhost",
let pool = or_fail!(PostgresConnectionPool::new("postgres://postgres@localhost",
NoSsl, 2));
let (stream1, stream2) = sync::duplex();
@ -121,7 +121,8 @@ fn test_unix_connection() {
let unix_socket_directory = unix_socket_directories.splitn(',', 1).next().unwrap();
let conn = or_fail!(PostgresConnection::connect_unix(&Path::new(unix_socket_directory), 5432, UserInfo::new(~"postgres", None), ~"postgres"));
let url = format!("postgres://postgres@{}", url::encode_component(unix_socket_directory));
let conn = or_fail!(PostgresConnection::connect(url.as_slice(), &NoSsl));
assert!(conn.finish().is_ok());
}