Add Client::poll_idle

Closes #403
This commit is contained in:
Steven Fackler 2018-12-22 17:02:48 -08:00
parent 0d3e18b251
commit 1fdfefbeda
6 changed files with 150 additions and 20 deletions

View File

@ -102,6 +102,10 @@ impl Client {
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
pub fn poll_idle(&mut self) -> Poll<(), Error> {
self.0.poll_idle()
}
}
#[must_use = "futures do nothing unless polled"]

View File

@ -1,7 +1,7 @@
use antidote::Mutex;
use bytes::IntoBuf;
use futures::sync::mpsc;
use futures::{AsyncSink, Sink, Stream};
use futures::{AsyncSink, Poll, Sink, Stream};
use postgres_protocol;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
@ -14,6 +14,7 @@ use crate::proto::connection::{Request, RequestMessages};
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
use crate::proto::copy_out::CopyOutStream;
use crate::proto::execute::ExecuteFuture;
use crate::proto::idle::{IdleGuard, IdleState};
use crate::proto::portal::Portal;
use crate::proto::prepare::PrepareFuture;
use crate::proto::query::QueryStream;
@ -22,7 +23,7 @@ use crate::proto::statement::Statement;
use crate::types::{IsNull, Oid, ToSql, Type};
use crate::Error;
pub struct PendingRequest(Result<RequestMessages, Error>);
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
pub struct WeakClient(Weak<Inner>);
@ -41,6 +42,7 @@ struct State {
struct Inner {
state: Mutex<State>,
idle: IdleState,
sender: mpsc::UnboundedSender<Request>,
}
@ -56,6 +58,7 @@ impl Client {
typeinfo_enum_query: None,
typeinfo_composite_query: None,
}),
idle: IdleState::new(),
sender,
}))
}
@ -64,6 +67,10 @@ impl Client {
self.0.sender.is_closed()
}
pub fn poll_idle(&self) -> Poll<(), Error> {
self.0.idle.poll_idle()
}
pub fn downgrade(&self) -> WeakClient {
WeakClient(Arc::downgrade(&self.0))
}
@ -101,11 +108,15 @@ impl Client {
}
pub fn send(&self, request: PendingRequest) -> Result<mpsc::Receiver<Message>, Error> {
let messages = request.0?;
let (messages, idle) = request.0?;
let (sender, receiver) = mpsc::channel(0);
self.0
.sender
.unbounded_send(Request { messages, sender })
.unbounded_send(Request {
messages,
sender,
idle: Some(idle),
})
.map(|_| receiver)
.map_err(|_| Error::closed())
}
@ -134,7 +145,7 @@ impl Client {
pub fn execute(&self, statement: &Statement, params: &[&dyn ToSql]) -> ExecuteFuture {
let pending = PendingRequest(
self.excecute_message(statement, params)
.map(RequestMessages::Single),
.map(|m| (RequestMessages::Single(m), self.0.idle.guard())),
);
ExecuteFuture::new(self.clone(), pending, statement.clone())
}
@ -142,7 +153,7 @@ impl Client {
pub fn query(&self, statement: &Statement, params: &[&dyn ToSql]) -> QueryStream<Statement> {
let pending = PendingRequest(
self.excecute_message(statement, params)
.map(RequestMessages::Single),
.map(|m| (RequestMessages::Single(m), self.0.idle.guard())),
);
QueryStream::new(self.clone(), pending, statement.clone())
}
@ -152,7 +163,8 @@ impl Client {
if let Ok(ref mut buf) = buf {
frontend::sync(buf);
}
let pending = PendingRequest(buf.map(RequestMessages::Single));
let pending =
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
BindFuture::new(self.clone(), pending, name, statement.clone())
}
@ -183,10 +195,13 @@ impl Client {
Ok(AsyncSink::Ready) => {}
_ => unreachable!("channel should have capacity"),
}
RequestMessages::CopyIn {
receiver: CopyInReceiver::new(receiver),
pending_message: None,
}
(
RequestMessages::CopyIn {
receiver: CopyInReceiver::new(receiver),
pending_message: None,
},
self.0.idle.guard(),
)
}));
CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender)
}
@ -194,7 +209,7 @@ impl Client {
pub fn copy_out(&self, statement: &Statement, params: &[&dyn ToSql]) -> CopyOutStream {
let pending = PendingRequest(
self.excecute_message(statement, params)
.map(RequestMessages::Single),
.map(|m| (RequestMessages::Single(m), self.0.idle.guard())),
);
CopyOutStream::new(self.clone(), pending, statement.clone())
}
@ -215,6 +230,7 @@ impl Client {
let _ = self.0.sender.unbounded_send(Request {
messages: RequestMessages::Single(buf),
sender,
idle: None,
});
}
@ -261,6 +277,8 @@ impl Client {
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
{
let mut buf = vec![];
PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf)))
PendingRequest(
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
)
}
}

