diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 34afad18..38c78807 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -44,6 +44,7 @@ tokio-codec = "0.1" tokio-io = "0.1" tokio-tcp = "0.1" tokio-timer = "0.2" +want = "0.0.5" [target.'cfg(unix)'.dependencies] tokio-uds = "0.2" diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index e9fe7297..c95657a6 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -7,6 +7,7 @@ extern crate tokio_codec; extern crate tokio_io; extern crate tokio_tcp; extern crate tokio_timer; +extern crate want; #[macro_use] extern crate futures; @@ -18,12 +19,13 @@ extern crate state_machine_future; #[cfg(unix)] extern crate tokio_uds; -use futures::sync::mpsc; use futures::{Async, Future, Poll}; use std::io; #[doc(inline)] -pub use postgres_shared::{error, params}; +pub use postgres_shared::stmt::Column; +#[doc(inline)] +pub use postgres_shared::{error, params, types}; #[doc(inline)] pub use postgres_shared::{CancelData, Notification}; @@ -46,12 +48,22 @@ fn disconnected() -> Error { )) } -pub struct Client(mpsc::Sender); +pub struct Client(proto::Client); impl Client { pub fn connect(params: ConnectParams) -> Handshake { Handshake(proto::HandshakeFuture::new(params)) } + + /// Polls to to determine whether the connection is ready to send new requests to the backend. + /// + /// Requests are unboundedly buffered to enable pipelining, but this risks unbounded memory consumption if requests + /// are produced at a faster pace than the backend can process. This method can be used to cooperatively "throttle" + /// request creation. Specifically, it returns ready when the connection has sent any queued requests and is waiting + /// on new requests from the client. + pub fn poll_ready(&mut self) -> Poll<(), Error> { + self.0.poll_ready() + } } pub struct Connection(proto::Connection); @@ -82,8 +94,8 @@ impl Future for Handshake { type Error = Error; fn poll(&mut self) -> Poll<(Client, Connection), Error> { - let (sender, connection) = try_ready!(self.0.poll()); + let (client, connection) = try_ready!(self.0.poll()); - Ok(Async::Ready((Client(sender), Connection(connection)))) + Ok(Async::Ready((Client(client), Connection(connection)))) } } diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs new file mode 100644 index 00000000..d646aad5 --- /dev/null +++ b/tokio-postgres/src/proto/client.rs @@ -0,0 +1,32 @@ +use futures::sync::mpsc; +use futures::Poll; +use postgres_protocol::message::backend::Message; +use want::Giver; + +use disconnected; +use error::Error; +use proto::connection::Request; + +pub struct Client { + sender: mpsc::UnboundedSender, + giver: Giver, +} + +impl Client { + pub fn new(sender: mpsc::UnboundedSender, giver: Giver) -> Client { + Client { sender, giver } + } + + pub fn poll_ready(&mut self) -> Poll<(), Error> { + self.giver.poll_want().map_err(|_| disconnected()) + } + + pub fn send(&mut self, messages: Vec) -> Result, Error> { + let (sender, receiver) = mpsc::channel(0); + self.giver.give(); + self.sender + .unbounded_send(Request { messages, sender }) + .map(|_| receiver) + .map_err(|_| disconnected()) + } +} diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index 87905cb0..cfd8a0b6 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -5,6 +5,7 @@ use postgres_protocol::message::frontend; use std::collections::{HashMap, VecDeque}; use std::io; use tokio_codec::Framed; +use want::Taker; use error::{self, Error}; use proto::codec::PostgresCodec; @@ -27,7 +28,8 @@ pub struct Connection { stream: Framed, cancel_data: CancelData, parameters: HashMap, - receiver: mpsc::Receiver, + receiver: mpsc::UnboundedReceiver, + taker: Taker, pending_request: Option>, pending_response: Option, responses: VecDeque>, @@ -39,13 +41,15 @@ impl Connection { stream: Framed, cancel_data: CancelData, parameters: HashMap, - receiver: mpsc::Receiver, + receiver: mpsc::UnboundedReceiver, + taker: Taker, ) -> Connection { Connection { stream, cancel_data, parameters, receiver, + taker, pending_request: None, pending_response: None, responses: VecDeque::new(), @@ -76,7 +80,10 @@ impl Connection { Async::Ready(None) => { return Err(Error::from(io::Error::from(io::ErrorKind::UnexpectedEof))); } - Async::NotReady => return Ok(()), + Async::NotReady => { + self.taker.want(); + return Ok(()); + } }; let message = match message { @@ -100,7 +107,7 @@ impl Connection { }, }; - let ready = match message { + let request_complete = match message { Message::ReadyForQuery(_) => true, _ => false, }; @@ -109,7 +116,7 @@ impl Connection { // if the receiver's hung up we still need to page through the rest of the messages // designated to it Ok(AsyncSink::Ready) | Err(_) => { - if !ready { + if !request_complete { self.responses.push_front(sender); } } diff --git a/tokio-postgres/src/proto/handshake.rs b/tokio-postgres/src/proto/handshake.rs index cccd9a4a..2521e90c 100644 --- a/tokio-postgres/src/proto/handshake.rs +++ b/tokio-postgres/src/proto/handshake.rs @@ -10,11 +10,13 @@ use state_machine_future::RentToOwn; use std::collections::HashMap; use std::io; use tokio_codec::Framed; +use want; use error::{self, Error}; use params::{ConnectParams, User}; +use proto::client::Client; use proto::codec::PostgresCodec; -use proto::connection::{Connection, Request}; +use proto::connection::Connection; use proto::socket::{ConnectFuture, Socket}; use {bad_response, disconnected, CancelData}; @@ -60,7 +62,7 @@ pub enum Handshake { parameters: HashMap, }, #[state_machine_future(ready)] - Finished((mpsc::Sender, Connection)), + Finished((Client, Connection)), #[state_machine_future(error)] Failed(Error), } @@ -281,10 +283,17 @@ impl PollHandshake for Handshake { let cancel_data = state.cancel_data.ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidData, "BackendKeyData message missing") })?; - let (sender, receiver) = mpsc::channel(0); - let connection = - Connection::new(state.stream, cancel_data, state.parameters, receiver); - transition!(Finished((sender, connection))) + let (sender, receiver) = mpsc::unbounded(); + let (giver, taker) = want::new(); + let client = Client::new(sender, giver); + let connection = Connection::new( + state.stream, + cancel_data, + state.parameters, + receiver, + taker, + ); + transition!(Finished((client, connection))) } Some(Message::ErrorResponse(body)) => return Err(error::__db(body)), Some(Message::NoticeResponse(_)) => {} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 541e3e87..16eb9c77 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -1,9 +1,15 @@ +mod client; mod codec; mod connection; mod handshake; +mod prepare; mod socket; +mod statement; +pub use proto::client::Client; pub use proto::codec::PostgresCodec; -pub use proto::connection::{Connection, Request}; +pub use proto::connection::Connection; pub use proto::handshake::HandshakeFuture; +pub use proto::prepare::PrepareFuture; pub use proto::socket::Socket; +pub use statement::Statement; diff --git a/tokio-postgres/src/proto/prepare.rs b/tokio-postgres/src/proto/prepare.rs new file mode 100644 index 00000000..dcf5ef7e --- /dev/null +++ b/tokio-postgres/src/proto/prepare.rs @@ -0,0 +1,20 @@ +use futures::sync::mpsc; +use postgres_protocol::message::backend::Message; + +use error::Error; +use proto::connection::Request; +use proto::statement::Statement; + +#[derive(StateMachineFuture)] +pub enum Prepare { + #[state_machine_future(start)] + Start { + sender: mpsc::UnboundedSender, + receiver: Result, Error>, + name: String, + }, + #[state_machine_future(ready)] + Finished(Statement), + #[state_machine_future(error)] + Failed(Error), +} diff --git a/tokio-postgres/src/proto/statement.rs b/tokio-postgres/src/proto/statement.rs new file mode 100644 index 00000000..407e18b8 --- /dev/null +++ b/tokio-postgres/src/proto/statement.rs @@ -0,0 +1,49 @@ +use futures::sync::mpsc; +use postgres_protocol::message::frontend; +use postgres_shared::stmt::Column; + +use proto::connection::Request; +use types::Type; + +pub struct Statement { + sender: mpsc::UnboundedSender, + name: String, + params: Vec, + columns: Vec, +} + +impl Drop for Statement { + fn drop(&mut self) { + let mut buf = vec![]; + frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid"); + let (sender, _) = mpsc::channel(0); + self.sender.unbounded_send(Request { + messages: buf, + sender, + }); + } +} + +impl Statement { + pub fn new( + sender: mpsc::UnboundedReceiver, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement { + sender, + name, + params, + columns, + } + } + + pub fn params(&self) -> &[Type] { + &self.params + } + + pub fn columns(&self) -> &[Column] { + &self.columns + } +}