diff --git a/src/lib.rs b/src/lib.rs index e161c6ad..4896aec6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -68,6 +68,7 @@ use std::io::prelude::*; use std::marker::Sync as StdSync; use std::mem; use std::result; +use std::time::Duration; #[cfg(feature = "unix_socket")] use std::path::PathBuf; @@ -509,6 +510,24 @@ impl InnerConnection { } } + fn read_message_with_notification_timeout(&mut self, timeout: Duration) + -> std::io::Result> { + debug_assert!(!self.desynchronized); + loop { + match try_desync!(self, self.stream.read_message_timeout(timeout)) { + Some(NoticeResponse { fields }) => { + if let Ok(err) = DbError::new_raw(fields) { + self.notice_handler.handle_notice(err); + } + } + Some(ParameterStatus { parameter, value }) => { + self.parameters.insert(parameter, value); + } + val => return Ok(val) + } + } + } + fn read_message(&mut self) -> std_io::Result { loop { match try!(self.read_message_with_notification()) { diff --git a/src/notification.rs b/src/notification.rs index cfb8b28f..e12d2b46 100644 --- a/src/notification.rs +++ b/src/notification.rs @@ -1,6 +1,7 @@ //! Asynchronous notifications. use std::fmt; +use std::time::Duration; use {desynchronized, Result, Connection, NotificationsNew}; use message::BackendMessage::NotificationResponse; @@ -48,8 +49,8 @@ impl<'conn> Notifications<'conn> { } } - /// Returns an iterator over notifications, blocking until one is received - /// if none are pending. + /// Returns an iterator over notifications that blocks until one is + /// received if none are pending. /// /// The iterator will never return `None`. pub fn blocking_iter<'a>(&'a self) -> BlockingIter<'a> { @@ -57,6 +58,20 @@ impl<'conn> Notifications<'conn> { conn: self.conn, } } + + /// Returns an iterator over notifications that blocks for a limited time + /// waiting to receive one if none are pending. + /// + /// # Note + /// + /// THis iterator may start returning `Some` after previously returning + /// `None` if more notifications are received. + pub fn timeout_iter<'a>(&'a self, timeout: Duration) -> TimeoutIter<'a> { + TimeoutIter { + conn: self.conn, + timeout: timeout, + } + } } impl<'a, 'conn> IntoIterator for &'a Notifications<'conn> { @@ -121,3 +136,39 @@ impl<'a> Iterator for BlockingIter<'a> { } } } + +/// An iterator over notifications which will block for a period of time if +/// none are pending. +pub struct TimeoutIter<'a> { + conn: &'a Connection, + timeout: Duration, +} + +impl<'a> Iterator for TimeoutIter<'a> { + type Item = Result; + + fn next(&mut self) -> Option> { + let mut conn = self.conn.conn.borrow_mut(); + + if let Some(notification) = conn.notifications.pop_front() { + return Some(Ok(notification)); + } + + if conn.is_desynchronized() { + return Some(Err(Error::IoError(desynchronized()))); + } + + match conn.read_message_with_notification_timeout(self.timeout) { + Ok(Some(NotificationResponse { pid, channel, payload })) => { + Some(Ok(Notification { + pid: pid, + channel: channel, + payload: payload + })) + } + Ok(None) => None, + Err(err) => Some(Err(Error::IoError(err))), + _ => unreachable!() + } + } +} diff --git a/tests/test.rs b/tests/test.rs index e90ea3c4..21818028 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -9,6 +9,7 @@ use openssl::ssl::{SslContext, SslMethod}; use std::thread; use std::io; use std::io::prelude::*; +use std::time::Duration; use postgres::{HandleNotice, Connection, @@ -609,6 +610,30 @@ fn test_notifications_next_block() { }, or_panic!(notifications.blocking_iter().next().unwrap())); } +#[test] +fn test_notification_next_timeout() { + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + or_panic!(conn.execute("LISTEN test_notifications_next_timeout", &[])); + + let _t = thread::spawn(|| { + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + thread::sleep_ms(500); + or_panic!(conn.execute("NOTIFY test_notifications_next_timeout, 'foo'", &[])); + thread::sleep_ms(1500); + or_panic!(conn.execute("NOTIFY test_notifications_next_timeout, 'foo'", &[])); + }); + + let notifications = conn.notifications(); + let mut it = notifications.timeout_iter(Duration::from_secs(1)); + check_notification(Notification { + pid: 0, + channel: "test_notifications_next_timeout".to_string(), + payload: "foo".to_string() + }, or_panic!(it.next().unwrap())); + + assert!(it.next().is_none()); +} + #[test] // This test is pretty sad, but I don't think there's a better way :( fn test_cancel_query() {