Merge branch 'breaks'

This commit is contained in:
Steven Fackler 2015-05-16 20:42:25 -07:00
commit ed26ff0042
9 changed files with 258 additions and 154 deletions

View File

@ -8,7 +8,7 @@ before_script:
- "./.travis/setup.sh"
script:
- cargo test
- cargo test --features "uuid rustc-serialize time unix_socket serde chrono"
- cargo test --features "uuid rustc-serialize time unix_socket serde chrono openssl"
- cargo doc --no-deps --features "unix_socket"
after_success:
- test $TRAVIS_PULL_REQUEST == "false" && test $TRAVIS_BRANCH == "master" && test $TRAVIS_RUST_VERSION == "nightly" && ./.travis/update_docs.sh

View File

@ -24,18 +24,19 @@ path = "tests/test.rs"
phf_codegen = "0.7"
[dependencies]
phf = "0.7"
openssl = "0.6"
log = "0.3"
rustc-serialize = "0.3"
bufstream = "0.1"
byteorder = "0.3"
debug-builders = "0.1"
bufstream = "0.1"
uuid = { version = "0.1", optional = true }
unix_socket = { version = "0.3", optional = true }
time = { version = "0.1.14", optional = true }
serde = { version = "0.3", optional = true }
log = "0.3"
phf = "0.7"
rust-crypto = "0.2"
rustc-serialize = "0.3"
chrono = { version = "0.2.14", optional = true }
openssl = { version = "0.6", optional = true }
serde = { version = "0.3", optional = true }
time = { version = "0.1.14", optional = true }
unix_socket = { version = "0.3", optional = true }
uuid = { version = "0.1", optional = true }
[dev-dependencies]
url = "0.2"

View File

@ -1,7 +1,6 @@
pub use ugh_privacy::DbError;
use byteorder;
use openssl::ssl::error::SslError;
use phf;
use std::error;
use std::convert::From;
@ -29,8 +28,8 @@ pub enum ConnectError {
UnsupportedAuthentication,
/// The Postgres server does not support SSL encryption.
NoSslSupport,
/// There was an error initializing the SSL session.
SslError(SslError),
/// There was an error initializing the SSL session
SslError(Box<error::Error>),
/// There was an error communicating with the server.
IoError(io::Error),
/// The server sent an unexpected response.
@ -67,7 +66,7 @@ impl error::Error for ConnectError {
fn cause(&self) -> Option<&error::Error> {
match *self {
ConnectError::DbError(ref err) => Some(err),
ConnectError::SslError(ref err) => Some(err),
ConnectError::SslError(ref err) => Some(&**err),
ConnectError::IoError(ref err) => Some(err),
_ => None
}
@ -86,12 +85,6 @@ impl From<DbError> for ConnectError {
}
}
impl From<SslError> for ConnectError {
fn from(err: SslError) -> ConnectError {
ConnectError::SslError(err)
}
}
impl From<byteorder::Error> for ConnectError {
fn from(err: byteorder::Error) -> ConnectError {
ConnectError::IoError(From::from(err))

27
src/io/mod.rs Normal file
View File

@ -0,0 +1,27 @@
//! Types and traits for SSL adaptors.
pub use priv_io::Stream;
use std::error::Error;
use std::io::prelude::*;
#[cfg(feature = "openssl")]
mod openssl;
/// A trait implemented by SSL adaptors.
pub trait StreamWrapper: Read+Write+Send {
/// Returns a reference to the underlying `Stream`.
fn get_ref(&self) -> &Stream;
/// Returns a mutable reference to the underlying `Stream`.
fn get_mut(&mut self) -> &mut Stream;
}
/// A trait implemented by types that can negotiate SSL over a Postgres stream.
pub trait NegotiateSsl {
/// Negotiates an SSL session, returning a wrapper around the provided
/// stream.
///
/// The host portion of the connection parameters is provided for hostname
/// verification.
fn negotiate_ssl(&self, host: &str, stream: Stream) -> Result<Box<StreamWrapper>, Box<Error>>;
}

23
src/io/openssl.rs Normal file
View File

@ -0,0 +1,23 @@
extern crate openssl;
use std::error::Error;
use self::openssl::ssl::{SslContext, SslStream};
use io::{StreamWrapper, Stream, NegotiateSsl};
impl StreamWrapper for SslStream<Stream> {
fn get_ref(&self) -> &Stream {
self.get_ref()
}
fn get_mut(&mut self) -> &mut Stream {
self.get_mut()
}
}
impl NegotiateSsl for SslContext {
fn negotiate_ssl(&self, _: &str, stream: Stream) -> Result<Box<StreamWrapper>, Box<Error>> {
let stream = try!(SslStream::new(self, stream));
Ok(Box::new(stream))
}
}

View File

@ -1,90 +0,0 @@
use openssl::ssl::{SslStream, MaybeSslStream};
use std::io;
use std::io::prelude::*;
use std::net::TcpStream;
#[cfg(feature = "unix_socket")]
use unix_socket::UnixStream;
use byteorder::ReadBytesExt;
use {ConnectParams, SslMode, ConnectTarget, ConnectError};
use message;
use message::WriteMessage;
use message::FrontendMessage::SslRequest;
const DEFAULT_PORT: u16 = 5432;
pub enum InternalStream {
Tcp(TcpStream),
#[cfg(feature = "unix_socket")]
Unix(UnixStream),
}
impl Read for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
InternalStream::Tcp(ref mut s) => s.read(buf),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref mut s) => s.read(buf),
}
}
}
impl Write for InternalStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match *self {
InternalStream::Tcp(ref mut s) => s.write(buf),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match *self {
InternalStream::Tcp(ref mut s) => s.flush(),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref mut s) => s.flush(),
}
}
}
fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
let port = params.port.unwrap_or(DEFAULT_PORT);
match params.target {
ConnectTarget::Tcp(ref host) => {
Ok(try!(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp)))
}
#[cfg(feature = "unix_socket")]
ConnectTarget::Unix(ref path) => {
let mut path = path.clone();
path.push(&format!(".s.PGSQL.{}", port));
Ok(try!(UnixStream::connect(&path).map(InternalStream::Unix)))
}
}
}
pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode)
-> Result<MaybeSslStream<InternalStream>, ConnectError> {
let mut socket = try!(open_socket(params));
let (ssl_required, ctx) = match *ssl {
SslMode::None => return Ok(MaybeSslStream::Normal(socket)),
SslMode::Prefer(ref ctx) => (false, ctx),
SslMode::Require(ref ctx) => (true, ctx)
};
try!(socket.write_message(&SslRequest { code: message::SSL_CODE }));
try!(socket.flush());
if try!(socket.read_u8()) == 'N' as u8 {
if ssl_required {
return Err(ConnectError::NoSslSupport);
} else {
return Ok(MaybeSslStream::Normal(socket));
}
}
match SslStream::new(ctx, socket) {
Ok(stream) => Ok(MaybeSslStream::Ssl(stream)),
Err(err) => Err(ConnectError::SslError(err))
}
}

