Add Notifications::next_block_for method

The setup is a little hairy, but seems correct.

cc 
This commit is contained in:
Steven Fackler 2014-12-02 20:34:46 -08:00
parent 75641e121f
commit d5998d8f2a
8 changed files with 141 additions and 46 deletions

View File

@ -20,21 +20,18 @@ name = "test"
path = "tests/test.rs" path = "tests/test.rs"
[features] [features]
default = ["uuid", "time"] default = ["uuid"]
[dependencies] [dependencies]
phf = "0.1" phf = "0.1"
phf_mac = "0.1" phf_mac = "0.1"
openssl = "0.2.1" openssl = "0.2.1"
time = "0.1"
[dependencies.uuid] [dependencies.uuid]
optional = true optional = true
version = "0.1" version = "0.1"
[dependencies.time]
optional = true
version = "0.1"
[dev-dependencies] [dev-dependencies]
url = "0.2" url = "0.2"

View File

@ -204,10 +204,7 @@ types. The driver currently supports the following conversions:
<td>JSON</td> <td>JSON</td>
</tr> </tr>
<tr> <tr>
<td> <td>time::Timespec</td>
<a href="https://github.com/rust-lang/time">time::Timespec</a>
(<a href="#optional-features">optional</a>)
</td>
<td>TIMESTAMP, TIMESTAMP WITH TIME ZONE</td> <td>TIMESTAMP, TIMESTAMP WITH TIME ZONE</td>
</tr> </tr>
<tr> <tr>
@ -226,10 +223,7 @@ types. The driver currently supports the following conversions:
<td>INT8RANGE</td> <td>INT8RANGE</td>
</tr> </tr>
<tr> <tr>
<td> <td>types::range::Range&lt;Timespec&gt;</td>
<a href="https://github.com/rust-lang/time">types::range::Range&lt;Timespec&gt;</a>
(<a href="#optional-features">optional</a>)
</td>
<td>TSRANGE, TSTZRANGE</td> <td>TSRANGE, TSTZRANGE</td>
</tr> </tr>
<tr> <tr>
@ -265,10 +259,7 @@ types. The driver currently supports the following conversions:
<td>INT8[], INT8[][], ...</td> <td>INT8[], INT8[][], ...</td>
</tr> </tr>
<tr> <tr>
<td> <td>types::array::ArrayBase&lt;Option&lt;Timespec&gt;&gt;</td>
<a href="https://github.com/rust-lang/time">types::array::ArrayBase&lt;Option&lt;Timespec&gt;&gt;</a>
(<a href="#optional-features">optional</a>)
</td>
<td>TIMESTAMP[], TIMESTAMPTZ[], TIMESTAMP[][], ...</td> <td>TIMESTAMP[], TIMESTAMPTZ[], TIMESTAMP[][], ...</td>
</tr> </tr>
<tr> <tr>
@ -308,10 +299,6 @@ traits.
[UUID](http://www.postgresql.org/docs/9.4/static/datatype-uuid.html) support is [UUID](http://www.postgresql.org/docs/9.4/static/datatype-uuid.html) support is
provided optionally by the `uuid` feature. It is enabled by default. provided optionally by the `uuid` feature. It is enabled by default.
### Time types
[Time](http://www.postgresql.org/docs/9.3/static/datatype-datetime.html)
support is provided optionally by the `time` feature. It is enabled by default.
To disable support for optional features, add `default-features = false` to To disable support for optional features, add `default-features = false` to
your Cargo manifest: your Cargo manifest:

View File

@ -1,8 +1,9 @@
use openssl::ssl::{SslStream, MaybeSslStream}; use openssl::ssl::{SslStream, MaybeSslStream};
use std::io::BufferedStream;
use std::io::net::ip::Port; use std::io::net::ip::Port;
use std::io::net::tcp::TcpStream; use std::io::net::tcp::TcpStream;
use std::io::net::pipe::UnixStream; use std::io::net::pipe::UnixStream;
use std::io::IoResult; use std::io::{IoResult, Stream};
use {ConnectParams, SslMode, ConnectTarget, ConnectError}; use {ConnectParams, SslMode, ConnectTarget, ConnectError};
use message; use message;
@ -11,6 +12,23 @@ use message::FrontendMessage::SslRequest;
const DEFAULT_PORT: Port = 5432; const DEFAULT_PORT: Port = 5432;
#[doc(hidden)]
pub trait Timeout {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>);
}
impl<S: Stream+Timeout> Timeout for MaybeSslStream<S> {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
self.get_mut().set_read_timeout(timeout_ms);
}
}
impl<S: Stream+Timeout> Timeout for BufferedStream<S> {
fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
self.get_mut().set_read_timeout(timeout_ms);
}
}
pub enum InternalStream { pub enum InternalStream {
Tcp(TcpStream), Tcp(TcpStream),
Unix(UnixStream), Unix(UnixStream),
@ -41,9 +59,8 @@ impl Writer for InternalStream {
} }
} }
impl InternalStream { impl Timeout for InternalStream {
#[allow(dead_code)] fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
pub fn set_read_timeout(&mut self, timeout_ms: Option<u64>) {
match *self { match *self {
InternalStream::Tcp(ref mut s) => s.set_read_timeout(timeout_ms), InternalStream::Tcp(ref mut s) => s.set_read_timeout(timeout_ms),
InternalStream::Unix(ref mut s) => s.set_read_timeout(timeout_ms), InternalStream::Unix(ref mut s) => s.set_read_timeout(timeout_ms),

View File

@ -65,6 +65,7 @@ extern crate phf;
extern crate phf_mac; extern crate phf_mac;
#[phase(plugin, link)] #[phase(plugin, link)]
extern crate log; extern crate log;
extern crate time;
use url::Url; use url::Url;
use openssl::crypto::hash::{HashType, Hasher}; use openssl::crypto::hash::{HashType, Hasher};
@ -72,14 +73,15 @@ use openssl::ssl::{SslContext, MaybeSslStream};
use serialize::hex::ToHex; use serialize::hex::ToHex;
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::collections::{RingBuf, HashMap}; use std::collections::{RingBuf, HashMap};
use std::io::{BufferedStream, IoResult}; use std::io::{BufferedStream, IoResult, IoError, IoErrorKind};
use std::io::net::ip::Port; use std::io::net::ip::Port;
use std::iter::IteratorCloneExt; use std::iter::IteratorCloneExt;
use std::time::Duration;
use std::mem; use std::mem;
use std::fmt; use std::fmt;
use std::result; use std::result;
use io::InternalStream; use io::{InternalStream, Timeout};
use message::{FrontendMessage, BackendMessage, RowDescriptionEntry}; use message::{FrontendMessage, BackendMessage, RowDescriptionEntry};
use message::FrontendMessage::*; use message::FrontendMessage::*;
use message::BackendMessage::*; use message::BackendMessage::*;
@ -248,8 +250,9 @@ impl<'conn> Notifications<'conn> {
return Ok(notification); return Ok(notification);
} }
check_desync!(self.conn.conn.borrow()); let mut conn = self.conn.conn.borrow_mut();
match try!(self.conn.conn.borrow_mut().read_message_with_notification()) { check_desync!(conn);
match try!(conn.read_message_with_notification()) {
NotificationResponse { pid, channel, payload } => { NotificationResponse { pid, channel, payload } => {
Ok(Notification { Ok(Notification {
pid: pid, pid: pid,
@ -260,6 +263,42 @@ impl<'conn> Notifications<'conn> {
_ => unreachable!() _ => unreachable!()
} }
} }
/// Returns the oldest pending notification
///
/// If no notifications are pending, blocks for up to `timeout` time, after
/// which an `IoError` with the `TimedOut` kind is returned.
pub fn next_block_for(&mut self, timeout: Duration) -> Result<Notification> {
if let Some(notification) = self.next() {
return Ok(notification);
}
let mut conn = self.conn.conn.borrow_mut();
check_desync!(conn);
let end = time::now().to_timespec() + timeout;
loop {
let now = time::now().to_timespec();
conn.stream.set_read_timeout(Some((end - now).num_milliseconds() as u64));
match conn.read_one_message() {
Ok(Some(NotificationResponse { pid, channel, payload })) => {
return Ok(Notification {
pid: pid,
channel: channel,
payload: payload
})
}
Ok(Some(_)) => unreachable!(),
Ok(None) => {}
Err(e @ IoError { kind: IoErrorKind::TimedOut, .. }) => {
conn.desynchronized = false;
return Err(Error::IoError(e));
}
Err(e) => return Err(Error::IoError(e)),
}
}
}
} }
/// Contains information necessary to cancel queries for a session /// Contains information necessary to cancel queries for a session
@ -394,19 +433,27 @@ impl InnerConnection {
Ok(try_desync!(self, self.stream.flush())) Ok(try_desync!(self, self.stream.flush()))
} }
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> { fn read_one_message(&mut self) -> IoResult<Option<BackendMessage>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, self.stream.read_message()) { match try_desync!(self, self.stream.read_message()) {
NoticeResponse { fields } => { NoticeResponse { fields } => {
if let Ok(err) = DbError::new_raw(fields) { if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle(err); self.notice_handler.handle(err);
} }
Ok(None)
} }
ParameterStatus { parameter, value } => { ParameterStatus { parameter, value } => {
debug!("Parameter {} = {}", parameter, value) debug!("Parameter {} = {}", parameter, value);
Ok(None)
} }
val => return Ok(val) val => Ok(Some(val))
}
}
fn read_message_with_notification(&mut self) -> IoResult<BackendMessage> {
loop {
if let Some(msg) = try!(self.read_one_message()) {
return Ok(msg);
} }
} }
} }

View File

@ -1,6 +1,7 @@
use std::io::{IoResult, IoError, OtherIoError, MemReader}; use std::io::{IoResult, IoError, OtherIoError, MemReader};
use std::mem; use std::mem;
use io::Timeout;
use types::Oid; use types::Oid;
use self::BackendMessage::*; use self::BackendMessage::*;
@ -272,9 +273,17 @@ pub trait ReadMessage {
fn read_message(&mut self) -> IoResult<BackendMessage>; fn read_message(&mut self) -> IoResult<BackendMessage>;
} }
impl<R: Reader> ReadMessage for R { impl<R: Reader+Timeout> ReadMessage for R {
fn read_message(&mut self) -> IoResult<BackendMessage> { fn read_message(&mut self) -> IoResult<BackendMessage> {
let ident = try!(self.read_u8()); // The first byte read is a bit complex to make
// Notifications#next_block_for work.
let ident = self.read_u8();
// At this point we've got to turn off any read timeout to prevent
// stream desynchronization. We're assuming that if we've got the first
// byte, there's more stuff to follow.
self.set_read_timeout(None);
let ident = try!(ident);
// subtract size of length value // subtract size of length value
let len = try!(self.read_be_u32()) as uint - mem::size_of::<i32>(); let len = try!(self.read_be_u32()) as uint - mem::size_of::<i32>();
let mut buf = MemReader::new(try!(self.read_exact(len))); let mut buf = MemReader::new(try!(self.read_exact(len)));

View File

@ -307,7 +307,6 @@ pub mod array;
pub mod range; pub mod range;
#[cfg(feature = "uuid")] #[cfg(feature = "uuid")]
mod uuid; mod uuid;
#[cfg(feature = "time")]
mod time; mod time;
/// A Postgres OID /// A Postgres OID

View File

@ -1,6 +1,4 @@
extern crate time; use time::Timespec;
use self::time::Timespec;
use Result; use Result;
use types::{RawFromSql, Type, RawToSql}; use types::{RawFromSql, Type, RawToSql};
use types::range::{Range, RangeBound, BoundSided, Normalizable}; use types::range::{Range, RangeBound, BoundSided, Normalizable};

View File

@ -8,6 +8,7 @@ extern crate openssl;
use openssl::ssl::SslContext; use openssl::ssl::SslContext;
use openssl::ssl::SslMethod::Sslv3; use openssl::ssl::SslMethod::Sslv3;
use std::io::{IoError, IoErrorKind};
use std::io::timer; use std::io::timer;
use std::time::Duration; use std::time::Duration;
@ -624,6 +625,46 @@ fn test_notifications_next_block() {
}, or_panic!(notifications.next_block())); }, or_panic!(notifications.next_block()));
} }
#[test]
fn test_notifications_next_block_for() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.execute("LISTEN test_notifications_next_block_for", &[]));
spawn(proc() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
timer::sleep(Duration::milliseconds(500));
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for, 'foo'", &[]));
});
let mut notifications = conn.notifications();
check_notification(Notification {
pid: 0,
channel: "test_notifications_next_block_for".to_string(),
payload: "foo".to_string()
}, or_panic!(notifications.next_block_for(Duration::seconds(2))));
}
#[test]
fn test_notifications_next_block_for_timeout() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.execute("LISTEN test_notifications_next_block_for_timeout", &[]));
spawn(proc() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
timer::sleep(Duration::seconds(2));
or_panic!(conn.execute("NOTIFY test_notifications_next_block_for_timeout, 'foo'", &[]));
});
let mut notifications = conn.notifications();
match notifications.next_block_for(Duration::milliseconds(500)) {
Err(Error::IoError(IoError { kind: IoErrorKind::TimedOut, .. })) => {},
Err(e) => panic!("Unexpected error {}", e),
Ok(_) => panic!("expected error"),
}
or_panic!(conn.execute("SELECT 1", &[]));
}
#[test] #[test]
// This test is pretty sad, but I don't think there's a better way :( // This test is pretty sad, but I don't think there's a better way :(
fn test_cancel_query() { fn test_cancel_query() {