From 707e7ccfa43736f450f4abfd993224c6d5d54216 Mon Sep 17 00:00:00 2001 From: Matthijs van der Vleuten Date: Fri, 18 Apr 2014 23:29:51 +0200 Subject: [PATCH 1/2] Add support for connecting through Unix sockets Includes connection test (assumes socket is in /tmp, the default location). --- src/lib.rs | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++--- src/test.rs | 7 ++++ 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index db2d360b..44c65268 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,6 +88,7 @@ use std::io::{BufferedStream, IoResult}; use std::io::net; use std::io::net::ip::{Port, SocketAddr}; use std::io::net::tcp::TcpStream; +use std::io::net::unix::UnixStream; use std::mem; use std::str; use std::task; @@ -327,6 +328,13 @@ fn open_socket(host: &str, port: Port) Err(SocketError(err.unwrap())) } +fn open_unix(path: &Path) -> Result { + match UnixStream::connect(path) { + Ok(unix) => Ok(unix), + Err(err) => Err(SocketError(err)) + } +} + fn initialize_stream(host: &str, port: Port, ssl: &SslMode) -> Result { let mut socket = match open_socket(host, port) { @@ -357,16 +365,26 @@ fn initialize_stream(host: &str, port: Port, ssl: &SslMode) } } +fn initialize_unix(path: &Path) + -> Result { + match open_unix(path) { + Ok(unix) => Ok(NormalUnix(unix)), + Err(err) => Err(err) + } +} + enum InternalStream { NormalStream(TcpStream), - SslStream(SslStream) + SslStream(SslStream), + NormalUnix(UnixStream) } impl Reader for InternalStream { fn read(&mut self, buf: &mut [u8]) -> IoResult { match *self { NormalStream(ref mut s) => s.read(buf), - SslStream(ref mut s) => s.read(buf) + SslStream(ref mut s) => s.read(buf), + NormalUnix(ref mut s) => s.read(buf) } } } @@ -375,14 +393,16 @@ impl Writer for InternalStream { fn write(&mut self, buf: &[u8]) -> IoResult<()> { match *self { NormalStream(ref mut s) => s.write(buf), - SslStream(ref mut s) => s.write(buf) + SslStream(ref mut s) => s.write(buf), + NormalUnix(ref mut s) => s.write(buf) } } fn flush(&mut self) -> IoResult<()> { match *self { NormalStream(ref mut s) => s.flush(), - SslStream(ref mut s) => s.flush() + SslStream(ref mut s) => s.flush(), + NormalUnix(ref mut s) => s.flush() } } } @@ -434,7 +454,7 @@ impl InnerPostgresConnection { let stream = try!(initialize_stream(host, port, ssl)); - let mut conn = InnerPostgresConnection { + let conn = InnerPostgresConnection { stream: BufferedStream::new(stream), next_stmt_id: 0, notice_handler: ~DefaultNoticeHandler, @@ -457,6 +477,44 @@ impl InnerPostgresConnection { let (_, path) = path.slice_shift_char(); args.push((~"database", path.to_owned())); } + + InnerPostgresConnection::connect_finish(conn, args, user) + } + + fn connect_unix(socket_dir: &Path, port: Port, user: UserInfo, database: ~str) + -> Result { + let mut socket = socket_dir.clone(); + socket.push(format!(".s.PGSQL.{}", port)); + + let stream = try!(initialize_unix(&socket)); + + let conn = InnerPostgresConnection { + stream: BufferedStream::new(stream), + next_stmt_id: 0, + notice_handler: ~DefaultNoticeHandler, + notifications: RingBuf::new(), + cancel_data: PostgresCancelData { process_id: 0, secret_key: 0 }, + unknown_types: HashMap::new(), + desynchronized: false, + finished: false, + canary: CANARY, + }; + + let mut args = Vec::new(); + + args.push((~"client_encoding", ~"UTF8")); + // Postgres uses the value of TimeZone as the time zone for TIMESTAMP + // WITH TIME ZONE values. Timespec converts to GMT internally. + args.push((~"TimeZone", ~"GMT")); + // We have to clone here since we need the user again for auth + args.push((~"user", user.user.clone())); + args.push((~"database", database)); + + InnerPostgresConnection::connect_finish(conn, args, user) + } + + fn connect_finish(mut conn: InnerPostgresConnection, args: Vec<(~str, ~str)>, user: UserInfo) + -> Result { try_pg_conn!(conn.write_messages([StartupMessage { version: message::PROTOCOL_VERSION, parameters: args.as_slice() @@ -730,6 +788,35 @@ impl PostgresConnection { }) } + /// Creates a new connection to a Postgres database. + /// + /// The path should be the directory containing the socket. + /// + /// The password in the UserInfo may be omitted if not required. + /// + /// # Example + /// + /// ```rust,no_run + /// # use postgres::PostgresConnection; + /// let path = Path::new("/tmp"); + /// let port = 5432; + /// let user = UserInfo::new(~"username", None); + /// let database = ~"postgres"; + /// let maybe_conn = PostgresConnection::connect_unix(&path, 5432, user, database); + /// let conn = match maybe_conn { + /// Ok(conn) => conn, + /// Err(err) => fail!("Error connecting: {}", err) + /// }; + /// ``` + pub fn connect_unix(path: &Path, port: Port, user: UserInfo, database: ~str) + -> Result { + InnerPostgresConnection::connect_unix(path, port, user, database).map(|conn| { + PostgresConnection { + conn: RefCell::new(conn) + } + }) + } + /// Sets the notice handler for the connection, returning the old handler. pub fn set_notice_handler(&self, handler: ~PostgresNoticeHandler:Send) -> ~PostgresNoticeHandler:Send { diff --git a/src/test.rs b/src/test.rs index dd9bb6b4..ba77c042 100644 --- a/src/test.rs +++ b/src/test.rs @@ -9,6 +9,7 @@ use openssl::ssl::{SslContext, Sslv3}; use std::f32; use std::f64; use std::io::timer; +use url::UserInfo; use {PostgresNoticeHandler, PostgresNotification, @@ -107,6 +108,12 @@ fn test_connection_finish() { assert!(conn.finish().is_ok()); } +#[test] +fn test_unix_connection() { + let conn = or_fail!(PostgresConnection::connect_unix(&Path::new("/tmp"), 5432, UserInfo::new(~"postgres", None), ~"postgres")); + assert!(conn.finish().is_ok()); +} + #[test] fn test_transaction_commit() { let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl)); From 5e85d6b9bdde45bf3b69153c1988f94808435d32 Mon Sep 17 00:00:00 2001 From: Matthijs van der Vleuten Date: Sat, 19 Apr 2014 11:04:37 +0200 Subject: [PATCH 2/2] test_unix_connection now detects the socket directory. Change pg_hba.conf to allow connections through the socket. Ignore connect_unix doc test. It requires `extern crate url;` which is not allowed with rustdoc. Also, per comments on PR #35: - Inline open_unix - Centralize common code from connect and connect_unix in connect_finish. --- src/lib.rs | 54 ++++++++++++++-------------------------------- src/test.rs | 13 ++++++++++- travis/pg_hba.conf | 2 ++ 3 files changed, 30 insertions(+), 39 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 44c65268..029052c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -328,13 +328,6 @@ fn open_socket(host: &str, port: Port) Err(SocketError(err.unwrap())) } -fn open_unix(path: &Path) -> Result { - match UnixStream::connect(path) { - Ok(unix) => Ok(unix), - Err(err) => Err(SocketError(err)) - } -} - fn initialize_stream(host: &str, port: Port, ssl: &SslMode) -> Result { let mut socket = match open_socket(host, port) { @@ -367,9 +360,9 @@ fn initialize_stream(host: &str, port: Port, ssl: &SslMode) fn initialize_unix(path: &Path) -> Result { - match open_unix(path) { + match UnixStream::connect(path) { Ok(unix) => Ok(NormalUnix(unix)), - Err(err) => Err(err) + Err(err) => Err(SocketError(err)) } } @@ -454,31 +447,13 @@ impl InnerPostgresConnection { let stream = try!(initialize_stream(host, port, ssl)); - let conn = InnerPostgresConnection { - stream: BufferedStream::new(stream), - next_stmt_id: 0, - notice_handler: ~DefaultNoticeHandler, - notifications: RingBuf::new(), - cancel_data: PostgresCancelData { process_id: 0, secret_key: 0 }, - unknown_types: HashMap::new(), - desynchronized: false, - finished: false, - canary: CANARY, - }; - - args.push((~"client_encoding", ~"UTF8")); - // Postgres uses the value of TimeZone as the time zone for TIMESTAMP - // WITH TIME ZONE values. Timespec converts to GMT internally. - args.push((~"TimeZone", ~"GMT")); - // We have to clone here since we need the user again for auth - args.push((~"user", user.user.clone())); if !path.is_empty() { // path contains the leading / let (_, path) = path.slice_shift_char(); args.push((~"database", path.to_owned())); } - InnerPostgresConnection::connect_finish(conn, args, user) + InnerPostgresConnection::connect_finish(stream, args, user) } fn connect_unix(socket_dir: &Path, port: Port, user: UserInfo, database: ~str) @@ -488,7 +463,16 @@ impl InnerPostgresConnection { let stream = try!(initialize_unix(&socket)); - let conn = InnerPostgresConnection { + let mut args = Vec::new(); + args.push((~"database", database)); + + InnerPostgresConnection::connect_finish(stream, args, user) + } + + fn connect_finish(stream: InternalStream, mut args: Vec<(~str, ~str)>, user: UserInfo) + -> Result { + + let mut conn = InnerPostgresConnection { stream: BufferedStream::new(stream), next_stmt_id: 0, notice_handler: ~DefaultNoticeHandler, @@ -500,21 +484,13 @@ impl InnerPostgresConnection { canary: CANARY, }; - let mut args = Vec::new(); - args.push((~"client_encoding", ~"UTF8")); // Postgres uses the value of TimeZone as the time zone for TIMESTAMP // WITH TIME ZONE values. Timespec converts to GMT internally. args.push((~"TimeZone", ~"GMT")); // We have to clone here since we need the user again for auth args.push((~"user", user.user.clone())); - args.push((~"database", database)); - InnerPostgresConnection::connect_finish(conn, args, user) - } - - fn connect_finish(mut conn: InnerPostgresConnection, args: Vec<(~str, ~str)>, user: UserInfo) - -> Result { try_pg_conn!(conn.write_messages([StartupMessage { version: message::PROTOCOL_VERSION, parameters: args.as_slice() @@ -796,7 +772,9 @@ impl PostgresConnection { /// /// # Example /// - /// ```rust,no_run + /// ```rust,ignore + /// # extern crate url; + /// # use url::UserInfo; /// # use postgres::PostgresConnection; /// let path = Path::new("/tmp"); /// let port = 5432; diff --git a/src/test.rs b/src/test.rs index ba77c042..cc61c802 100644 --- a/src/test.rs +++ b/src/test.rs @@ -110,7 +110,18 @@ fn test_connection_finish() { #[test] fn test_unix_connection() { - let conn = or_fail!(PostgresConnection::connect_unix(&Path::new("/tmp"), 5432, UserInfo::new(~"postgres", None), ~"postgres")); + let conn = or_fail!(PostgresConnection::connect("postgres://postgres@localhost", &NoSsl)); + let stmt = or_fail!(conn.prepare("SHOW unix_socket_directories")); + let result = or_fail!(stmt.query([])); + let unix_socket_directories: ~str = result.map(|row| row[1]).next().unwrap(); + + if unix_socket_directories == ~"" { + fail!("can't test connect_unix; unix_socket_directories is empty"); + } + + let unix_socket_directory = unix_socket_directories.splitn(',', 1).next().unwrap(); + + let conn = or_fail!(PostgresConnection::connect_unix(&Path::new(unix_socket_directory), 5432, UserInfo::new(~"postgres", None), ~"postgres")); assert!(conn.finish().is_ok()); } diff --git a/travis/pg_hba.conf b/travis/pg_hba.conf index ee602658..7fd94263 100644 --- a/travis/pg_hba.conf +++ b/travis/pg_hba.conf @@ -8,3 +8,5 @@ host all md5_user ::1/128 md5 host all postgres 127.0.0.1/32 trust # IPv6 local connections: host all postgres ::1/128 trust +# Unix socket connections: +local all postgres trust