From 1fdfefbeda2658c3506aeb0a8172aac94bed6725 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sat, 22 Dec 2018 17:02:48 -0800 Subject: [PATCH] Add Client::poll_idle Closes #403 --- tokio-postgres/src/lib.rs | 4 +++ tokio-postgres/src/proto/client.rs | 44 ++++++++++++++++------- tokio-postgres/src/proto/connection.rs | 24 +++++++++---- tokio-postgres/src/proto/idle.rs | 47 ++++++++++++++++++++++++ tokio-postgres/src/proto/mod.rs | 1 + tokio-postgres/tests/test/main.rs | 50 ++++++++++++++++++++++++++ 6 files changed, 150 insertions(+), 20 deletions(-) create mode 100644 tokio-postgres/src/proto/idle.rs diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index fbcb81b7..7bddecfe 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -102,6 +102,10 @@ impl Client { pub fn is_closed(&self) -> bool { self.0.is_closed() } + + pub fn poll_idle(&mut self) -> Poll<(), Error> { + self.0.poll_idle() + } } #[must_use = "futures do nothing unless polled"] diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index 20bc841a..e0c6341c 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -1,7 +1,7 @@ use antidote::Mutex; use bytes::IntoBuf; use futures::sync::mpsc; -use futures::{AsyncSink, Sink, Stream}; +use futures::{AsyncSink, Poll, Sink, Stream}; use postgres_protocol; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; @@ -14,6 +14,7 @@ use crate::proto::connection::{Request, RequestMessages}; use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage}; use crate::proto::copy_out::CopyOutStream; use crate::proto::execute::ExecuteFuture; +use crate::proto::idle::{IdleGuard, IdleState}; use crate::proto::portal::Portal; use crate::proto::prepare::PrepareFuture; use crate::proto::query::QueryStream; @@ -22,7 +23,7 @@ use crate::proto::statement::Statement; use crate::types::{IsNull, Oid, ToSql, Type}; use crate::Error; -pub struct PendingRequest(Result); +pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>); pub struct WeakClient(Weak); @@ -41,6 +42,7 @@ struct State { struct Inner { state: Mutex, + idle: IdleState, sender: mpsc::UnboundedSender, } @@ -56,6 +58,7 @@ impl Client { typeinfo_enum_query: None, typeinfo_composite_query: None, }), + idle: IdleState::new(), sender, })) } @@ -64,6 +67,10 @@ impl Client { self.0.sender.is_closed() } + pub fn poll_idle(&self) -> Poll<(), Error> { + self.0.idle.poll_idle() + } + pub fn downgrade(&self) -> WeakClient { WeakClient(Arc::downgrade(&self.0)) } @@ -101,11 +108,15 @@ impl Client { } pub fn send(&self, request: PendingRequest) -> Result, Error> { - let messages = request.0?; + let (messages, idle) = request.0?; let (sender, receiver) = mpsc::channel(0); self.0 .sender - .unbounded_send(Request { messages, sender }) + .unbounded_send(Request { + messages, + sender, + idle: Some(idle), + }) .map(|_| receiver) .map_err(|_| Error::closed()) } @@ -134,7 +145,7 @@ impl Client { pub fn execute(&self, statement: &Statement, params: &[&dyn ToSql]) -> ExecuteFuture { let pending = PendingRequest( self.excecute_message(statement, params) - .map(RequestMessages::Single), + .map(|m| (RequestMessages::Single(m), self.0.idle.guard())), ); ExecuteFuture::new(self.clone(), pending, statement.clone()) } @@ -142,7 +153,7 @@ impl Client { pub fn query(&self, statement: &Statement, params: &[&dyn ToSql]) -> QueryStream { let pending = PendingRequest( self.excecute_message(statement, params) - .map(RequestMessages::Single), + .map(|m| (RequestMessages::Single(m), self.0.idle.guard())), ); QueryStream::new(self.clone(), pending, statement.clone()) } @@ -152,7 +163,8 @@ impl Client { if let Ok(ref mut buf) = buf { frontend::sync(buf); } - let pending = PendingRequest(buf.map(RequestMessages::Single)); + let pending = + PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard()))); BindFuture::new(self.clone(), pending, name, statement.clone()) } @@ -183,10 +195,13 @@ impl Client { Ok(AsyncSink::Ready) => {} _ => unreachable!("channel should have capacity"), } - RequestMessages::CopyIn { - receiver: CopyInReceiver::new(receiver), - pending_message: None, - } + ( + RequestMessages::CopyIn { + receiver: CopyInReceiver::new(receiver), + pending_message: None, + }, + self.0.idle.guard(), + ) })); CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender) } @@ -194,7 +209,7 @@ impl Client { pub fn copy_out(&self, statement: &Statement, params: &[&dyn ToSql]) -> CopyOutStream { let pending = PendingRequest( self.excecute_message(statement, params) - .map(RequestMessages::Single), + .map(|m| (RequestMessages::Single(m), self.0.idle.guard())), ); CopyOutStream::new(self.clone(), pending, statement.clone()) } @@ -215,6 +230,7 @@ impl Client { let _ = self.0.sender.unbounded_send(Request { messages: RequestMessages::Single(buf), sender, + idle: None, }); } @@ -261,6 +277,8 @@ impl Client { F: FnOnce(&mut Vec) -> Result<(), Error>, { let mut buf = vec![]; - PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf))) + PendingRequest( + messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())), + ) } } diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index 4684b4f3..e4c80fa1 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -10,6 +10,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::codec::PostgresCodec; use crate::proto::copy_in::CopyInReceiver; +use crate::proto::idle::IdleGuard; use crate::{AsyncMessage, CancelData, Notification}; use crate::{DbError, Error}; @@ -24,6 +25,12 @@ pub enum RequestMessages { pub struct Request { pub messages: RequestMessages, pub sender: mpsc::Sender, + pub idle: Option, +} + +struct Response { + sender: mpsc::Sender, + _idle: Option, } #[derive(PartialEq, Debug)] @@ -40,7 +47,7 @@ pub struct Connection { receiver: mpsc::UnboundedReceiver, pending_request: Option, pending_response: Option, - responses: VecDeque>, + responses: VecDeque, state: State, } @@ -124,8 +131,8 @@ where m => m, }; - let mut sender = match self.responses.pop_front() { - Some(sender) => sender, + let mut response = match self.responses.pop_front() { + Some(response) => response, None => match message { Message::ErrorResponse(error) => return Err(Error::db(error)), _ => return Err(Error::unexpected_message()), @@ -137,16 +144,16 @@ where _ => false, }; - match sender.start_send(message) { + match response.sender.start_send(message) { // if the receiver's hung up we still need to page through the rest of the messages // designated to it Ok(AsyncSink::Ready) | Err(_) => { if !request_complete { - self.responses.push_front(sender); + self.responses.push_front(response); } } Ok(AsyncSink::NotReady(message)) => { - self.responses.push_front(sender); + self.responses.push_front(response); self.pending_response = Some(message); trace!("poll_read: waiting on sender"); return Ok(None); @@ -164,7 +171,10 @@ where match try_ready_receive!(self.receiver.poll()) { Some(request) => { trace!("polled new request"); - self.responses.push_back(request.sender); + self.responses.push_back(Response { + sender: request.sender, + _idle: request.idle, + }); Ok(Async::Ready(Some(request.messages))) } None => Ok(Async::Ready(None)), diff --git a/tokio-postgres/src/proto/idle.rs b/tokio-postgres/src/proto/idle.rs new file mode 100644 index 00000000..d4cbe8f0 --- /dev/null +++ b/tokio-postgres/src/proto/idle.rs @@ -0,0 +1,47 @@ +use futures::task::AtomicTask; +use futures::{Async, Poll}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use crate::Error; + +struct Inner { + active: AtomicUsize, + task: AtomicTask, +} + +pub struct IdleState(Arc); + +impl IdleState { + pub fn new() -> IdleState { + IdleState(Arc::new(Inner { + active: AtomicUsize::new(0), + task: AtomicTask::new(), + })) + } + + pub fn guard(&self) -> IdleGuard { + self.0.active.fetch_add(1, Ordering::SeqCst); + IdleGuard(self.0.clone()) + } + + pub fn poll_idle(&self) -> Poll<(), Error> { + self.0.task.register(); + + if self.0.active.load(Ordering::SeqCst) == 0 { + Ok(Async::Ready(())) + } else { + Ok(Async::NotReady) + } + } +} + +pub struct IdleGuard(Arc); + +impl Drop for IdleGuard { + fn drop(&mut self) { + if self.0.active.fetch_sub(1, Ordering::SeqCst) == 1 { + self.0.task.notify(); + } + } +} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index de81620e..079deeee 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -31,6 +31,7 @@ mod copy_in; mod copy_out; mod execute; mod handshake; +mod idle; mod portal; mod prepare; mod query; diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 41918e65..b6be8662 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -4,6 +4,7 @@ use futures::sync::mpsc; use futures::{future, stream, try_ready}; use log::debug; use std::error::Error; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, Instant}; use tokio::net::TcpStream; use tokio::prelude::*; @@ -683,3 +684,52 @@ fn transaction_builder_around_moved_client() { drop(client); runtime.run().unwrap(); } + +#[test] +fn poll_idle() { + struct IdleFuture { + client: tokio_postgres::Client, + query: Option, + } + + impl Future for IdleFuture { + type Item = (); + type Error = tokio_postgres::Error; + + fn poll(&mut self) -> Poll<(), tokio_postgres::Error> { + if let Some(_) = self.query.take() { + assert!(!self.client.poll_idle().unwrap().is_ready()); + return Ok(Async::NotReady); + } + + try_ready!(self.client.poll_idle()); + assert!(QUERY_DONE.load(Ordering::SeqCst)); + + Ok(Async::Ready(())) + } + } + + static QUERY_DONE: AtomicBool = AtomicBool::new(false); + + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + let stmt = runtime.block_on(client.prepare("SELECT 1")).unwrap(); + + let query = client + .query(&stmt, &[]) + .collect() + .map(|_| QUERY_DONE.store(true, Ordering::SeqCst)) + .map_err(|e| panic!("{}", e)); + runtime.spawn(query); + + let future = IdleFuture { + query: Some(client.prepare("")), + client, + }; + runtime.block_on(future).unwrap(); +}