Make the default notifications iterator read nonblocking

It is always super confusing as to when a notification that's been sent
to the client will actually show up in the old version of this iterator,
so it's best to have it see if there's anything waiting in the TCP
buffer.

Closes #149
This commit is contained in:
Steven Fackler 2015-12-27 10:13:42 -07:00
parent 278ee1cfd7
commit bb837bd872
5 changed files with 76 additions and 13 deletions

View File

@ -527,6 +527,24 @@ impl InnerConnection {
}
}
fn read_message_with_notification_nonblocking(&mut self)
-> std::io::Result<Option<BackendMessage>> {
debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, self.stream.read_message_nonblocking()) {
Some(NoticeResponse { fields }) => {
if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle_notice(err);
}
}
Some(ParameterStatus { parameter, value }) => {
self.parameters.insert(parameter, value);
}
val => return Ok(val),
}
}
}
fn read_message(&mut self) -> std_io::Result<BackendMessage> {
loop {
match try!(self.read_message_with_notification()) {

View File

@ -6,7 +6,7 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use types::Oid;
use util;
use priv_io::ReadTimeout;
use priv_io::StreamOptions;
use self::BackendMessage::*;
use self::FrontendMessage::*;
@ -287,10 +287,12 @@ pub trait ReadMessage {
fn read_message_timeout(&mut self, timeout: Duration) -> io::Result<Option<BackendMessage>>;
fn read_message_nonblocking(&mut self) -> io::Result<Option<BackendMessage>>;
fn finish_read_message(&mut self, ident: u8) -> io::Result<BackendMessage>;
}
impl<R: BufRead + ReadTimeout> ReadMessage for R {
impl<R: BufRead + StreamOptions> ReadMessage for R {
fn read_message(&mut self) -> io::Result<BackendMessage> {
let ident = try!(self.read_u8());
self.finish_read_message(ident)
@ -314,6 +316,24 @@ impl<R: BufRead + ReadTimeout> ReadMessage for R {
}
}
fn read_message_nonblocking(&mut self) -> io::Result<Option<BackendMessage>> {
try!(self.set_nonblocking(true));
let ident = self.read_u8();
try!(self.set_nonblocking(false));
match ident {
Ok(ident) => self.finish_read_message(ident).map(Some),
Err(e) => {
let e: io::Error = e.into();
if e.kind() == io::ErrorKind::WouldBlock {
Ok(None)
} else {
Err(e)
}
}
}
}
fn finish_read_message(&mut self, ident: u8) -> io::Result<BackendMessage> {
// subtract size of length value
let len = try!(self.read_u32::<BigEndian>()) - mem::size_of::<u32>() as u32;

View File

@ -42,8 +42,7 @@ impl<'conn> Notifications<'conn> {
/// # Note
///
/// This iterator may start returning `Some` after previously returning
/// `None` if more notifications are received. However, those notifications
/// will not be registered until the connection is used in some way.
/// `None` if more notifications are received.
pub fn iter<'a>(&'a self) -> Iter<'a> {
Iter { conn: self.conn }
}
@ -72,7 +71,7 @@ impl<'conn> Notifications<'conn> {
}
impl<'a, 'conn> IntoIterator for &'a Notifications<'conn> {
type Item = Notification;
type Item = Result<Notification>;
type IntoIter = Iter<'a>;
fn into_iter(self) -> Iter<'a> {
@ -92,10 +91,27 @@ pub struct Iter<'a> {
}
impl<'a> Iterator for Iter<'a> {
type Item = Notification;
type Item = Result<Notification>;
fn next(&mut self) -> Option<Notification> {
self.conn.conn.borrow_mut().notifications.pop_front()
fn next(&mut self) -> Option<Result<Notification>> {
let mut conn = self.conn.conn.borrow_mut();
if let Some(notification) = conn.notifications.pop_front() {
return Some(Ok(notification));
}
match conn.read_message_with_notification_nonblocking() {
Ok(Some(NotificationResponse { pid, channel, payload })) => {
Some(Ok(Notification {
pid: pid,
channel: channel,
payload: payload,
}))
}
Ok(None) => None,
Err(err) => Some(Err(Error::Io(err))),
_ => unreachable!(),
}
}
}

View File

@ -22,11 +22,12 @@ use message::FrontendMessage::SslRequest;
const DEFAULT_PORT: u16 = 5432;
#[doc(hidden)]
pub trait ReadTimeout {
pub trait StreamOptions {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>;
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()>;
}
impl ReadTimeout for BufStream<Box<StreamWrapper>> {
impl StreamOptions for BufStream<Box<StreamWrapper>> {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match self.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => {
@ -36,6 +37,14 @@ impl ReadTimeout for BufStream<Box<StreamWrapper>> {
InternalStream::Unix(ref s) => s.set_read_timeout(timeout),
}
}
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
match self.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_nonblocking(nonblock),
#[cfg(feature = "unix_socket")]
InternalStream::Unix(ref s) => s.set_nonblocking(nonblock),
}
}
}
/// A connection to the Postgres server.

View File

@ -575,12 +575,12 @@ fn test_notification_iterator_some() {
pid: 0,
channel: "test_notification_iterator_one_channel".to_string(),
payload: "hello".to_string()
}, it.next().unwrap());
}, it.next().unwrap().unwrap());
check_notification(Notification {
pid: 0,
channel: "test_notification_iterator_one_channel2".to_string(),
payload: "world".to_string()
}, it.next().unwrap());
}, it.next().unwrap().unwrap());
assert!(it.next().is_none());
or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel, '!'", &[]));
@ -588,7 +588,7 @@ fn test_notification_iterator_some() {
pid: 0,
channel: "test_notification_iterator_one_channel".to_string(),
payload: "!".to_string()
}, it.next().unwrap());
}, it.next().unwrap().unwrap());
assert!(it.next().is_none());
}