From be499822860e2aa7d50590be99cf7aafd9dd5998 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 26 Apr 2015 23:21:46 -0700 Subject: [PATCH] Make SSL infrastructure implementation agnostic --- src/error.rs | 13 +---- src/io_util.rs | 81 ++++++++++++++++++++++---- src/lib.rs | 30 ++++------ tests/test.rs | 138 ++++++++++++++++++++++----------------------- tests/types/mod.rs | 16 +++--- 5 files changed, 161 insertions(+), 117 deletions(-) diff --git a/src/error.rs b/src/error.rs index 7b491d09..62a5b778 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,6 @@ pub use ugh_privacy::DbError; use byteorder; -use openssl::ssl::error::SslError; use phf; use std::error; use std::convert::From; @@ -29,8 +28,8 @@ pub enum ConnectError { UnsupportedAuthentication, /// The Postgres server does not support SSL encryption. NoSslSupport, - /// There was an error initializing the SSL session. - SslError(SslError), + /// There was an error initializing the SSL session + SslError(Box), /// There was an error communicating with the server. IoError(io::Error), /// The server sent an unexpected response. @@ -67,7 +66,7 @@ impl error::Error for ConnectError { fn cause(&self) -> Option<&error::Error> { match *self { ConnectError::DbError(ref err) => Some(err), - ConnectError::SslError(ref err) => Some(err), + ConnectError::SslError(ref err) => Some(&**err), ConnectError::IoError(ref err) => Some(err), _ => None } @@ -86,12 +85,6 @@ impl From for ConnectError { } } -impl From for ConnectError { - fn from(err: SslError) -> ConnectError { - ConnectError::SslError(err) - } -} - impl From for ConnectError { fn from(err: byteorder::Error) -> ConnectError { ConnectError::IoError(From::from(err)) diff --git a/src/io_util.rs b/src/io_util.rs index 2661d09d..15ccc525 100644 --- a/src/io_util.rs +++ b/src/io_util.rs @@ -1,4 +1,5 @@ -use openssl::ssl::{SslStream, MaybeSslStream}; +use openssl::ssl::{SslStream, SslContext}; +use std::error::Error; use std::io; use std::io::prelude::*; use std::net::TcpStream; @@ -6,19 +7,76 @@ use std::net::TcpStream; use unix_socket::UnixStream; use byteorder::ReadBytesExt; -use {ConnectParams, SslMode, ConnectTarget, ConnectError}; +use {ConnectParams, ConnectTarget, ConnectError}; use message; use message::WriteMessage; use message::FrontendMessage::SslRequest; const DEFAULT_PORT: u16 = 5432; +pub trait StreamWrapper: Read+Write+Send { + fn get_ref(&self) -> &S; + fn get_mut(&mut self) -> &mut S; +} + +impl StreamWrapper for SslStream { + fn get_ref(&self) -> &S { + self.get_ref() + } + + fn get_mut(&mut self) -> &mut S { + self.get_mut() + } +} + +pub trait NegotiateSsl { + fn negotiate_ssl(&mut self, stream: S) -> Result>, Box> + where S: Read+Write+Send+'static; +} + +impl NegotiateSsl for SslContext { + fn negotiate_ssl(&mut self, stream: S) -> Result>, Box> + where S: Read+Write+Send+'static { + let stream = try!(SslStream::new(self, stream)); + Ok(Box::new(stream)) + } +} + +/// Specifies the SSL support requested for a new connection. +pub enum SslMode { + /// The connection will not use SSL. + None, + /// The connection will use SSL if the backend supports it. + Prefer(N), + /// The connection must use SSL. + Require(N), +} + +pub enum NoSsl {} + +impl NegotiateSsl for NoSsl { + fn negotiate_ssl(&mut self, stream: S) + -> Result>, Box> { + match *self {} + } +} + pub enum InternalStream { Tcp(TcpStream), #[cfg(feature = "unix_socket")] Unix(UnixStream), } +impl StreamWrapper for InternalStream { + fn get_ref(&self) -> &InternalStream { + self + } + + fn get_mut(&mut self) -> &mut InternalStream { + self + } +} + impl Read for InternalStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { match *self { @@ -62,14 +120,15 @@ fn open_socket(params: &ConnectParams) -> Result { } } -pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode) - -> Result, ConnectError> { +pub fn initialize_stream(params: &ConnectParams, ssl: &mut SslMode) + -> Result>, ConnectError> + where N: NegotiateSsl { let mut socket = try!(open_socket(params)); - let (ssl_required, ctx) = match *ssl { - SslMode::None => return Ok(MaybeSslStream::Normal(socket)), - SslMode::Prefer(ref ctx) => (false, ctx), - SslMode::Require(ref ctx) => (true, ctx) + let (ssl_required, negotiator) = match *ssl { + SslMode::None => return Ok(Box::new(socket)), + SslMode::Prefer(ref mut negotiator) => (false, negotiator), + SslMode::Require(ref mut negotiator) => (true, negotiator), }; try!(socket.write_message(&SslRequest { code: message::SSL_CODE })); @@ -79,12 +138,12 @@ pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode) if ssl_required { return Err(ConnectError::NoSslSupport); } else { - return Ok(MaybeSslStream::Normal(socket)); + return Ok(Box::new(socket)); } } - match SslStream::new(ctx, socket) { - Ok(stream) => Ok(MaybeSslStream::Ssl(stream)), + match negotiator.negotiate_ssl(socket) { + Ok(stream) => Ok(stream), Err(err) => Err(ConnectError::SslError(err)) } } diff --git a/src/lib.rs b/src/lib.rs index abf83fe8..e9a1fd8c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,7 +59,6 @@ extern crate debug_builders; use bufstream::BufStream; use debug_builders::DebugStruct; use openssl::crypto::hash::{self, Hasher}; -use openssl::ssl::{SslContext, MaybeSslStream}; use serialize::hex::ToHex; use std::ascii::AsciiExt; use std::borrow::{ToOwned, Cow}; @@ -80,6 +79,7 @@ use std::path::PathBuf; pub use error::{Error, ConnectError, SqlState, DbError, ErrorPosition}; #[doc(inline)] pub use types::{Oid, Type, Kind, ToSql, FromSql}; +pub use io_util::{SslMode, NegotiateSsl, StreamWrapper, NoSsl}; use types::IsNull; #[doc(inline)] pub use types::Slice; @@ -387,8 +387,9 @@ pub struct CancelData { /// # let _ = /// postgres::cancel_query(url, &SslMode::None, cancel_data); /// ``` -pub fn cancel_query(params: T, ssl: &SslMode, data: CancelData) - -> result::Result<(), ConnectError> where T: IntoConnectParams { +pub fn cancel_query(params: T, ssl: &mut SslMode, data: CancelData) + -> result::Result<(), ConnectError> + where T: IntoConnectParams, N: NegotiateSsl { let params = try!(params.into_connect_params()); let mut socket = try!(io_util::initialize_stream(¶ms, ssl)); @@ -464,7 +465,7 @@ struct CachedStatement { } struct InnerConnection { - stream: BufStream>, + stream: BufStream>>, notice_handler: Box, notifications: VecDeque, cancel_data: CancelData, @@ -486,8 +487,9 @@ impl Drop for InnerConnection { } impl InnerConnection { - fn connect(params: T, ssl: &SslMode) -> result::Result - where T: IntoConnectParams { + fn connect(params: T, ssl: &mut SslMode) + -> result::Result + where T: IntoConnectParams, N: NegotiateSsl { let params = try!(params.into_connect_params()); let stream = try!(io_util::initialize_stream(¶ms, ssl)); @@ -1005,8 +1007,9 @@ impl Connection { /// let conn = try!(Connection::connect(params, &SslMode::None)); /// # Ok(()) }; /// ``` - pub fn connect(params: T, ssl: &SslMode) -> result::Result - where T: IntoConnectParams { + pub fn connect(params: T, ssl: &mut SslMode) + -> result::Result + where T: IntoConnectParams, N: NegotiateSsl { InnerConnection::connect(params, ssl).map(|conn| { Connection { conn: RefCell::new(conn) } }) @@ -1244,17 +1247,6 @@ impl Connection { } } -/// Specifies the SSL support requested for a new connection. -#[derive(Debug)] -pub enum SslMode { - /// The connection will not use SSL. - None, - /// The connection will use SSL if the backend supports it. - Prefer(SslContext), - /// The connection must use SSL. - Require(SslContext) -} - /// Represents a transaction on a database connection. /// /// The transaction will roll back by default. diff --git a/tests/test.rs b/tests/test.rs index df596854..c15257f4 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -41,17 +41,17 @@ mod types; #[test] fn test_non_default_database() { - or_panic!(Connection::connect("postgres://postgres@localhost/postgres", &SslMode::None)); + or_panic!(Connection::connect("postgres://postgres@localhost/postgres", &mut SslMode::None)); } #[test] fn test_url_terminating_slash() { - or_panic!(Connection::connect("postgres://postgres@localhost/", &SslMode::None)); + or_panic!(Connection::connect("postgres://postgres@localhost/", &mut SslMode::None)); } #[test] fn test_prepare_err() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = conn.prepare("invalid sql database"); match stmt { Err(Error::DbError(ref e)) if e.code() == &SyntaxError && e.position() == Some(&Normal(1)) => {} @@ -62,7 +62,7 @@ fn test_prepare_err() { #[test] fn test_unknown_database() { - match Connection::connect("postgres://postgres@localhost/asdf", &SslMode::None) { + match Connection::connect("postgres://postgres@localhost/asdf", &mut SslMode::None) { Err(ConnectError::DbError(ref e)) if e.code() == &InvalidCatalogName => {} Err(resp) => panic!("Unexpected result {:?}", resp), _ => panic!("Unexpected result"), @@ -71,14 +71,14 @@ fn test_unknown_database() { #[test] fn test_connection_finish() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); assert!(conn.finish().is_ok()); } #[test] #[cfg_attr(not(feature = "unix_socket"), ignore)] fn test_unix_connection() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SHOW unix_socket_directories")); let result = or_panic!(stmt.query(&[])); let unix_socket_directories: String = result.iter().map(|row| row.get(0)).next().unwrap(); @@ -92,13 +92,13 @@ fn test_unix_connection() { let path = url::percent_encoding::utf8_percent_encode( unix_socket_directory, url::percent_encoding::USERNAME_ENCODE_SET); let url = format!("postgres://postgres@{}", path); - let conn = or_panic!(Connection::connect(&url[..], &SslMode::None)); + let conn = or_panic!(Connection::connect(&url[..], &mut SslMode::None)); assert!(conn.finish().is_ok()); } #[test] fn test_transaction_commit() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); let trans = or_panic!(conn.transaction()); @@ -114,7 +114,7 @@ fn test_transaction_commit() { #[test] fn test_transaction_commit_finish() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); let trans = or_panic!(conn.transaction()); @@ -130,7 +130,7 @@ fn test_transaction_commit_finish() { #[test] fn test_transaction_commit_method() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); let trans = or_panic!(conn.transaction()); @@ -145,7 +145,7 @@ fn test_transaction_commit_method() { #[test] fn test_transaction_rollback() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); or_panic!(conn.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); @@ -162,7 +162,7 @@ fn test_transaction_rollback() { #[test] fn test_transaction_rollback_finish() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); or_panic!(conn.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); @@ -179,7 +179,7 @@ fn test_transaction_rollback_finish() { #[test] fn test_nested_transactions() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1)", &[])); @@ -225,7 +225,7 @@ fn test_nested_transactions() { #[test] fn test_nested_transactions_finish() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1)", &[])); @@ -281,7 +281,7 @@ fn test_nested_transactions_finish() { #[test] #[should_panic(expected = "active transaction")] fn test_conn_trans_when_nested() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let _trans = or_panic!(conn.transaction()); conn.transaction().unwrap(); } @@ -289,7 +289,7 @@ fn test_conn_trans_when_nested() { #[test] #[should_panic(expected = "active transaction")] fn test_trans_with_nested_trans() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let trans = or_panic!(conn.transaction()); let _trans2 = or_panic!(trans.transaction()); trans.transaction().unwrap(); @@ -297,7 +297,7 @@ fn test_trans_with_nested_trans() { #[test] fn test_stmt_execute_after_transaction() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let trans = or_panic!(conn.transaction()); let stmt = or_panic!(trans.prepare("SELECT 1")); or_panic!(trans.finish()); @@ -307,7 +307,7 @@ fn test_stmt_execute_after_transaction() { #[test] fn test_stmt_finish() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY)", &[])); let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); assert!(stmt.finish().is_ok()); @@ -315,7 +315,7 @@ fn test_stmt_finish() { #[test] fn test_batch_execute() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let query = "CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY); INSERT INTO foo (id) VALUES (10);"; or_panic!(conn.batch_execute(query)); @@ -328,7 +328,7 @@ fn test_batch_execute() { #[test] fn test_batch_execute_error() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let query = "CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY); INSERT INTO foo (id) VALUES (10); asdfa; @@ -345,7 +345,7 @@ fn test_batch_execute_error() { #[test] fn test_transaction_batch_execute() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let trans = or_panic!(conn.transaction()); let query = "CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY); INSERT INTO foo (id) VALUES (10);"; @@ -359,7 +359,7 @@ fn test_transaction_batch_execute() { #[test] fn test_query() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY)", &[])); or_panic!(conn.execute("INSERT INTO foo (id) VALUES ($1), ($2)", &[&1i64, &2i64])); @@ -371,7 +371,7 @@ fn test_query() { #[test] fn test_error_after_datarow() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare(" SELECT (SELECT generate_series(1, ss.i)) @@ -388,7 +388,7 @@ FROM (SELECT gs.i #[test] fn test_lazy_query() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let trans = or_panic!(conn.transaction()); or_panic!(trans.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); @@ -405,8 +405,8 @@ fn test_lazy_query() { #[test] #[should_panic(expected = "same `Connection` as")] fn test_lazy_query_wrong_conn() { - let conn1 = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); - let conn2 = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn1 = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); + let conn2 = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let trans = or_panic!(conn1.transaction()); let stmt = or_panic!(conn2.prepare("SELECT 1::INT")); @@ -415,14 +415,14 @@ fn test_lazy_query_wrong_conn() { #[test] fn test_param_types() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT $1::INT, $2::VARCHAR")); assert_eq!(stmt.param_types(), &[Type::Int4, Type::Varchar][..]); } #[test] fn test_columns() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT 1::INT as a, 'hi'::VARCHAR as b")); let cols = stmt.columns(); assert_eq!(2, cols.len()); @@ -434,7 +434,7 @@ fn test_columns() { #[test] fn test_execute_counts() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); assert_eq!(0, or_panic!(conn.execute("CREATE TEMPORARY TABLE foo ( id SERIAL PRIMARY KEY, b INT @@ -447,7 +447,7 @@ fn test_execute_counts() { #[test] fn test_wrong_param_type() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); match conn.execute("SELECT $1::VARCHAR", &[&1i32]) { Err(Error::WrongType(_)) => {} res => panic!("unexpected result {:?}", res) @@ -457,20 +457,20 @@ fn test_wrong_param_type() { #[test] #[should_panic(expected = "expected 2 parameters but got 1")] fn test_too_few_params() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let _ = conn.execute("SELECT $1::INT, $2::INT", &[&1i32]); } #[test] #[should_panic(expected = "expected 2 parameters but got 3")] fn test_too_many_params() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let _ = conn.execute("SELECT $1::INT, $2::INT", &[&1i32, &2i32, &3i32]); } #[test] fn test_index_named() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT 10::INT as val")); let result = or_panic!(stmt.query(&[])); @@ -480,7 +480,7 @@ fn test_index_named() { #[test] #[should_panic] fn test_index_named_fail() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT 10::INT as id")); let result = or_panic!(stmt.query(&[])); @@ -489,7 +489,7 @@ fn test_index_named_fail() { #[test] fn test_get_named_err() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT 10::INT as id")); let result = or_panic!(stmt.query(&[])); @@ -501,7 +501,7 @@ fn test_get_named_err() { #[test] fn test_get_was_null() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT NULL::INT as id")); let result = or_panic!(stmt.query(&[])); @@ -513,7 +513,7 @@ fn test_get_was_null() { #[test] fn test_get_off_by_one() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT 10::INT as id")); let result = or_panic!(stmt.query(&[])); @@ -536,7 +536,7 @@ fn test_custom_notice_handler() { } let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost?client_min_messages=NOTICE", &SslMode::None)); + "postgres://postgres@localhost?client_min_messages=NOTICE", &mut SslMode::None)); conn.set_notice_handler(Box::new(Handler)); or_panic!(conn.execute("CREATE FUNCTION pg_temp.note() RETURNS INT AS $$ BEGIN @@ -550,7 +550,7 @@ fn test_custom_notice_handler() { #[test] fn test_notification_iterator_none() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); assert!(conn.notifications().next().is_none()); } @@ -561,7 +561,7 @@ fn check_notification(expected: Notification, actual: Notification) { #[test] fn test_notification_iterator_some() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let mut it = conn.notifications(); or_panic!(conn.execute("LISTEN test_notification_iterator_one_channel", &[])); or_panic!(conn.execute("LISTEN test_notification_iterator_one_channel2", &[])); @@ -591,11 +591,11 @@ fn test_notification_iterator_some() { #[test] fn test_notifications_next_block() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("LISTEN test_notifications_next_block", &[])); let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); thread::sleep_ms(500); or_panic!(conn.execute("NOTIFY test_notifications_next_block, 'foo'", &[])); }); @@ -611,11 +611,11 @@ fn test_notifications_next_block() { /* #[test] fn test_notifications_next_block_for() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("LISTEN test_notifications_next_block_for", &[])); let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); timer::sleep(Duration::milliseconds(500)); or_panic!(conn.execute("NOTIFY test_notifications_next_block_for, 'foo'", &[])); }); @@ -630,11 +630,11 @@ fn test_notifications_next_block_for() { #[test] fn test_notifications_next_block_for_timeout() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("LISTEN test_notifications_next_block_for_timeout", &[])); let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); timer::sleep(Duration::seconds(2)); or_panic!(conn.execute("NOTIFY test_notifications_next_block_for_timeout, 'foo'", &[])); }); @@ -653,12 +653,12 @@ fn test_notifications_next_block_for_timeout() { #[test] // This test is pretty sad, but I don't think there's a better way :( fn test_cancel_query() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let cancel_data = conn.cancel_data(); let _t = thread::spawn(move || { thread::sleep_ms(500); - assert!(postgres::cancel_query("postgres://postgres@localhost", &SslMode::None, + assert!(postgres::cancel_query("postgres://postgres@localhost", &mut SslMode::None, cancel_data).is_ok()); }); @@ -673,7 +673,7 @@ fn test_cancel_query() { fn test_require_ssl_conn() { let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); let conn = or_panic!(Connection::connect("postgres://postgres@localhost", - &SslMode::Require(ctx))); + &mut SslMode::Require(ctx))); or_panic!(conn.execute("SELECT 1::VARCHAR", &[])); } @@ -681,18 +681,18 @@ fn test_require_ssl_conn() { fn test_prefer_ssl_conn() { let ctx = SslContext::new(SslMethod::Sslv23).unwrap(); let conn = or_panic!(Connection::connect("postgres://postgres@localhost", - &SslMode::Prefer(ctx))); + &mut SslMode::Prefer(ctx))); or_panic!(conn.execute("SELECT 1::VARCHAR", &[])); } #[test] fn test_plaintext_pass() { - or_panic!(Connection::connect("postgres://pass_user:password@localhost/postgres", &SslMode::None)); + or_panic!(Connection::connect("postgres://pass_user:password@localhost/postgres", &mut SslMode::None)); } #[test] fn test_plaintext_pass_no_pass() { - let ret = Connection::connect("postgres://pass_user@localhost/postgres", &SslMode::None); + let ret = Connection::connect("postgres://pass_user@localhost/postgres", &mut SslMode::None); match ret { Err(ConnectError::MissingPassword) => (), Err(err) => panic!("Unexpected error {:?}", err), @@ -702,7 +702,7 @@ fn test_plaintext_pass_no_pass() { #[test] fn test_plaintext_pass_wrong_pass() { - let ret = Connection::connect("postgres://pass_user:asdf@localhost/postgres", &SslMode::None); + let ret = Connection::connect("postgres://pass_user:asdf@localhost/postgres", &mut SslMode::None); match ret { Err(ConnectError::DbError(ref e)) if e.code() == &InvalidPassword => {} Err(err) => panic!("Unexpected error {:?}", err), @@ -712,12 +712,12 @@ fn test_plaintext_pass_wrong_pass() { #[test] fn test_md5_pass() { - or_panic!(Connection::connect("postgres://md5_user:password@localhost/postgres", &SslMode::None)); + or_panic!(Connection::connect("postgres://md5_user:password@localhost/postgres", &mut SslMode::None)); } #[test] fn test_md5_pass_no_pass() { - let ret = Connection::connect("postgres://md5_user@localhost/postgres", &SslMode::None); + let ret = Connection::connect("postgres://md5_user@localhost/postgres", &mut SslMode::None); match ret { Err(ConnectError::MissingPassword) => (), Err(err) => panic!("Unexpected error {:?}", err), @@ -727,7 +727,7 @@ fn test_md5_pass_no_pass() { #[test] fn test_md5_pass_wrong_pass() { - let ret = Connection::connect("postgres://md5_user:asdf@localhost/postgres", &SslMode::None); + let ret = Connection::connect("postgres://md5_user:asdf@localhost/postgres", &mut SslMode::None); match ret { Err(ConnectError::DbError(ref e)) if e.code() == &InvalidPassword => {} Err(err) => panic!("Unexpected error {:?}", err), @@ -737,7 +737,7 @@ fn test_md5_pass_wrong_pass() { #[test] fn test_execute_copy_from_err() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); match stmt.execute(&[]) { @@ -754,7 +754,7 @@ fn test_execute_copy_from_err() { #[test] fn test_copy_in() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", &[])); let stmt = or_panic!(conn.prepare_copy_in("foo", &["id", "name"])); @@ -773,7 +773,7 @@ fn test_copy_in() { #[test] fn test_copy_in_bad_column_count() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", &[])); let stmt = or_panic!(conn.prepare_copy_in("foo", &["id", "name"])); @@ -810,7 +810,7 @@ fn test_copy_in_bad_column_count() { #[test] fn test_copy_in_bad_type() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT, name VARCHAR)", &[])); let stmt = or_panic!(conn.prepare_copy_in("foo", &["id", "name"])); @@ -842,7 +842,7 @@ fn test_copy_in_weird_names() { #[test] fn test_batch_execute_copy_from_err() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); match conn.batch_execute("COPY foo (id) FROM STDIN") { Err(Error::DbError(ref err)) if err.message().contains("COPY") => {} @@ -858,7 +858,7 @@ fn test_generic_connection() { or_panic!(t.execute("SELECT 1", &[])); } - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); f(&conn); let trans = or_panic!(conn.transaction()); f(&trans); @@ -866,7 +866,7 @@ fn test_generic_connection() { #[test] fn test_custom_range_element_type() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let trans = or_panic!(conn.transaction()); or_panic!(trans.execute("CREATE TYPE floatrange AS RANGE ( subtype = float8, @@ -884,7 +884,7 @@ fn test_custom_range_element_type() { #[test] fn test_prepare_cached() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1), (2)", &[])); @@ -903,7 +903,7 @@ fn test_prepare_cached() { #[test] fn test_is_active() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); assert!(conn.is_active()); let trans = or_panic!(conn.transaction()); assert!(!conn.is_active()); @@ -923,14 +923,14 @@ fn test_is_active() { #[test] fn test_parameter() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); assert_eq!(Some("UTF8".to_string()), conn.parameter("client_encoding")); assert_eq!(None, conn.parameter("asdf")); } #[test] fn test_get_bytes() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT '\\x00010203'::BYTEA")); let result = or_panic!(stmt.query(&[])); assert_eq!(b"\x00\x01\x02\x03", result.iter().next().unwrap().get_bytes(0).unwrap()); @@ -938,7 +938,7 @@ fn test_get_bytes() { #[test] fn test_get_opt_wrong_type() { - let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap(); + let conn = Connection::connect("postgres://postgres@localhost", &mut SslMode::None).unwrap(); let stmt = conn.prepare("SELECT 1::INT").unwrap(); let res = stmt.query(&[]).unwrap(); match res.iter().next().unwrap().get_opt::<_, String>(0) { @@ -957,7 +957,7 @@ fn url_encoded_password() { #[test] fn test_transaction_isolation_level() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); assert_eq!(IsolationLevel::ReadCommitted, or_panic!(conn.transaction_isolation())); or_panic!(conn.set_transaction_isolation(IsolationLevel::ReadUncommitted)); assert_eq!(IsolationLevel::ReadUncommitted, or_panic!(conn.transaction_isolation())); diff --git a/tests/types/mod.rs b/tests/types/mod.rs index d1a979f9..4545a812 100644 --- a/tests/types/mod.rs +++ b/tests/types/mod.rs @@ -16,7 +16,7 @@ mod rustc_serialize; mod serde; fn test_type(sql_type: &str, checks: &[(T, S)]) { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); for &(ref val, ref repr) in checks.iter() { let stmt = or_panic!(conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type))); let result = or_panic!(stmt.query(&[])).iter().next().unwrap().get(0); @@ -102,7 +102,7 @@ fn test_text_params() { #[test] fn test_bpchar_params() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo ( id SERIAL PRIMARY KEY, b CHAR(5) @@ -118,7 +118,7 @@ fn test_bpchar_params() { #[test] fn test_citext_params() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); or_panic!(conn.execute("CREATE TEMPORARY TABLE foo ( id SERIAL PRIMARY KEY, b CITEXT @@ -156,7 +156,7 @@ fn test_hstore_params() { } fn test_nan_param(sql_type: &str) { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare(&*format!("SELECT 'NaN'::{}", sql_type))); let result = or_panic!(stmt.query(&[])); let val: T = result.iter().next().unwrap().get(0); @@ -175,7 +175,7 @@ fn test_f64_nan_param() { #[test] fn test_pg_database_datname() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &mut SslMode::None)); let stmt = or_panic!(conn.prepare("SELECT datname FROM pg_database")); let result = or_panic!(stmt.query(&[])); @@ -186,7 +186,7 @@ fn test_pg_database_datname() { #[test] fn test_slice() { - let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap(); + let conn = Connection::connect("postgres://postgres@localhost", &mut SslMode::None).unwrap(); conn.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, f VARCHAR); INSERT INTO foo (f) VALUES ('a'), ('b'), ('c'), ('d');").unwrap(); @@ -198,7 +198,7 @@ fn test_slice() { #[test] fn test_slice_wrong_type() { - let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap(); + let conn = Connection::connect("postgres://postgres@localhost", &mut SslMode::None).unwrap(); conn.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)").unwrap(); let stmt = conn.prepare("SELECT * FROM foo WHERE id = ANY($1)").unwrap(); @@ -211,7 +211,7 @@ fn test_slice_wrong_type() { #[test] fn test_slice_range() { - let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap(); + let conn = Connection::connect("postgres://postgres@localhost", &mut SslMode::None).unwrap(); let stmt = conn.prepare("SELECT $1::INT8RANGE").unwrap(); match stmt.query(&[&Slice(&[1i64])]) {