View File

@ -47,9 +47,9 @@
extern crate bufstream;
extern crate byteorder;
extern crate crypto;
#[macro_use]
extern crate log;
extern crate openssl;
extern crate phf;
extern crate rustc_serialize as serialize;
#[cfg(feature = "unix_socket")]
@ -57,17 +57,16 @@ extern crate unix_socket;
extern crate debug_builders;
use bufstream::BufStream;
use crypto::digest::Digest;
use crypto::md5::Md5;
use debug_builders::DebugStruct;
use openssl::crypto::hash::{self, Hasher};
use openssl::ssl::{SslContext, MaybeSslStream};
use serialize::hex::ToHex;
use std::ascii::AsciiExt;
use std::borrow::{ToOwned, Cow};
use std::cell::{Cell, RefCell};
use std::collections::{VecDeque, HashMap};
use std::fmt;
use std::iter::IntoIterator;
use std::io;
use std::io as std_io;
use std::io::prelude::*;
use std::mem;
use std::slice;
@ -80,10 +79,10 @@ use std::path::PathBuf;
pub use error::{Error, ConnectError, SqlState, DbError, ErrorPosition};
#[doc(inline)]
pub use types::{Oid, Type, Kind, ToSql, FromSql};
use io::{StreamWrapper, NegotiateSsl};
use types::IsNull;
#[doc(inline)]
pub use types::Slice;
use io_util::InternalStream;
use message::BackendMessage::*;
use message::FrontendMessage::*;
use message::{FrontendMessage, BackendMessage, RowDescriptionEntry};
@ -94,9 +93,10 @@ use url::Url;
mod macros;
mod error;
mod io_util;
pub mod io;
mod message;
mod ugh_privacy;
mod priv_io;
mod url;
mod util;
pub mod types;
@ -388,9 +388,10 @@ pub struct CancelData {
/// postgres::cancel_query(url, &SslMode::None, cancel_data);
/// ```
pub fn cancel_query<T>(params: T, ssl: &SslMode, data: CancelData)
-> result::Result<(), ConnectError> where T: IntoConnectParams {
-> result::Result<(), ConnectError>
where T: IntoConnectParams {
let params = try!(params.into_connect_params());
let mut socket = try!(io_util::initialize_stream(&params, ssl));
let mut socket = try!(priv_io::initialize_stream(&params, ssl));
try!(socket.write_message(&CancelRequest {
code: message::CANCEL_CODE,
@ -456,6 +457,16 @@ impl IsolationLevel {
}
}
/// Specifies the SSL support requested for a new connection.
pub enum SslMode {
/// The connection will not use SSL.
None,
/// The connection will use SSL if the backend supports it.
Prefer(Box<NegotiateSsl>),
/// The connection must use SSL.
Require(Box<NegotiateSsl>),
}
#[derive(Clone)]
struct CachedStatement {
name: String,
@ -464,7 +475,7 @@ struct CachedStatement {
}
struct InnerConnection {
stream: BufStream<MaybeSslStream<InternalStream>>,
stream: BufStream<Box<StreamWrapper>>,
notice_handler: Box<HandleNotice>,
notifications: VecDeque<Notification>,
cancel_data: CancelData,
@ -489,7 +500,7 @@ impl InnerConnection {
fn connect<T>(params: T, ssl: &SslMode) -> result::Result<InnerConnection, ConnectError>
where T: IntoConnectParams {
let params = try!(params.into_connect_params());
let stream = try!(io_util::initialize_stream(&params, ssl));
let stream = try!(priv_io::initialize_stream(&params, ssl));
let ConnectParams { user, database, mut options, .. } = params;
@ -569,7 +580,7 @@ impl InnerConnection {
}
}
fn write_messages(&mut self, messages: &[FrontendMessage]) -> io::Result<()> {
fn write_messages(&mut self, messages: &[FrontendMessage]) -> std_io::Result<()> {
debug_assert!(!self.desynchronized);
for message in messages {
try_desync!(self, self.stream.write_message(message));
@ -577,7 +588,7 @@ impl InnerConnection {
Ok(try_desync!(self, self.stream.flush()))
}
fn read_one_message(&mut self) -> io::Result<Option<BackendMessage>> {
fn read_one_message(&mut self) -> std_io::Result<Option<BackendMessage>> {
debug_assert!(!self.desynchronized);
match try_desync!(self, self.stream.read_message()) {
NoticeResponse { fields } => {
@ -594,7 +605,7 @@ impl InnerConnection {
}
}
fn read_message_with_notification(&mut self) -> io::Result<BackendMessage> {
fn read_message_with_notification(&mut self) -> std_io::Result<BackendMessage> {
loop {
if let Some(msg) = try!(self.read_one_message()) {
return Ok(msg);
@ -602,7 +613,7 @@ impl InnerConnection {
}
}
fn read_message(&mut self) -> io::Result<BackendMessage> {
fn read_message(&mut self) -> std_io::Result<BackendMessage> {
loop {
match try!(self.read_message_with_notification()) {
NotificationResponse { pid, channel, payload } => {
@ -628,13 +639,14 @@ impl InnerConnection {
}
AuthenticationMD5Password { salt } => {
let pass = try!(user.password.ok_or(ConnectError::MissingPassword));
let mut hasher = Hasher::new(hash::Type::MD5);
let _ = hasher.write_all(pass.as_bytes());
let _ = hasher.write_all(user.user.as_bytes());
let output = hasher.finish().to_hex();
let _ = hasher.write_all(output.as_bytes());
let _ = hasher.write_all(&salt);
let output = format!("md5{}", hasher.finish().to_hex());
let mut hasher = Md5::new();
let _ = hasher.input(pass.as_bytes());
let _ = hasher.input(user.user.as_bytes());
let output = hasher.result_str();
hasher.reset();
let _ = hasher.input(output.as_bytes());
let _ = hasher.input(&salt);
let output = format!("md5{}", hasher.result_str());
try!(self.write_messages(&[PasswordMessage {
password: &output
}]));
@ -1131,13 +1143,6 @@ impl Connection {
self.batch_execute(level.to_set_query())
}
/// # Deprecated
///
/// Use `transaction_isolation` instead.
pub fn get_transaction_isolation(&self) -> Result<IsolationLevel> {
self.transaction_isolation()
}
/// Returns the isolation level which will be used for future transactions.
pub fn transaction_isolation(&self) -> Result<IsolationLevel> {
let mut conn = self.conn.borrow_mut();
@ -1251,17 +1256,6 @@ impl Connection {
}
}
/// Specifies the SSL support requested for a new connection.
#[derive(Debug)]
pub enum SslMode {
/// The connection will not use SSL.
None,
/// The connection will use SSL if the backend supports it.
Prefer(SslContext),
/// The connection must use SSL.
Require(SslContext)
}
/// Represents a transaction on a database connection.
///
/// The transaction will roll back by default.

153
src/priv_io.rs Normal file
View File

@ -0,0 +1,153 @@
use byteorder::ReadBytesExt;
use std::io;
use std::io::prelude::*;
use std::net::TcpStream;
#[cfg(feature = "unix_socket")]
use unix_socket::UnixStream;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use {SslMode, ConnectError, ConnectParams, ConnectTarget};
use io::{NegotiateSsl, StreamWrapper};
use message::{self, WriteMessage};
use message::FrontendMessage::SslRequest;
const DEFAULT_PORT: u16 = 5432;
/// A connection to the Postgres server.
///
/// It implements `Read`, `Write` and `StreamWrapper`, as well as `AsRawFd` on
/// Unix platforms and `AsRawSocket` on Windows platforms.
pub struct Stream(InternalStream);
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl StreamWrapper for Stream {
fn get_ref(&self) -> &Stream {
self
}
fn get_mut(&mut self) -> &mut Stream {
self
}
}
#[cfg(unix)]
impl AsRawFd for Stream {
fn as_raw_fd(&self) -> RawFd {
match self.0 {
InternalStream::Tcp(ref s) => s.as_raw_fd(),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref s) => s.as_raw_fd(),
}
}
}
#[cfg(windows)]
impl AsRawSocket for Stream {
fn as_raw_socket(&self) -> RawSocket {
// Unix sockets aren't supported on windows, so no need to match
match self.0 {
InternalStream::Tcp(ref s) => s.as_raw_socket(),
}
}
}
enum InternalStream {
Tcp(TcpStream),
#[cfg(feature = "unix_socket")]
Unix(UnixStream),
}
impl Read for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
InternalStream::Tcp(ref mut s) => s.read(buf),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref mut s) => s.read(buf),
}
}
}
impl Write for InternalStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match *self {
InternalStream::Tcp(ref mut s) => s.write(buf),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match *self {
InternalStream::Tcp(ref mut s) => s.flush(),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref mut s) => s.flush(),
}
}
}
fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
let port = params.port.unwrap_or(DEFAULT_PORT);
match params.target {
ConnectTarget::Tcp(ref host) => {
Ok(try!(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp)))
}
#[cfg(feature = "unix_socket")]
ConnectTarget::Unix(ref path) => {
let mut path = path.clone();
path.push(&format!(".s.PGSQL.{}", port));
Ok(try!(UnixStream::connect(&path).map(InternalStream::Unix)))
}
}
}
pub fn initialize_stream(params: &ConnectParams, ssl: &SslMode)
-> Result<Box<StreamWrapper>, ConnectError> {
let mut socket = Stream(try!(open_socket(params)));
let (ssl_required, negotiator) = match *ssl {
SslMode::None => return Ok(Box::new(socket)),
SslMode::Prefer(ref negotiator) => (false, negotiator),
SslMode::Require(ref negotiator) => (true, negotiator),
};
try!(socket.write_message(&SslRequest { code: message::SSL_CODE }));
try!(socket.flush());
if try!(socket.read_u8()) == 'N' as u8 {
if ssl_required {
return Err(ConnectError::NoSslSupport);
} else {
return Ok(Box::new(socket));
}
}
// Postgres doesn't support SSL over unix sockets
let host = match params.target {
ConnectTarget::Tcp(ref host) => host,
#[cfg(feature = "unix_socket")]
ConnectTarget::Unix(_) => return Err(ConnectError::BadResponse)
};
match negotiator.negotiate_ssl(host, socket) {
Ok(stream) => Ok(stream),
Err(err) => Err(ConnectError::SslError(err))
}
}

View File

@ -1,10 +1,11 @@
extern crate postgres;
extern crate rustc_serialize as serialize;
extern crate url;
#[cfg(feature = "openssl")]
extern crate openssl;
use openssl::ssl::SslContext;
use openssl::ssl::SslMethod;
#[cfg(feature = "openssl")]
use openssl::ssl::{SslContext, SslMethod};
use std::thread;
use postgres::{HandleNotice,
@ -670,18 +671,20 @@ fn test_cancel_query() {
}
#[test]
#[cfg(feature = "openssl")]
fn test_require_ssl_conn() {
let ctx = SslContext::new(SslMethod::Sslv23).unwrap();
let conn = or_panic!(Connection::connect("postgres://postgres@localhost",
&SslMode::Require(ctx)));
&mut SslMode::Require(Box::new(ctx))));
or_panic!(conn.execute("SELECT 1::VARCHAR", &[]));
}
#[test]
#[cfg(feature = "openssl")]
fn test_prefer_ssl_conn() {
let ctx = SslContext::new(SslMethod::Sslv23).unwrap();
let conn = or_panic!(Connection::connect("postgres://postgres@localhost",
&SslMode::Prefer(ctx)));
&mut SslMode::Prefer(Box::new(ctx))));
or_panic!(conn.execute("SELECT 1::VARCHAR", &[]));
}