diff --git a/Cargo.toml b/Cargo.toml index e7937b3f..95f0daa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,11 @@ [package] name = "postgres" -version = "0.9.0" +version = "0.9.2" authors = ["Steven Fackler "] license = "MIT" description = "A native PostgreSQL driver" repository = "https://github.com/sfackler/rust-postgres" -documentation = "https://sfackler.github.io/rust-postgres/doc/v0.9.0/postgres" +documentation = "https://sfackler.github.io/rust-postgres/doc/v0.9.2/postgres" readme = "README.md" keywords = ["database", "sql"] build = "build.rs" @@ -29,7 +29,6 @@ byteorder = "0.3" debug-builders = "0.1" log = "0.3" phf = "0.7" -rust-crypto = "0.2" rustc-serialize = "0.3" chrono = { version = "0.2.14", optional = true } openssl = { version = "0.6", optional = true } diff --git a/README.md b/README.md index d0c87b1f..93f075d5 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # Rust-Postgres A native PostgreSQL driver for Rust. -Documentation is available at https://sfackler.github.io/rust-postgres/doc/v0.8.9/postgres +[Documentation](https://sfackler.github.io/rust-postgres/doc/v0.9.2/postgres) [![Build Status](https://travis-ci.org/sfackler/rust-postgres.png?branch=master)](https://travis-ci.org/sfackler/rust-postgres) [![Latest Version](https://img.shields.io/crates/v/postgres.svg)](https://crates.io/crates/postgres) @@ -9,7 +9,7 @@ You can integrate Rust-Postgres into your project through the [releases on crate ```toml # Cargo.toml [dependencies] -postgres = "0.8" +postgres = "0.9" ``` ## Overview diff --git a/THIRD_PARTY b/THIRD_PARTY index 80336ea0..aac5bfc0 100644 --- a/THIRD_PARTY +++ b/THIRD_PARTY @@ -57,3 +57,34 @@ CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +------------------------------------------------------------------------------- + +* src/md5.rs has been copied from rust-crypto + +Copyright (c) 2006-2009 Graydon Hoare +Copyright (c) 2009-2013 Mozilla Foundation + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/src/lib.rs b/src/lib.rs index ef381e8c..98e62605 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,6 @@ //! package. //! //! ```rust,no_run -//! # #![allow(unstable)] //! extern crate postgres; //! //! use postgres::{Connection, SslMode}; @@ -42,12 +41,11 @@ //! } //! } //! ``` -#![doc(html_root_url="https://sfackler.github.io/rust-postgres/doc/v0.9.0")] +#![doc(html_root_url="https://sfackler.github.io/rust-postgres/doc/v0.9.2")] #![warn(missing_docs)] extern crate bufstream; extern crate byteorder; -extern crate crypto; #[macro_use] extern crate log; extern crate phf; @@ -57,8 +55,7 @@ extern crate unix_socket; extern crate debug_builders; use bufstream::BufStream; -use crypto::digest::Digest; -use crypto::md5::Md5; +use md5::Md5; use debug_builders::DebugStruct; use std::ascii::AsciiExt; use std::borrow::ToOwned; @@ -95,6 +92,7 @@ mod url; mod util; pub mod types; pub mod rows; +mod md5; const TYPEINFO_QUERY: &'static str = "t"; @@ -174,7 +172,7 @@ impl IntoConnectParams for Url { #[cfg(feature = "unix_socket")] fn make_unix(maybe_path: String) -> result::Result { - Ok(ConnectTarget::Unix(PathBuf::from(&maybe_path))) + Ok(ConnectTarget::Unix(PathBuf::from(maybe_path))) } #[cfg(not(feature = "unix_socket"))] fn make_unix(_: String) -> result::Result { @@ -289,62 +287,6 @@ impl<'conn> Notifications<'conn> { _ => unreachable!() } } - - /* - /// Returns the oldest pending notification - /// - /// If no notifications are pending, blocks for up to `timeout` time, after - /// which `None` is returned. - /// - /// ## Example - /// - /// ```rust,no_run - /// # #![allow(unstable)] - /// use std::old_io::{IoError, IoErrorKind}; - /// use std::time::Duration; - /// - /// use postgres::Error; - /// - /// # let conn = postgres::Connection::connect("", &postgres::SslMode::None).unwrap(); - /// match conn.notifications().next_block_for(Duration::seconds(2)) { - /// Some(Ok(notification)) => println!("notification: {}", notification.payload), - /// Some(Err(e)) => println!("Error: {:?}", e), - /// None => println!("Wait for notification timed out"), - /// } - /// ``` - pub fn next_block_for(&mut self, timeout: Duration) -> Option> { - if let Some(notification) = self.next() { - return Some(Ok(notification)); - } - - let mut conn = self.conn.conn.borrow_mut(); - if conn.desynchronized { - return Some(Err(Error::StreamDesynchronized)); - } - - let end = SteadyTime::now() + timeout; - loop { - let timeout = max(Duration::zero(), end - SteadyTime::now()).num_milliseconds() as u64; - conn.stream.set_read_timeout(Some(timeout)); - match conn.read_one_message() { - Ok(Some(NotificationResponse { pid, channel, payload })) => { - return Some(Ok(Notification { - pid: pid, - channel: channel, - payload: payload - })) - } - Ok(Some(_)) => unreachable!(), - Ok(None) => {} - Err(IoError { kind: IoErrorKind::TimedOut, .. }) => { - conn.desynchronized = false; - return None; - } - Err(e) => return Some(Err(Error::IoError(e))), - } - } - } - */ } /// Contains information necessary to cancel queries for a session. @@ -467,9 +409,19 @@ pub enum SslMode { /// The connection will not use SSL. None, /// The connection will use SSL if the backend supports it. - Prefer(Box), + Prefer(Box), /// The connection must use SSL. - Require(Box), + Require(Box), +} + +impl fmt::Debug for SslMode { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + SslMode::None => fmt.write_str("None"), + SslMode::Prefer(..) => fmt.write_str("Prefer"), + SslMode::Require(..) => fmt.write_str("Require"), + } + } } #[derive(Clone)] @@ -479,10 +431,6 @@ struct CachedStatement { columns: Vec, } -trait SessionInfoNew<'a> { - fn new(conn: &'a InnerConnection) -> SessionInfo<'a>; -} - struct InnerConnection { stream: BufStream>, notice_handler: Box, @@ -597,27 +545,19 @@ impl InnerConnection { Ok(try_desync!(self, self.stream.flush())) } - fn read_one_message(&mut self) -> std_io::Result> { - debug_assert!(!self.desynchronized); - match try_desync!(self, self.stream.read_message()) { - NoticeResponse { fields } => { - if let Ok(err) = DbError::new_raw(fields) { - self.notice_handler.handle_notice(err); - } - Ok(None) - } - ParameterStatus { parameter, value } => { - self.parameters.insert(parameter, value); - Ok(None) - } - val => Ok(Some(val)) - } - } - fn read_message_with_notification(&mut self) -> std_io::Result { + debug_assert!(!self.desynchronized); loop { - if let Some(msg) = try!(self.read_one_message()) { - return Ok(msg); + match try_desync!(self, self.stream.read_message()) { + NoticeResponse { fields } => { + if let Ok(err) = DbError::new_raw(fields) { + self.notice_handler.handle_notice(err); + } + } + ParameterStatus { parameter, value } => { + self.parameters.insert(parameter, value); + } + val => return Ok(val) } } } @@ -982,7 +922,6 @@ impl Connection { /// ``` /// /// ```rust,no_run - /// # #![allow(unstable)] /// # use postgres::{Connection, UserInfo, ConnectParams, SslMode, ConnectTarget}; /// # #[cfg(feature = "unix_socket")] /// # fn f() -> Result<(), ::postgres::error::ConnectError> { @@ -1630,22 +1569,31 @@ impl<'conn> Statement<'conn> { let mut buf = vec![]; loop { - match std::io::copy(&mut r.take(16 * 1024), &mut buf) { + match r.take(16 * 1024).read_to_end(&mut buf) { Ok(0) => break, - Ok(len) => { + Ok(_) => { try_desync!(conn, conn.stream.write_message( - &CopyData { - data: &buf[..len as usize], - })); + &CopyData { + data: &buf, + })); buf.clear(); } Err(err) => { - // FIXME better to return the error directly - try_desync!(conn, conn.stream.write_message( - &CopyFail { - message: &err.to_string(), - })); - break; + try!(conn.write_messages(&[ + CopyFail { + message: "", + }, + CopyDone, + Sync])); + match try!(conn.read_message()) { + ErrorResponse { .. } => { /* expected from the CopyFail */ } + _ => { + conn.desynchronized = true; + return Err(Error::IoError(bad_response())); + } + } + try!(conn.wait_for_ready()); + return Err(Error::IoError(err)); } } } @@ -1833,3 +1781,7 @@ trait LazyRowsNew<'trans, 'stmt> { finished: bool, trans: &'trans Transaction<'trans>) -> LazyRows<'trans, 'stmt>; } + +trait SessionInfoNew<'a> { + fn new(conn: &'a InnerConnection) -> SessionInfo<'a>; +} diff --git a/src/md5.rs b/src/md5.rs new file mode 100644 index 00000000..894d1bf4 --- /dev/null +++ b/src/md5.rs @@ -0,0 +1,525 @@ +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::io::prelude::*; +use std::ptr; +use std::mem; +use std::ops::{Add, Range}; +use std::iter::repeat; + +#[derive(Clone)] +struct StepUp { + next: T, + end: T, + ammount: T +} + +impl Iterator for StepUp where + T: Add + PartialOrd + Copy { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.next < self.end { + let n = self.next; + self.next = self.next + self.ammount; + Some(n) + } else { + None + } + } +} + +trait RangeExt { + fn step_up(self, ammount: T) -> StepUp; +} + +impl RangeExt for Range where + T: Add + PartialOrd + Copy { + fn step_up(self, ammount: T) -> StepUp { + StepUp { + next: self.start, + end: self.end, + ammount: ammount + } + } +} + +/// Copy bytes from src to dest +#[inline] +fn copy_memory(src: &[u8], dst: &mut [u8]) { + assert!(dst.len() >= src.len()); + unsafe { + let srcp = src.as_ptr(); + let dstp = dst.as_mut_ptr(); + ptr::copy_nonoverlapping(srcp, dstp, src.len()); + } +} + +/// Zero all bytes in dst +#[inline] +fn zero(dst: &mut [u8]) { + unsafe { + ptr::write_bytes(dst.as_mut_ptr(), 0, dst.len()); + } +} + +/// Read a vector of bytes into a vector of u32s. The values are read in little-endian format. +fn read_u32v_le(dst: &mut[u32], input: &[u8]) { + assert!(dst.len() * 4 == input.len()); + unsafe { + let mut x: *mut u32 = dst.get_unchecked_mut(0); + let mut y: *const u8 = input.get_unchecked(0); + for _ in (0..dst.len()) { + let mut tmp: u32 = mem::uninitialized(); + ptr::copy_nonoverlapping(y, &mut tmp as *mut _ as *mut u8, 4); + *x = u32::from_le(tmp); + x = x.offset(1); + y = y.offset(4); + } + } +} + +/// Write a u32 into a vector, which must be 4 bytes long. The value is written in little-endian +/// format. +fn write_u32_le(dst: &mut[u8], mut input: u32) { + assert!(dst.len() == 4); + input = input.to_le(); + unsafe { + let tmp = &input as *const _ as *const u8; + ptr::copy_nonoverlapping(tmp, dst.get_unchecked_mut(0), 4); + } +} + +/// The StandardPadding trait adds a method useful for various hash algorithms to a FixedBuffer +/// struct. +trait StandardPadding { + /// Add standard padding to the buffer. The buffer must not be full when this method is called + /// and is guaranteed to have exactly rem remaining bytes when it returns. If there are not at + /// least rem bytes available, the buffer will be zero padded, processed, cleared, and then + /// filled with zeros again until only rem bytes are remaining. + fn standard_padding(&mut self, rem: usize, func: F); +} + +impl StandardPadding for T { + fn standard_padding(&mut self, rem: usize, mut func: F) { + let size = self.size(); + + self.next(1)[0] = 128; + + if self.remaining() < rem { + self.zero_until(size); + func(self.full_buffer()); + } + + self.zero_until(size - rem); + } +} + +/// A FixedBuffer, likes its name implies, is a fixed size buffer. When the buffer becomes full, it +/// must be processed. The input() method takes care of processing and then clearing the buffer +/// automatically. However, other methods do not and require the caller to process the buffer. Any +/// method that modifies the buffer directory or provides the caller with bytes that can be modifies +/// results in those bytes being marked as used by the buffer. +trait FixedBuffer { + /// Input a vector of bytes. If the buffer becomes full, process it with the provided + /// function and then clear the buffer. + fn input(&mut self, input: &[u8], func: F); + + /// Reset the buffer. + fn reset(&mut self); + + /// Zero the buffer up until the specified index. The buffer position currently must not be + /// greater than that index. + fn zero_until(&mut self, idx: usize); + + /// Get a slice of the buffer of the specified size. There must be at least that many bytes + /// remaining in the buffer. + fn next<'s>(&'s mut self, len: usize) -> &'s mut [u8]; + + /// Get the current buffer. The buffer must already be full. This clears the buffer as well. + fn full_buffer<'s>(&'s mut self) -> &'s [u8]; + + /// Get the current buffer. + fn current_buffer<'s>(&'s mut self) -> &'s [u8]; + + /// Get the current position of the buffer. + fn position(&self) -> usize; + + /// Get the number of bytes remaining in the buffer until it is full. + fn remaining(&self) -> usize; + + /// Get the size of the buffer + fn size(&self) -> usize; +} + +macro_rules! impl_fixed_buffer( ($name:ident, $size:expr) => ( + impl FixedBuffer for $name { + fn input(&mut self, input: &[u8], mut func: F) { + let mut i = 0; + + // FIXME: #6304 - This local variable shouldn't be necessary. + let size = $size; + + // If there is already data in the buffer, copy as much as we can into it and process + // the data if the buffer becomes full. + if self.buffer_idx != 0 { + let buffer_remaining = size - self.buffer_idx; + if input.len() >= buffer_remaining { + copy_memory( + &input[..buffer_remaining], + &mut self.buffer[self.buffer_idx..size]); + self.buffer_idx = 0; + func(&self.buffer); + i += buffer_remaining; + } else { + copy_memory( + input, + &mut self.buffer[self.buffer_idx..self.buffer_idx + input.len()]); + self.buffer_idx += input.len(); + return; + } + } + + // While we have at least a full buffer size chunks's worth of data, process that data + // without copying it into the buffer + while input.len() - i >= size { + func(&input[i..i + size]); + i += size; + } + + // Copy any input data into the buffer. At this point in the method, the ammount of + // data left in the input vector will be less than the buffer size and the buffer will + // be empty. + let input_remaining = input.len() - i; + copy_memory( + &input[i..], + &mut self.buffer[0..input_remaining]); + self.buffer_idx += input_remaining; + } + + fn reset(&mut self) { + self.buffer_idx = 0; + } + + fn zero_until(&mut self, idx: usize) { + assert!(idx >= self.buffer_idx); + zero(&mut self.buffer[self.buffer_idx..idx]); + self.buffer_idx = idx; + } + + fn next<'s>(&'s mut self, len: usize) -> &'s mut [u8] { + self.buffer_idx += len; + &mut self.buffer[self.buffer_idx - len..self.buffer_idx] + } + + fn full_buffer<'s>(&'s mut self) -> &'s [u8] { + assert!(self.buffer_idx == $size); + self.buffer_idx = 0; + &self.buffer[..$size] + } + + fn current_buffer<'s>(&'s mut self) -> &'s [u8] { + let tmp = self.buffer_idx; + self.buffer_idx = 0; + &self.buffer[..tmp] + } + + fn position(&self) -> usize { self.buffer_idx } + + fn remaining(&self) -> usize { $size - self.buffer_idx } + + fn size(&self) -> usize { $size } + } +)); + +/// A fixed size buffer of 64 bytes useful for cryptographic operations. +#[derive(Copy)] +struct FixedBuffer64 { + buffer: [u8; 64], + buffer_idx: usize, +} + +impl Clone for FixedBuffer64 { fn clone(&self) -> FixedBuffer64 { *self } } + +impl FixedBuffer64 { + /// Create a new buffer + fn new() -> FixedBuffer64 { + FixedBuffer64 { + buffer: [0u8; 64], + buffer_idx: 0 + } + } +} + +impl_fixed_buffer!(FixedBuffer64, 64); + +// A structure that represents that state of a digest computation for the MD5 digest function +struct Md5State { + s0: u32, + s1: u32, + s2: u32, + s3: u32 +} + +impl Md5State { + fn new() -> Md5State { + Md5State { + s0: 0x67452301, + s1: 0xefcdab89, + s2: 0x98badcfe, + s3: 0x10325476 + } + } + + fn reset(&mut self) { + self.s0 = 0x67452301; + self.s1 = 0xefcdab89; + self.s2 = 0x98badcfe; + self.s3 = 0x10325476; + } + + fn process_block(&mut self, input: &[u8]) { + fn f(u: u32, v: u32, w: u32) -> u32 { + (u & v) | (!u & w) + } + + fn g(u: u32, v: u32, w: u32) -> u32 { + (u & w) | (v & !w) + } + + fn h(u: u32, v: u32, w: u32) -> u32 { + u ^ v ^ w + } + + fn i(u: u32, v: u32, w: u32) -> u32 { + v ^ (u | !w) + } + + fn op_f(w: u32, x: u32, y: u32, z: u32, m: u32, s: u32) -> u32 { + w.wrapping_add(f(x, y, z)).wrapping_add(m).rotate_left(s).wrapping_add(x) + } + + fn op_g(w: u32, x: u32, y: u32, z: u32, m: u32, s: u32) -> u32 { + w.wrapping_add(g(x, y, z)).wrapping_add(m).rotate_left(s).wrapping_add(x) + } + + fn op_h(w: u32, x: u32, y: u32, z: u32, m: u32, s: u32) -> u32 { + w.wrapping_add(h(x, y, z)).wrapping_add(m).rotate_left(s).wrapping_add(x) + } + + fn op_i(w: u32, x: u32, y: u32, z: u32, m: u32, s: u32) -> u32 { + w.wrapping_add(i(x, y, z)).wrapping_add(m).rotate_left(s).wrapping_add(x) + } + + let mut a = self.s0; + let mut b = self.s1; + let mut c = self.s2; + let mut d = self.s3; + + let mut data = [0u32; 16]; + + read_u32v_le(&mut data, input); + + // round 1 + for i in (0..16).step_up(4) { + a = op_f(a, b, c, d, data[i].wrapping_add(C1[i]), 7); + d = op_f(d, a, b, c, data[i + 1].wrapping_add(C1[i + 1]), 12); + c = op_f(c, d, a, b, data[i + 2].wrapping_add(C1[i + 2]), 17); + b = op_f(b, c, d, a, data[i + 3].wrapping_add(C1[i + 3]), 22); + } + + // round 2 + let mut t = 1; + for i in (0..16).step_up(4) { + a = op_g(a, b, c, d, data[t & 0x0f].wrapping_add(C2[i]), 5); + d = op_g(d, a, b, c, data[(t + 5) & 0x0f].wrapping_add(C2[i + 1]), 9); + c = op_g(c, d, a, b, data[(t + 10) & 0x0f].wrapping_add(C2[i + 2]), 14); + b = op_g(b, c, d, a, data[(t + 15) & 0x0f].wrapping_add(C2[i + 3]), 20); + t += 20; + } + + // round 3 + t = 5; + for i in (0..16).step_up(4) { + a = op_h(a, b, c, d, data[t & 0x0f].wrapping_add(C3[i]), 4); + d = op_h(d, a, b, c, data[(t + 3) & 0x0f].wrapping_add(C3[i + 1]), 11); + c = op_h(c, d, a, b, data[(t + 6) & 0x0f].wrapping_add(C3[i + 2]), 16); + b = op_h(b, c, d, a, data[(t + 9) & 0x0f].wrapping_add(C3[i + 3]), 23); + t += 12; + } + + // round 4 + t = 0; + for i in (0..16).step_up(4) { + a = op_i(a, b, c, d, data[t & 0x0f].wrapping_add(C4[i]), 6); + d = op_i(d, a, b, c, data[(t + 7) & 0x0f].wrapping_add(C4[i + 1]), 10); + c = op_i(c, d, a, b, data[(t + 14) & 0x0f].wrapping_add(C4[i + 2]), 15); + b = op_i(b, c, d, a, data[(t + 21) & 0x0f].wrapping_add(C4[i + 3]), 21); + t += 28; + } + + self.s0 = self.s0.wrapping_add(a); + self.s1 = self.s1.wrapping_add(b); + self.s2 = self.s2.wrapping_add(c); + self.s3 = self.s3.wrapping_add(d); + } +} + +// Round 1 constants +static C1: [u32; 16] = [ + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501, + 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821 +]; + +// Round 2 constants +static C2: [u32; 16] = [ + 0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a +]; + +// Round 3 constants +static C3: [u32; 16] = [ + 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, + 0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665 +]; + +// Round 4 constants +static C4: [u32; 16] = [ + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1, + 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391 +]; + +/// The MD5 Digest algorithm +pub struct Md5 { + length_bytes: u64, + buffer: FixedBuffer64, + state: Md5State, + finished: bool, +} + +impl Md5 { + /// Construct a new instance of the MD5 Digest. + pub fn new() -> Md5 { + Md5 { + length_bytes: 0, + buffer: FixedBuffer64::new(), + state: Md5State::new(), + finished: false + } + } + + pub fn input(&mut self, input: &[u8]) { + assert!(!self.finished); + // Unlike Sha1 and Sha2, the length value in MD5 is defined as the length of the message mod + // 2^64 - ie: integer overflow is OK. + self.length_bytes += input.len() as u64; + let self_state = &mut self.state; + self.buffer.input(input, |d: &[u8]| { self_state.process_block(d);} + ); + } + + pub fn reset(&mut self) { + self.length_bytes = 0; + self.buffer.reset(); + self.state.reset(); + self.finished = false; + } + + pub fn result(&mut self, out: &mut [u8]) { + if !self.finished { + let self_state = &mut self.state; + self.buffer.standard_padding(8, |d: &[u8]| { self_state.process_block(d); }); + write_u32_le(self.buffer.next(4), (self.length_bytes << 3) as u32); + write_u32_le(self.buffer.next(4), (self.length_bytes >> 29) as u32); + self_state.process_block(self.buffer.full_buffer()); + self.finished = true; + } + + write_u32_le(&mut out[0..4], self.state.s0); + write_u32_le(&mut out[4..8], self.state.s1); + write_u32_le(&mut out[8..12], self.state.s2); + write_u32_le(&mut out[12..16], self.state.s3); + } + + fn output_bits(&self) -> usize { 128 } + + pub fn result_str(&mut self) -> String { + use serialize::hex::ToHex; + + let mut buf: Vec = repeat(0).take((self.output_bits()+7)/8).collect(); + self.result(&mut buf); + buf[..].to_hex() + } +} + + +#[cfg(test)] +mod tests { + use md5::Md5; + + struct Test { + input: &'static str, + output_str: &'static str, + } + + fn test_hash(sh: &mut D, tests: &[Test]) { + // Test that it works when accepting the message all at once + for t in tests.iter() { + sh.input_str(t.input); + + let out_str = sh.result_str(); + assert_eq!(out_str, t.output_str); + + sh.reset(); + } + + // Test that it works when accepting the message in pieces + for t in tests.iter() { + let len = t.input.len(); + let mut left = len; + while left > 0 { + let take = (left + 1) / 2; + sh.input_str(&t.input[len - left..take + len - left]); + left = left - take; + } + + let out_str = sh.result_str(); + assert_eq!(out_str, t.output_str); + + sh.reset(); + } + } + + #[test] + fn test_md5() { + // Examples from wikipedia + let wikipedia_tests = vec![ + Test { + input: "", + output_str: "d41d8cd98f00b204e9800998ecf8427e" + }, + Test { + input: "The quick brown fox jumps over the lazy dog", + output_str: "9e107d9d372bb6826bd81d3542a419d6" + }, + Test { + input: "The quick brown fox jumps over the lazy dog.", + output_str: "e4d909c290d0fb1ca068ffaddf22cbd0" + }, + ]; + + let tests = wikipedia_tests; + + let mut sh = Md5::new(); + + test_hash(&mut sh, &tests[..]); + } +} diff --git a/src/message.rs b/src/message.rs index 81a7be82..dbd4609c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -161,23 +161,23 @@ impl WriteMessage for W { try!(buf.write_cstr(portal)); try!(buf.write_cstr(statement)); - try!(buf.write_u16::(formats.len() as u16)); + try!(buf.write_u16::(try!(u16::from_usize(formats.len())))); for &format in formats { try!(buf.write_i16::(format)); } - try!(buf.write_u16::(values.len() as u16)); + try!(buf.write_u16::(try!(u16::from_usize(values.len())))); for value in values { match *value { None => try!(buf.write_i32::(-1)), Some(ref value) => { - try!(buf.write_i32::(value.len() as i32)); + try!(buf.write_i32::(try!(i32::from_usize(value.len())))); try!(buf.write_all(&**value)); } } } - try!(buf.write_u16::(result_formats.len() as u16)); + try!(buf.write_u16::(try!(u16::from_usize(result_formats.len())))); for &format in result_formats { try!(buf.write_i16::(format)); } @@ -215,7 +215,7 @@ impl WriteMessage for W { ident = Some(b'P'); try!(buf.write_cstr(name)); try!(buf.write_cstr(query)); - try!(buf.write_u16::(param_types.len() as u16)); + try!(buf.write_u16::(try!(u16::from_usize(param_types.len())))); for &ty in param_types { try!(buf.write_u32::(ty)); } @@ -246,6 +246,9 @@ impl WriteMessage for W { } // add size of length value + if buf.len() > u32::max_value() as usize - mem::size_of::() { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "value too large to transmit")); + } try!(self.write_u32::((buf.len() + mem::size_of::()) as u32)); try!(self.write_all(&*buf)); @@ -407,3 +410,24 @@ fn read_row_description(buf: &mut R) -> io::Result { Ok(RowDescription { descriptions: types }) } + +trait FromUsize { + fn from_usize(x: usize) -> io::Result; +} + +macro_rules! from_usize { + ($t:ty) => { + impl FromUsize for $t { + fn from_usize(x: usize) -> io::Result<$t> { + if x > <$t>::max_value() as usize { + Err(io::Error::new(io::ErrorKind::InvalidInput, "value too large to transmit")) + } else { + Ok(x as $t) + } + } + } + } +} + +from_usize!(u16); +from_usize!(i32); diff --git a/src/rows.rs b/src/rows.rs index 36148db5..1728ff4c 100644 --- a/src/rows.rs +++ b/src/rows.rs @@ -50,6 +50,23 @@ impl<'stmt> Rows<'stmt> { self.stmt.columns() } + /// Returns the number of rows present. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns a specific `Row`. + /// + /// # Panics + /// + /// Panics if `idx` is out of bounds. + pub fn get<'a>(&'a self, idx: usize) -> Row<'a> { + Row { + stmt: self.stmt, + data: Cow::Borrowed(&self.data[idx]), + } + } + /// Returns an iterator over the `Row`s. pub fn iter<'a>(&'a self) -> Iter<'a> { Iter { diff --git a/src/types/chrono.rs b/src/types/chrono.rs index dc65b673..159bbdec 100644 --- a/src/types/chrono.rs +++ b/src/types/chrono.rs @@ -1,10 +1,12 @@ extern crate chrono; +use std::error; use std::io::prelude::*; use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; use self::chrono::{Duration, NaiveDate, NaiveTime, NaiveDateTime, DateTime, UTC}; use Result; +use error::Error; use types::{FromSql, ToSql, IsNull, Type, SessionInfo}; fn base() -> NaiveDateTime { @@ -61,8 +63,13 @@ impl FromSql for NaiveDate { impl ToSql for NaiveDate { fn to_sql(&self, _: &Type, mut w: &mut W, _: &SessionInfo) -> Result { - let jd = *self - base().date(); - try!(w.write_i32::(jd.num_days() as i32)); + let jd = (*self - base().date()).num_days(); + if jd > i32::max_value() as i64 || jd < i32::min_value() as i64 { + let err: Box = "value too large to transmit".into(); + return Err(Error::Conversion(err)); + } + + try!(w.write_i32::(jd as i32)); Ok(IsNull::No) } diff --git a/src/types/mod.rs b/src/types/mod.rs index 5ad8eb56..49f5cc03 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -934,15 +934,15 @@ impl ToSql for HashMap> { fn to_sql(&self, _: &Type, mut w: &mut W, _: &SessionInfo) -> Result { - try!(w.write_i32::(self.len() as i32)); + try!(w.write_i32::(try!(downcast(self.len())))); for (key, val) in self { - try!(w.write_i32::(key.len() as i32)); + try!(w.write_i32::(try!(downcast(key.len())))); try!(w.write_all(key.as_bytes())); match *val { Some(ref val) => { - try!(w.write_i32::(val.len() as i32)); + try!(w.write_i32::(try!(downcast(val.len())))); try!(w.write_all(val.as_bytes())); } None => try!(w.write_i32::(-1)) @@ -959,3 +959,12 @@ impl ToSql for HashMap> { } } } + +fn downcast(len: usize) -> Result { + if len > i32::max_value() as usize { + let err: Box = "value too large to transmit".into(); + Err(Error::Conversion(err)) + } else { + Ok(len as i32) + } +} diff --git a/src/types/slice.rs b/src/types/slice.rs index cb9156b8..02cb3975 100644 --- a/src/types/slice.rs +++ b/src/types/slice.rs @@ -3,7 +3,7 @@ use byteorder::{WriteBytesExt, BigEndian}; use Result; use error::Error; -use types::{Type, ToSql, Kind, IsNull, SessionInfo}; +use types::{Type, ToSql, Kind, IsNull, SessionInfo, downcast}; /// An adapter type mapping slices to Postgres arrays. /// @@ -48,14 +48,14 @@ impl<'a, T: 'a + ToSql> ToSql for Slice<'a, T> { try!(w.write_i32::(1)); // has nulls try!(w.write_u32::(member_type.oid())); - try!(w.write_i32::(self.0.len() as i32)); + try!(w.write_i32::(try!(downcast(self.0.len())))); try!(w.write_i32::(0)); // index offset let mut inner_buf = vec![]; for e in self.0 { match try!(e.to_sql(&member_type, &mut inner_buf, ctx)) { IsNull::No => { - try!(w.write_i32::(inner_buf.len() as i32)); + try!(w.write_i32::(try!(downcast(inner_buf.len())))); try!(w.write_all(&inner_buf)); } IsNull::Yes => try!(w.write_i32::(-1)), diff --git a/tests/test.rs b/tests/test.rs index d55d5de3..f14ae4b8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -7,6 +7,7 @@ extern crate openssl; #[cfg(feature = "openssl")] use openssl::ssl::{SslContext, SslMethod}; use std::thread; +use std::io; use postgres::{HandleNotice, Notification, @@ -605,48 +606,6 @@ fn test_notifications_next_block() { }, or_panic!(notifications.next_block())); } -/* -#[test] -fn test_notifications_next_block_for() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); - or_panic!(conn.execute("LISTEN test_notifications_next_block_for", &[])); - - let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); - timer::sleep(Duration::milliseconds(500)); - or_panic!(conn.execute("NOTIFY test_notifications_next_block_for, 'foo'", &[])); - }); - - let mut notifications = conn.notifications(); - check_notification(Notification { - pid: 0, - channel: "test_notifications_next_block_for".to_string(), - payload: "foo".to_string() - }, or_panic!(notifications.next_block_for(Duration::seconds(2)).unwrap())); -} - -#[test] -fn test_notifications_next_block_for_timeout() { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); - or_panic!(conn.execute("LISTEN test_notifications_next_block_for_timeout", &[])); - - let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); - timer::sleep(Duration::seconds(2)); - or_panic!(conn.execute("NOTIFY test_notifications_next_block_for_timeout, 'foo'", &[])); - }); - - let mut notifications = conn.notifications(); - match notifications.next_block_for(Duration::milliseconds(500)) { - None => {} - Some(Err(e)) => panic!("Unexpected error {:?}", e), - Some(Ok(_)) => panic!("expected error"), - } - - or_panic!(conn.execute("SELECT 1", &[])); -} -*/ - #[test] // This test is pretty sad, but I don't think there's a better way :( fn test_cancel_query() { @@ -741,12 +700,12 @@ fn test_execute_copy_from_err() { let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); match stmt.execute(&[]) { Err(Error::DbError(ref err)) if err.message().contains("COPY") => {} - Err(err) => panic!("Unexptected error {:?}", err), + Err(err) => panic!("Unexpected error {:?}", err), _ => panic!("Expected error"), } match stmt.query(&[]) { Err(Error::DbError(ref err)) if err.message().contains("COPY") => {} - Err(err) => panic!("Unexptected error {:?}", err), + Err(err) => panic!("Unexpected error {:?}", err), _ => panic!("Expected error"), } } @@ -757,11 +716,33 @@ fn test_batch_execute_copy_from_err() { or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); match conn.batch_execute("COPY foo (id) FROM STDIN") { Err(Error::DbError(ref err)) if err.message().contains("COPY") => {} - Err(err) => panic!("Unexptected error {:?}", err), + Err(err) => panic!("Unexpected error {:?}", err), _ => panic!("Expected error"), } } +#[test] +fn test_copy_io_error() { + struct ErrorReader; + + impl io::Read for ErrorReader { + fn read(&mut self, _: &mut [u8]) -> io::Result { + Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "boom")) + } + } + + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); + let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); + match stmt.copy_in(&[], &mut ErrorReader) { + Err(Error::IoError(ref e)) if e.kind() == io::ErrorKind::AddrNotAvailable => {} + Err(err) => panic!("Unexpected error {:?}", err), + _ => panic!("Expected error"), + } + + or_panic!(conn.execute("SELECT 1", &[])); +} + #[test] fn test_copy() { let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); @@ -892,3 +873,16 @@ fn test_transaction_isolation_level() { or_panic!(conn.set_transaction_isolation(IsolationLevel::ReadCommitted)); assert_eq!(IsolationLevel::ReadCommitted, or_panic!(conn.transaction_isolation())); } + +#[test] +fn test_rows_index() { + let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap(); + conn.batch_execute(" + CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY); + INSERT INTO foo (id) VALUES (1), (2), (3); + ").unwrap(); + let stmt = conn.prepare("SELECT id FROM foo ORDER BY id").unwrap(); + let rows = stmt.query(&[]).unwrap(); + assert_eq!(3, rows.len()); + assert_eq!(2i32, rows.get(1).get(0)); +}