Add Notifications::next_block_for method

The setup is a little hairy, but seems correct.

cc #19
This commit is contained in:
Steven Fackler 2014-12-02 20:34:46 -08:00
parent 75641e121f
commit d5998d8f2a
8 changed files with 141 additions and 46 deletions

View File

@ -20,21 +20,18 @@ name = "test"
path = "tests/test.rs"
[features]
default = ["uuid", "time"]
default = ["uuid"]
[dependencies]
phf = "0.1"
phf_mac = "0.1"
openssl = "0.2.1"
time = "0.1"
[dependencies.uuid]
optional = true
version = "0.1"
[dependencies.time]
optional = true
version = "0.1"
[dev-dependencies]
url = "0.2"

View File

@ -204,10 +204,7 @@ types. The driver currently supports the following conversions:
<td>JSON</td>
</tr>
<tr>
<td>
<a href="https://github.com/rust-lang/time">time::Timespec</a>
(<a href="#optional-features">optional</a>)
</td>
<td>time::Timespec</td>
<td>TIMESTAMP, TIMESTAMP WITH TIME ZONE</td>
</tr>
<tr>
@ -226,10 +223,7 @@ types. The driver currently supports the following conversions:
<td>INT8RANGE</td>
</tr>
<tr>
<td>
<a href="https://github.com/rust-lang/time">types::range::Range&lt;Timespec&gt;</a>
(<a href="#optional-features">optional</a>)
</td>
<td>types::range::Range&lt;Timespec&gt;</td>
<td>TSRANGE, TSTZRANGE</td>
</tr>
<tr>
@ -265,10 +259,7 @@ types. The driver currently supports the following conversions:
<td>INT8[], INT8[][], ...</td>
</tr>
<tr>
<td>
<a href="https://github.com/rust-lang/time">types::array::ArrayBase&lt;Option&lt;Timespec&gt;&gt;</a>
(<a href="#optional-features">optional</a>)
</td>
<td>types::array::ArrayBase&lt;Option&lt;Timespec&gt;&gt;</td>
<td>TIMESTAMP[], TIMESTAMPTZ[], TIMESTAMP[][], ...</td>
</tr>
<tr>
@ -308,10 +299,6 @@ traits.
[UUID](http://www.postgresql.org/docs/9.4/static/datatype-uuid.html) support is
provided optionally by the `uuid` feature. It is enabled by default.
### Time types
[Time](http://www.postgresql.org/docs/9.3/static/datatype-datetime.html)
support is provided optionally by the `time` feature. It is enabled by default.
To disable support for optional features, add `default-features = false` to
your Cargo manifest:

View File

@ -1,8 +1,9 @@
use openssl::ssl::{SslStream, MaybeSslStream};
use std::io::BufferedStream;
use std::io::net::ip::Port;
use std::io::net::tcp::TcpStream;
use std::io::net::pipe::UnixStream;
use std::io::IoResult;
use std::io::{IoResult, Stream};
use {ConnectParams, SslMode, ConnectTarget, ConnectError};
use message;
@ -11,6 +12,23 @@ use message::FrontendMessage::SslRequest;
const DEFAULT_PORT: Port = 5432;
#[doc(hidden)]
pub trait Timeout {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>);
}
impl<S: Stream+Timeout> Timeout for MaybeSslStream<S> {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
self.get_mut().set_read_timeout(timeout_ms);
}
}
impl<S: Stream+Timeout> Timeout for BufferedStream<S> {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
self.get_mut().set_read_timeout(timeout_ms);
}
}
pub enum InternalStream {
Tcp(TcpStream),
Unix(UnixStream),
@ -41,9 +59,8 @@ impl Writer for InternalStream {
}
}
impl InternalStream {
#[allow(dead_code)]
pub fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
impl Timeout for InternalStream {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
match *self {
InternalStream::Tcp(ref mut s) => s.set_read_timeout(timeout_ms),
InternalStream::Unix(ref mut s) => s.set_read_timeout(timeout_ms),

View File

@ -65,6 +65,7 @@ extern crate phf;
extern crate phf_mac;
#[phase(plugin, link)]
extern crate log;
extern crate time;
use url::Url;
use openssl::crypto::hash::{HashType, Hasher};
@ -72,14 +73,15 @@ use openssl::ssl::{SslContext, MaybeSslStream};
use serialize::hex::ToHex;
use std::cell::{Cell, RefCell};
use std::collections::{RingBuf, HashMap};
use std::io::{BufferedStream, IoResult};
use std::io::{BufferedStream, IoResult, IoError, IoErrorKind};
use std::io::net::ip::Port;
use std::iter::IteratorCloneExt;
use std::time::Duration;
use std::mem;
use std::fmt;
use std::result;
use io::InternalStream;
use io::{InternalStream, Timeout};
use message::{FrontendMessage, BackendMessage, RowDescriptionEntry};
use message::FrontendMessage::*;
use message::BackendMessage::*;
@ -248,8 +250,9 @@ impl<'conn> Notifications<'conn> {
return Ok(notification);
}
check_desync!(self.conn.conn.borrow());
match try!(self.conn.conn.borrow_mut().read_message_with_notification()) {
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
match try!(conn.read_message_with_notification()) {
NotificationResponse { pid, channel, payload } => {
Ok(Notification {
pid: pid,
@ -260,6 +263,42 @@ impl<'conn> Notifications<'conn> {
_ => unreachable!()
}
}
/// Returns the oldest pending notification
///
/// If no notifications are pending, blocks for up to `timeout` time, after
/// which an `IoError` with the `TimedOut` kind is returned.
pub fn next_block_for(&mut self, timeout: Duration) -> Result<Notification> {
if let Some(notification) = self.next() {
return Ok(notification);
}
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
let end = time::now().to_timespec() + timeout;
loop {
let now = time::now().to_timespec();
conn.stream.set_read_timeout(Some((end - now).num_milliseconds() as u64));
match conn.read_one_message() {
Ok(Some(NotificationResponse { pid, channel, payload })) => {
return Ok(Notification {
pid: pid,
channel: channel,
payload: payload
})
}
Ok(Some(_)) => unreachable!(),
Ok(None) => {}
Err(e @ IoError { kind: IoErrorKind::TimedOut, .. }) => {
conn.desynchronized = false;
return Err(Error::IoError(e));
}
Err(e) => return Err(Error::IoError(e)),
}
}
}
}
/// Contains information necessary to cancel queries for a session
@ -394,19 +433,27 @@ impl InnerConnection {
Ok(try_desync!(self, self.stream.flush()))
}
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> {
fn read_one_message(&mut self) -> IoResult<Option<BackendMessage>> {
debug_assert!(!self.desynchronized);
match try_desync!(self, self.stream.read_message()) {
NoticeResponse { fields } => {
if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle(err);
}
Ok(None)
}
ParameterStatus { parameter, value } => {
debug!("Parameter {} = {}", parameter, value);
Ok(None)
}
val => Ok(Some(val))
}
}
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> {
loop {
match try_desync!(self, self.stream.read_message()) {
NoticeResponse { fields } => {
if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle(err);
}
}
ParameterStatus { parameter, value } => {
debug!("Parameter {} = {}", parameter, value)
}
val => return Ok(val)
if let Some(msg) = try!(self.read_one_message()) {
return Ok(msg);
}
}
}

View File

@ -1,6 +1,7 @@
use std::io::{IoResult, IoError, OtherIoError, MemReader};
use std::mem;
use io::Timeout;
use types::Oid;
use self::BackendMessage::*;
@ -272,9 +273,17 @@ pub trait ReadMessage {
fn read_message(&mut self) -> IoResult<BackendMessage>;
}
impl<R: Reader> ReadMessage for R {
impl<R: Reader+Timeout> ReadMessage for R {
fn read_message(&mut self) -> IoResult<BackendMessage> {
let ident = try!(self.read_u8());
// The first byte read is a bit complex to make
// Notifications#next_block_for work.
let ident = self.read_u8();
// At this point we've got to turn off any read timeout to prevent
// stream desynchronization. We're assuming that if we've got the first
// byte, there's more stuff to follow.
self.set_read_timeout(None);
let ident = try!(ident);
// subtract size of length value
let len = try!(self.read_be_u32()) as uint - mem::size_of::<i32>();
let mut buf = MemReader::new(try!(self.read_exact(len)));

View File

@ -307,7 +307,6 @@ pub mod array;
pub mod range;
#[cfg(feature = "uuid")]
mod uuid;
#[cfg(feature = "time")]
mod time;
/// A Postgres OID

View File

@ -1,6 +1,4 @@
extern crate time;
use self::time::Timespec;
use time::Timespec;
use Result;
use types::{RawFromSql, Type, RawToSql};
use types::range::{Range, RangeBound, BoundSided, Normalizable};

View File

@ -8,6 +8,7 @@ extern crate openssl;
use openssl::ssl::SslContext;
use openssl::ssl::SslMethod::Sslv3;
use std::io::{IoError, IoErrorKind};
use std::io::timer;
use std::time::Duration;
@ -624,6 +625,46 @@ fn test_notifications_next_block() {
}, or_panic!(notifications.next_block()));
}
#[test]
fn test_notifications_next_block_for() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.execute("LISTEN test_notifications_next_block_for", &[]));
spawn(proc() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
timer::sleep(Duration::milliseconds(500));
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for, 'foo'", &[]));
});
let mut notifications = conn.notifications();
check_notification(Notification {
pid: 0,
channel: "test_notifications_next_block_for".to_string(),
payload: "foo".to_string()
}, or_panic!(notifications.next_block_for(Duration::seconds(2))));
}
#[test]
fn test_notifications_next_block_for_timeout() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.execute("LISTEN test_notifications_next_block_for_timeout", &[]));
spawn(proc() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
timer::sleep(Duration::seconds(2));
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for_timeout, 'foo'", &[]));
});
let mut notifications = conn.notifications();
match notifications.next_block_for(Duration::milliseconds(500)) {
Err(Error::IoError(IoError { kind: IoErrorKind::TimedOut, .. })) => {},
Err(e) => panic!("Unexpected error {}", e),
Ok(_) => panic!("expected error"),
}
or_panic!(conn.execute("SELECT 1", &[]));
}
#[test]
// This test is pretty sad, but I don't think there's a better way :(
fn test_cancel_query() {