Merge pull request #35 from zr40/add-unix-domain-socket

Add support for connecting through Unix sockets
This commit is contained in:
Steven Fackler 2014-04-19 12:45:48 -04:00
commit cc3db19974
3 changed files with 94 additions and 9 deletions

View File

@ -88,6 +88,7 @@ use std::io::{BufferedStream, IoResult};
use std::io::net;
use std::io::net::ip::{Port, SocketAddr};
use std::io::net::tcp::TcpStream;
use std::io::net::unix::UnixStream;
use std::mem;
use std::str;
use std::task;
@ -357,16 +358,26 @@ fn initialize_stream(host: &str, port: Port, ssl: &SslMode)
}
}
fn initialize_unix(path: &Path)
-> Result<InternalStream, PostgresConnectError> {
match UnixStream::connect(path) {
Ok(unix) => Ok(NormalUnix(unix)),
Err(err) => Err(SocketError(err))
}
}
enum InternalStream {
NormalStream(TcpStream),
SslStream(SslStream<TcpStream>)
SslStream(SslStream<TcpStream>),
NormalUnix(UnixStream)
}
impl Reader for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<uint> {
match *self {
NormalStream(ref mut s) => s.read(buf),
SslStream(ref mut s) => s.read(buf)
SslStream(ref mut s) => s.read(buf),
NormalUnix(ref mut s) => s.read(buf)
}
}
}
@ -375,14 +386,16 @@ impl Writer for InternalStream {
fn write(&mut self, buf: &[u8]) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.write(buf),
SslStream(ref mut s) => s.write(buf)
SslStream(ref mut s) => s.write(buf),
NormalUnix(ref mut s) => s.write(buf)
}
}
fn flush(&mut self) -> IoResult<()> {
match *self {
NormalStream(ref mut s) => s.flush(),
SslStream(ref mut s) => s.flush()
SslStream(ref mut s) => s.flush(),
NormalUnix(ref mut s) => s.flush()
}
}
}
@ -434,6 +447,31 @@ impl InnerPostgresConnection {
let stream = try!(initialize_stream(host, port, ssl));
if !path.is_empty() {
// path contains the leading /
let (_, path) = path.slice_shift_char();
args.push((~"database", path.to_owned()));
}
InnerPostgresConnection::connect_finish(stream, args, user)
}
fn connect_unix(socket_dir: &Path, port: Port, user: UserInfo, database: ~str)
-> Result<InnerPostgresConnection, PostgresConnectError> {
let mut socket = socket_dir.clone();
socket.push(format!(".s.PGSQL.{}", port));
let stream = try!(initialize_unix(&socket));
let mut args = Vec::new();
args.push((~"database", database));
InnerPostgresConnection::connect_finish(stream, args, user)
}
fn connect_finish(stream: InternalStream, mut args: Vec<(~str, ~str)>, user: UserInfo)
-> Result<InnerPostgresConnection, PostgresConnectError> {
let mut conn = InnerPostgresConnection {
stream: BufferedStream::new(stream),
next_stmt_id: 0,
@ -452,11 +490,7 @@ impl InnerPostgresConnection {
args.push((~"TimeZone", ~"GMT"));
// We have to clone here since we need the user again for auth
args.push((~"user", user.user.clone()));
if !path.is_empty() {
// path contains the leading /
let (_, path) = path.slice_shift_char();
args.push((~"database", path.to_owned()));
}
try_pg_conn!(conn.write_messages([StartupMessage {
version: message::PROTOCOL_VERSION,
parameters: args.as_slice()
@ -730,6 +764,37 @@ impl PostgresConnection {
})
}
/// Creates a new connection to a Postgres database.
///
/// The path should be the directory containing the socket.
///
/// The password in the UserInfo may be omitted if not required.
///
/// # Example
///
/// ```rust,ignore
/// # extern crate url;
/// # use url::UserInfo;
/// # use postgres::PostgresConnection;
/// let path = Path::new("/tmp");
/// let port = 5432;
/// let user = UserInfo::new(~"username", None);
/// let database = ~"postgres";
/// let maybe_conn = PostgresConnection::connect_unix(&path, 5432, user, database);
/// let conn = match maybe_conn {
/// Ok(conn) => conn,
/// Err(err) => fail!("Error connecting: {}", err)
/// };
/// ```
pub fn connect_unix(path: &Path, port: Port, user: UserInfo, database: ~str)
-> Result<PostgresConnection, PostgresConnectError> {
InnerPostgresConnection::connect_unix(path, port, user, database).map(|conn| {
PostgresConnection {
conn: RefCell::new(conn)
}
})
}
/// Sets the notice handler for the connection, returning the old handler.
pub fn set_notice_handler(&self, handler: ~PostgresNoticeHandler:Send)
-> ~PostgresNoticeHandler:Send {

View File

@ -9,6 +9,7 @@ use openssl::ssl::{SslContext, Sslv3};
use std::f32;
use std::f64;
use std::io::timer;
use url::UserInfo;
use {PostgresNoticeHandler,
PostgresNotification,
@ -107,6 +108,23 @@ fn test_connection_finish() {
assert!(conn.finish().is_ok());
}
#[test]
fn test_unix_connection() {
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));
let stmt = or_fail!(conn.prepare("SHOW unix_socket_directories"));
let result = or_fail!(stmt.query([]));
let unix_socket_directories: ~str = result.map(|row| row[1]).next().unwrap();
if unix_socket_directories == ~"" {
fail!("can't test connect_unix; unix_socket_directories is empty");
}
let unix_socket_directory = unix_socket_directories.splitn(',', 1).next().unwrap();
let conn = or_fail!(PostgresConnection::connect_unix(&Path::new(unix_socket_directory), 5432, UserInfo::new(~"postgres", None), ~"postgres"));
assert!(conn.finish().is_ok());
}
#[test]
fn test_transaction_commit() {
let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl));

View File

@ -8,3 +8,5 @@ host all md5_user ::1/128 md5
host all postgres 127.0.0.1/32 trust
# IPv6 local connections:
host all postgres ::1/128 trust
# Unix socket connections:
local all postgres trust