parent
0d3e18b251
commit
1fdfefbeda
@ -102,6 +102,10 @@ impl Client {
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.0.is_closed()
|
||||
}
|
||||
|
||||
pub fn poll_idle(&mut self) -> Poll<(), Error> {
|
||||
self.0.poll_idle()
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
|
@ -1,7 +1,7 @@
|
||||
use antidote::Mutex;
|
||||
use bytes::IntoBuf;
|
||||
use futures::sync::mpsc;
|
||||
use futures::{AsyncSink, Sink, Stream};
|
||||
use futures::{AsyncSink, Poll, Sink, Stream};
|
||||
use postgres_protocol;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
@ -14,6 +14,7 @@ use crate::proto::connection::{Request, RequestMessages};
|
||||
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
|
||||
use crate::proto::copy_out::CopyOutStream;
|
||||
use crate::proto::execute::ExecuteFuture;
|
||||
use crate::proto::idle::{IdleGuard, IdleState};
|
||||
use crate::proto::portal::Portal;
|
||||
use crate::proto::prepare::PrepareFuture;
|
||||
use crate::proto::query::QueryStream;
|
||||
@ -22,7 +23,7 @@ use crate::proto::statement::Statement;
|
||||
use crate::types::{IsNull, Oid, ToSql, Type};
|
||||
use crate::Error;
|
||||
|
||||
pub struct PendingRequest(Result<RequestMessages, Error>);
|
||||
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
|
||||
|
||||
pub struct WeakClient(Weak<Inner>);
|
||||
|
||||
@ -41,6 +42,7 @@ struct State {
|
||||
|
||||
struct Inner {
|
||||
state: Mutex<State>,
|
||||
idle: IdleState,
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
}
|
||||
|
||||
@ -56,6 +58,7 @@ impl Client {
|
||||
typeinfo_enum_query: None,
|
||||
typeinfo_composite_query: None,
|
||||
}),
|
||||
idle: IdleState::new(),
|
||||
sender,
|
||||
}))
|
||||
}
|
||||
@ -64,6 +67,10 @@ impl Client {
|
||||
self.0.sender.is_closed()
|
||||
}
|
||||
|
||||
pub fn poll_idle(&self) -> Poll<(), Error> {
|
||||
self.0.idle.poll_idle()
|
||||
}
|
||||
|
||||
pub fn downgrade(&self) -> WeakClient {
|
||||
WeakClient(Arc::downgrade(&self.0))
|
||||
}
|
||||
@ -101,11 +108,15 @@ impl Client {
|
||||
}
|
||||
|
||||
pub fn send(&self, request: PendingRequest) -> Result<mpsc::Receiver<Message>, Error> {
|
||||
let messages = request.0?;
|
||||
let (messages, idle) = request.0?;
|
||||
let (sender, receiver) = mpsc::channel(0);
|
||||
self.0
|
||||
.sender
|
||||
.unbounded_send(Request { messages, sender })
|
||||
.unbounded_send(Request {
|
||||
messages,
|
||||
sender,
|
||||
idle: Some(idle),
|
||||
})
|
||||
.map(|_| receiver)
|
||||
.map_err(|_| Error::closed())
|
||||
}
|
||||
@ -134,7 +145,7 @@ impl Client {
|
||||
pub fn execute(&self, statement: &Statement, params: &[&dyn ToSql]) -> ExecuteFuture {
|
||||
let pending = PendingRequest(
|
||||
self.excecute_message(statement, params)
|
||||
.map(RequestMessages::Single),
|
||||
.map(|m| (RequestMessages::Single(m), self.0.idle.guard())),
|
||||
);
|
||||
ExecuteFuture::new(self.clone(), pending, statement.clone())
|
||||
}
|
||||
@ -142,7 +153,7 @@ impl Client {
|
||||
pub fn query(&self, statement: &Statement, params: &[&dyn ToSql]) -> QueryStream<Statement> {
|
||||
let pending = PendingRequest(
|
||||
self.excecute_message(statement, params)
|
||||
.map(RequestMessages::Single),
|
||||
.map(|m| (RequestMessages::Single(m), self.0.idle.guard())),
|
||||
);
|
||||
QueryStream::new(self.clone(), pending, statement.clone())
|
||||
}
|
||||
@ -152,7 +163,8 @@ impl Client {
|
||||
if let Ok(ref mut buf) = buf {
|
||||
frontend::sync(buf);
|
||||
}
|
||||
let pending = PendingRequest(buf.map(RequestMessages::Single));
|
||||
let pending =
|
||||
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
|
||||
BindFuture::new(self.clone(), pending, name, statement.clone())
|
||||
}
|
||||
|
||||
@ -183,10 +195,13 @@ impl Client {
|
||||
Ok(AsyncSink::Ready) => {}
|
||||
_ => unreachable!("channel should have capacity"),
|
||||
}
|
||||
RequestMessages::CopyIn {
|
||||
receiver: CopyInReceiver::new(receiver),
|
||||
pending_message: None,
|
||||
}
|
||||
(
|
||||
RequestMessages::CopyIn {
|
||||
receiver: CopyInReceiver::new(receiver),
|
||||
pending_message: None,
|
||||
},
|
||||
self.0.idle.guard(),
|
||||
)
|
||||
}));
|
||||
CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender)
|
||||
}
|
||||
@ -194,7 +209,7 @@ impl Client {
|
||||
pub fn copy_out(&self, statement: &Statement, params: &[&dyn ToSql]) -> CopyOutStream {
|
||||
let pending = PendingRequest(
|
||||
self.excecute_message(statement, params)
|
||||
.map(RequestMessages::Single),
|
||||
.map(|m| (RequestMessages::Single(m), self.0.idle.guard())),
|
||||
);
|
||||
CopyOutStream::new(self.clone(), pending, statement.clone())
|
||||
}
|
||||
@ -215,6 +230,7 @@ impl Client {
|
||||
let _ = self.0.sender.unbounded_send(Request {
|
||||
messages: RequestMessages::Single(buf),
|
||||
sender,
|
||||
idle: None,
|
||||
});
|
||||
}
|
||||
|
||||
@ -261,6 +277,8 @@ impl Client {
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
|
||||
{
|
||||
let mut buf = vec![];
|
||||
PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf)))
|
||||
PendingRequest(
|
||||
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::codec::PostgresCodec;
|
||||
use crate::proto::copy_in::CopyInReceiver;
|
||||
use crate::proto::idle::IdleGuard;
|
||||
use crate::{AsyncMessage, CancelData, Notification};
|
||||
use crate::{DbError, Error};
|
||||
|
||||
@ -24,6 +25,12 @@ pub enum RequestMessages {
|
||||
pub struct Request {
|
||||
pub messages: RequestMessages,
|
||||
pub sender: mpsc::Sender<Message>,
|
||||
pub idle: Option<IdleGuard>,
|
||||
}
|
||||
|
||||
struct Response {
|
||||
sender: mpsc::Sender<Message>,
|
||||
_idle: Option<IdleGuard>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
@ -40,7 +47,7 @@ pub struct Connection<S> {
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_response: Option<Message>,
|
||||
responses: VecDeque<mpsc::Sender<Message>>,
|
||||
responses: VecDeque<Response>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
@ -124,8 +131,8 @@ where
|
||||
m => m,
|
||||
};
|
||||
|
||||
let mut sender = match self.responses.pop_front() {
|
||||
Some(sender) => sender,
|
||||
let mut response = match self.responses.pop_front() {
|
||||
Some(response) => response,
|
||||
None => match message {
|
||||
Message::ErrorResponse(error) => return Err(Error::db(error)),
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
@ -137,16 +144,16 @@ where
|
||||
_ => false,
|
||||
};
|
||||
|
||||
match sender.start_send(message) {
|
||||
match response.sender.start_send(message) {
|
||||
// if the receiver's hung up we still need to page through the rest of the messages
|
||||
// designated to it
|
||||
Ok(AsyncSink::Ready) | Err(_) => {
|
||||
if !request_complete {
|
||||
self.responses.push_front(sender);
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Ok(AsyncSink::NotReady(message)) => {
|
||||
self.responses.push_front(sender);
|
||||
self.responses.push_front(response);
|
||||
self.pending_response = Some(message);
|
||||
trace!("poll_read: waiting on sender");
|
||||
return Ok(None);
|
||||
@ -164,7 +171,10 @@ where
|
||||
match try_ready_receive!(self.receiver.poll()) {
|
||||
Some(request) => {
|
||||
trace!("polled new request");
|
||||
self.responses.push_back(request.sender);
|
||||
self.responses.push_back(Response {
|
||||
sender: request.sender,
|
||||
_idle: request.idle,
|
||||
});
|
||||
Ok(Async::Ready(Some(request.messages)))
|
||||
}
|
||||
None => Ok(Async::Ready(None)),
|
||||
|
47
tokio-postgres/src/proto/idle.rs
Normal file
47
tokio-postgres/src/proto/idle.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use futures::task::AtomicTask;
|
||||
use futures::{Async, Poll};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::Error;
|
||||
|
||||
struct Inner {
|
||||
active: AtomicUsize,
|
||||
task: AtomicTask,
|
||||
}
|
||||
|
||||
pub struct IdleState(Arc<Inner>);
|
||||
|
||||
impl IdleState {
|
||||
pub fn new() -> IdleState {
|
||||
IdleState(Arc::new(Inner {
|
||||
active: AtomicUsize::new(0),
|
||||
task: AtomicTask::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn guard(&self) -> IdleGuard {
|
||||
self.0.active.fetch_add(1, Ordering::SeqCst);
|
||||
IdleGuard(self.0.clone())
|
||||
}
|
||||
|
||||
pub fn poll_idle(&self) -> Poll<(), Error> {
|
||||
self.0.task.register();
|
||||
|
||||
if self.0.active.load(Ordering::SeqCst) == 0 {
|
||||
Ok(Async::Ready(()))
|
||||
} else {
|
||||
Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdleGuard(Arc<Inner>);
|
||||
|
||||
impl Drop for IdleGuard {
|
||||
fn drop(&mut self) {
|
||||
if self.0.active.fetch_sub(1, Ordering::SeqCst) == 1 {
|
||||
self.0.task.notify();
|
||||
}
|
||||
}
|
||||
}
|
@ -31,6 +31,7 @@ mod copy_in;
|
||||
mod copy_out;
|
||||
mod execute;
|
||||
mod handshake;
|
||||
mod idle;
|
||||
mod portal;
|
||||
mod prepare;
|
||||
mod query;
|
||||
|
@ -4,6 +4,7 @@ use futures::sync::mpsc;
|
||||
use futures::{future, stream, try_ready};
|
||||
use log::debug;
|
||||
use std::error::Error;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::prelude::*;
|
||||
@ -683,3 +684,52 @@ fn transaction_builder_around_moved_client() {
|
||||
drop(client);
|
||||
runtime.run().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poll_idle() {
|
||||
struct IdleFuture {
|
||||
client: tokio_postgres::Client,
|
||||
query: Option<tokio_postgres::Prepare>,
|
||||
}
|
||||
|
||||
impl Future for IdleFuture {
|
||||
type Item = ();
|
||||
type Error = tokio_postgres::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(), tokio_postgres::Error> {
|
||||
if let Some(_) = self.query.take() {
|
||||
assert!(!self.client.poll_idle().unwrap().is_ready());
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
|
||||
try_ready!(self.client.poll_idle());
|
||||
assert!(QUERY_DONE.load(Ordering::SeqCst));
|
||||
|
||||
Ok(Async::Ready(()))
|
||||
}
|
||||
}
|
||||
|
||||
static QUERY_DONE: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
let stmt = runtime.block_on(client.prepare("SELECT 1")).unwrap();
|
||||
|
||||
let query = client
|
||||
.query(&stmt, &[])
|
||||
.collect()
|
||||
.map(|_| QUERY_DONE.store(true, Ordering::SeqCst))
|
||||
.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(query);
|
||||
|
||||
let future = IdleFuture {
|
||||
query: Some(client.prepare("")),
|
||||
client,
|
||||
};
|
||||
runtime.block_on(future).unwrap();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user