View File

@ -10,6 +10,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::codec::PostgresCodec;
use crate::proto::copy_in::CopyInReceiver;
use crate::proto::idle::IdleGuard;
use crate::{AsyncMessage, CancelData, Notification};
use crate::{DbError, Error};
@ -24,6 +25,12 @@ pub enum RequestMessages {
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<Message>,
pub idle: Option<IdleGuard>,
}
struct Response {
sender: mpsc::Sender<Message>,
_idle: Option<IdleGuard>,
}
#[derive(PartialEq, Debug)]
@ -40,7 +47,7 @@ pub struct Connection<S> {
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_response: Option<Message>,
responses: VecDeque<mpsc::Sender<Message>>,
responses: VecDeque<Response>,
state: State,
}
@ -124,8 +131,8 @@ where
m => m,
};
let mut sender = match self.responses.pop_front() {
Some(sender) => sender,
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match message {
Message::ErrorResponse(error) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
@ -137,16 +144,16 @@ where
_ => false,
};
match sender.start_send(message) {
match response.sender.start_send(message) {
// if the receiver's hung up we still need to page through the rest of the messages
// designated to it
Ok(AsyncSink::Ready) | Err(_) => {
if !request_complete {
self.responses.push_front(sender);
self.responses.push_front(response);
}
}
Ok(AsyncSink::NotReady(message)) => {
self.responses.push_front(sender);
self.responses.push_front(response);
self.pending_response = Some(message);
trace!("poll_read: waiting on sender");
return Ok(None);
@ -164,7 +171,10 @@ where
match try_ready_receive!(self.receiver.poll()) {
Some(request) => {
trace!("polled new request");
self.responses.push_back(request.sender);
self.responses.push_back(Response {
sender: request.sender,
_idle: request.idle,
});
Ok(Async::Ready(Some(request.messages)))
}
None => Ok(Async::Ready(None)),

View File

@ -0,0 +1,47 @@
use futures::task::AtomicTask;
use futures::{Async, Poll};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::Error;
struct Inner {
active: AtomicUsize,
task: AtomicTask,
}
pub struct IdleState(Arc<Inner>);
impl IdleState {
pub fn new() -> IdleState {
IdleState(Arc::new(Inner {
active: AtomicUsize::new(0),
task: AtomicTask::new(),
}))
}
pub fn guard(&self) -> IdleGuard {
self.0.active.fetch_add(1, Ordering::SeqCst);
IdleGuard(self.0.clone())
}
pub fn poll_idle(&self) -> Poll<(), Error> {
self.0.task.register();
if self.0.active.load(Ordering::SeqCst) == 0 {
Ok(Async::Ready(()))
} else {
Ok(Async::NotReady)
}
}
}
pub struct IdleGuard(Arc<Inner>);
impl Drop for IdleGuard {
fn drop(&mut self) {
if self.0.active.fetch_sub(1, Ordering::SeqCst) == 1 {
self.0.task.notify();
}
}
}

View File

@ -31,6 +31,7 @@ mod copy_in;
mod copy_out;
mod execute;
mod handshake;
mod idle;
mod portal;
mod prepare;
mod query;

View File

@ -4,6 +4,7 @@ use futures::sync::mpsc;
use futures::{future, stream, try_ready};
use log::debug;
use std::error::Error;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio::prelude::*;
@ -683,3 +684,52 @@ fn transaction_builder_around_moved_client() {
drop(client);
runtime.run().unwrap();
}
#[test]
fn poll_idle() {
struct IdleFuture {
client: tokio_postgres::Client,
query: Option<tokio_postgres::Prepare>,
}
impl Future for IdleFuture {
type Item = ();
type Error = tokio_postgres::Error;
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
if let Some(_) = self.query.take() {
assert!(!self.client.poll_idle().unwrap().is_ready());
return Ok(Async::NotReady);
}
try_ready!(self.client.poll_idle());
assert!(QUERY_DONE.load(Ordering::SeqCst));
Ok(Async::Ready(()))
}
}
static QUERY_DONE: AtomicBool = AtomicBool::new(false);
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let stmt = runtime.block_on(client.prepare("SELECT 1")).unwrap();
let query = client
.query(&stmt, &[])
.collect()
.map(|_| QUERY_DONE.store(true, Ordering::SeqCst))
.map_err(|e| panic!("{}", e));
runtime.spawn(query);
let future = IdleFuture {
query: Some(client.prepare("")),
client,
};
runtime.block_on(future).unwrap();
}