diff --git a/Cargo.toml b/Cargo.toml index e394b39f..254b755c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,5 @@ tokio-uds = { git = "https://github.com/sfackler/tokio" } tokio-io = { git = "https://github.com/sfackler/tokio" } tokio-timer = { git = "https://github.com/sfackler/tokio" } tokio-codec = { git = "https://github.com/sfackler/tokio" } +tokio-reactor = { git = "https://github.com/sfackler/tokio" } +tokio-executor = { git = "https://github.com/sfackler/tokio" } diff --git a/postgres-shared/src/params/mod.rs b/postgres-shared/src/params/mod.rs index b772ab0e..296483f9 100644 --- a/postgres-shared/src/params/mod.rs +++ b/postgres-shared/src/params/mod.rs @@ -2,8 +2,10 @@ use std::error::Error; use std::mem; use std::path::PathBuf; +use std::str::FromStr; use std::time::Duration; +use error; use params::url::Url; mod url; @@ -96,6 +98,14 @@ impl ConnectParams { } } +impl FromStr for ConnectParams { + type Err = error::Error; + + fn from_str(s: &str) -> Result { + s.into_connect_params().map_err(error::connect) + } +} + /// A builder for `ConnectParams`. pub struct Builder { port: u16, diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 38c78807..b492c624 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -37,6 +37,7 @@ fallible-iterator = "0.1.3" futures = "0.1.7" futures-cpupool = "0.1" lazy_static = "1.0" +log = "0.4" postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" } postgres-shared = { version = "0.4.0", path = "../postgres-shared" } state_machine_future = "0.1.7" @@ -48,3 +49,7 @@ want = "0.0.5" [target.'cfg(unix)'.dependencies] tokio-uds = "0.2" + +[dev-dependencies] +tokio = "0.1.7" +env_logger = "0.5" diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index c95657a6..0b84664d 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -14,6 +14,8 @@ extern crate futures; #[macro_use] extern crate lazy_static; #[macro_use] +extern crate log; +#[macro_use] extern crate state_machine_future; #[cfg(unix)] @@ -21,6 +23,7 @@ extern crate tokio_uds; use futures::{Async, Future, Poll}; use std::io; +use std::sync::atomic::{AtomicUsize, Ordering}; #[doc(inline)] pub use postgres_shared::stmt::Column; @@ -31,9 +34,12 @@ pub use postgres_shared::{CancelData, Notification}; use error::Error; use params::ConnectParams; +use types::Type; mod proto; +static NEXT_STATEMENT_ID: AtomicUsize = AtomicUsize::new(0); + fn bad_response() -> Error { Error::from(io::Error::new( io::ErrorKind::InvalidInput, @@ -48,13 +54,13 @@ fn disconnected() -> Error { )) } +pub fn connect(params: ConnectParams) -> Handshake { + Handshake(proto::HandshakeFuture::new(params)) +} + pub struct Client(proto::Client); impl Client { - pub fn connect(params: ConnectParams) -> Handshake { - Handshake(proto::HandshakeFuture::new(params)) - } - /// Polls to to determine whether the connection is ready to send new requests to the backend. /// /// Requests are unboundedly buffered to enable pipelining, but this risks unbounded memory consumption if requests @@ -64,6 +70,15 @@ impl Client { pub fn poll_ready(&mut self) -> Poll<(), Error> { self.0.poll_ready() } + + pub fn prepare(&mut self, query: &str) -> Prepare { + self.prepare_typed(query, &[]) + } + + pub fn prepare_typed(&mut self, query: &str, param_types: &[Type]) -> Prepare { + let name = format!("s{}", NEXT_STATEMENT_ID.fetch_add(1, Ordering::SeqCst)); + Prepare(self.0.prepare(name, query, param_types)) + } } pub struct Connection(proto::Connection); @@ -99,3 +114,28 @@ impl Future for Handshake { Ok(Async::Ready((Client(client), Connection(connection)))) } } + +pub struct Prepare(proto::PrepareFuture); + +impl Future for Prepare { + type Item = Statement; + type Error = Error; + + fn poll(&mut self) -> Poll { + let statement = try_ready!(self.0.poll()); + + Ok(Async::Ready(Statement(statement))) + } +} + +pub struct Statement(proto::Statement); + +impl Statement { + pub fn params(&self) -> &[Type] { + self.0.params() + } + + pub fn columns(&self) -> &[Column] { + self.0.columns() + } +} diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index d646aad5..8f78b864 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -1,11 +1,14 @@ use futures::sync::mpsc; use futures::Poll; use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; use want::Giver; use disconnected; use error::Error; use proto::connection::Request; +use proto::prepare::PrepareFuture; +use types::Type; pub struct Client { sender: mpsc::UnboundedSender, @@ -21,7 +24,18 @@ impl Client { self.giver.poll_want().map_err(|_| disconnected()) } - pub fn send(&mut self, messages: Vec) -> Result, Error> { + pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture { + let mut buf = vec![]; + let receiver = frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), &mut buf) + .and_then(|()| frontend::describe(b'S', &name, &mut buf)) + .and_then(|()| Ok(frontend::sync(&mut buf))) + .map_err(Into::into) + .and_then(|()| self.send(buf)); + + PrepareFuture::new(self.sender.clone(), receiver, name) + } + + fn send(&mut self, messages: Vec) -> Result, Error> { let (sender, receiver) = mpsc::channel(0); self.giver.give(); self.sender diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index cfd8a0b6..d841ad8c 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -7,6 +7,7 @@ use std::io; use tokio_codec::Framed; use want::Taker; +use disconnected; use error::{self, Error}; use proto::codec::PostgresCodec; use proto::socket::Socket; @@ -17,7 +18,7 @@ pub struct Request { pub sender: mpsc::Sender, } -#[derive(PartialEq)] +#[derive(PartialEq, Debug)] enum State { Active, Terminating, @@ -67,22 +68,28 @@ impl Connection { fn poll_response(&mut self) -> Poll, io::Error> { if let Some(message) = self.pending_response.take() { + trace!("retrying pending response"); return Ok(Async::Ready(Some(message))); } self.stream.poll() } - fn poll_read(&mut self) -> Result<(), Error> { + fn poll_read(&mut self) -> Poll<(), Error> { + if self.state != State::Active { + trace!("poll_read: done"); + return Ok(Async::Ready(())); + } + loop { let message = match self.poll_response()? { Async::Ready(Some(message)) => message, Async::Ready(None) => { - return Err(Error::from(io::Error::from(io::ErrorKind::UnexpectedEof))); + return Err(disconnected()); } Async::NotReady => { - self.taker.want(); - return Ok(()); + trace!("poll_read: waiting on response"); + return Ok(Async::NotReady); } }; @@ -123,7 +130,8 @@ impl Connection { Ok(AsyncSink::NotReady(message)) => { self.responses.push_front(sender); self.pending_response = Some(message); - return Ok(()); + trace!("poll_read: waiting on socket"); + return Ok(Async::NotReady); } } } @@ -131,57 +139,86 @@ impl Connection { fn poll_request(&mut self) -> Poll>, Error> { if let Some(message) = self.pending_request.take() { + trace!("retrying pending request"); return Ok(Async::Ready(Some(message))); } - match self.receiver.poll() { - Ok(Async::Ready(Some(request))) => { + match try_receive!(self.receiver.poll()) { + Some(request) => { + trace!("polled new request"); self.responses.push_back(request.sender); Ok(Async::Ready(Some(request.messages))) } - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), + None => Ok(Async::Ready(None)), } } - fn poll_write(&mut self) -> Result<(), Error> { + fn poll_write(&mut self) -> Poll<(), Error> { loop { + if self.state == State::Closing { + trace!("poll_write: done"); + return Ok(Async::Ready(())); + } + let request = match self.poll_request()? { Async::Ready(Some(request)) => request, Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => { + trace!("poll_write: at eof, terminating"); self.state = State::Terminating; let mut request = vec![]; frontend::terminate(&mut request); request } - Async::Ready(None) => return Ok(()), - Async::NotReady => return Ok(()), + Async::Ready(None) => { + trace!( + "poll_write: at eof, pending responses {}", + self.responses.len(), + ); + return Ok(Async::Ready(())); + } + Async::NotReady => { + trace!("poll_write: waiting on request"); + self.taker.want(); + return Ok(Async::NotReady); + } }; match self.stream.start_send(request)? { AsyncSink::Ready => { if self.state == State::Terminating { + trace!("poll_write: sent eof, closing"); self.state = State::Closing; } } AsyncSink::NotReady(request) => { + trace!("poll_write: waiting on socket"); self.pending_request = Some(request); - return Ok(()); + return Ok(Async::NotReady); } } } } - fn poll_flush(&mut self) -> Result<(), Error> { - self.stream.poll_complete()?; - Ok(()) + fn poll_flush(&mut self) -> Poll<(), Error> { + trace!("flushing"); + self.stream.poll_complete().map_err(Into::into) } fn poll_shutdown(&mut self) -> Poll<(), Error> { - match self.state { - State::Active | State::Terminating => Ok(Async::NotReady), - State::Closing => self.stream.close().map_err(Into::into), + if self.state != State::Closing { + return Ok(Async::NotReady); + } + + match self.stream.close() { + Ok(Async::Ready(())) => { + trace!("poll_shutdown: complete"); + Ok(Async::Ready(())) + } + Ok(Async::NotReady) => { + trace!("poll_shutdown: waiting on socket"); + Ok(Async::NotReady) + } + Err(e) => Err(Error::from(e)), } } } diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 16eb9c77..b4d725da 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -1,3 +1,13 @@ +macro_rules! try_receive { + ($e:expr) => { + match $e { + Ok(::futures::Async::Ready(v)) => v, + Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady), + Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), + } + }; +} + mod client; mod codec; mod connection; @@ -12,4 +22,4 @@ pub use proto::connection::Connection; pub use proto::handshake::HandshakeFuture; pub use proto::prepare::PrepareFuture; pub use proto::socket::Socket; -pub use statement::Statement; +pub use proto::statement::Statement; diff --git a/tokio-postgres/src/proto/prepare.rs b/tokio-postgres/src/proto/prepare.rs index dcf5ef7e..c651222d 100644 --- a/tokio-postgres/src/proto/prepare.rs +++ b/tokio-postgres/src/proto/prepare.rs @@ -1,20 +1,168 @@ +use fallible_iterator::FallibleIterator; use futures::sync::mpsc; -use postgres_protocol::message::backend::Message; +use futures::{Poll, Stream}; +use postgres_protocol::message::backend::{Message, ParameterDescriptionBody, RowDescriptionBody}; +use state_machine_future::RentToOwn; -use error::Error; +use error::{self, Error}; use proto::connection::Request; use proto::statement::Statement; +use types::Type; +use Column; +use {bad_response, disconnected}; #[derive(StateMachineFuture)] pub enum Prepare { - #[state_machine_future(start)] + #[state_machine_future(start, transitions(ReadParseComplete))] Start { sender: mpsc::UnboundedSender, receiver: Result, Error>, name: String, }, + #[state_machine_future(transitions(ReadParameterDescription))] + ReadParseComplete { + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, + name: String, + }, + #[state_machine_future(transitions(ReadRowDescription))] + ReadParameterDescription { + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, + name: String, + }, + #[state_machine_future(transitions(ReadReadyForQuery))] + ReadRowDescription { + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, + name: String, + parameters: ParameterDescriptionBody, + }, + #[state_machine_future(transitions(Finished))] + ReadReadyForQuery { + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, + name: String, + parameters: ParameterDescriptionBody, + columns: RowDescriptionBody, + }, #[state_machine_future(ready)] Finished(Statement), #[state_machine_future(error)] Failed(Error), } + +impl PollPrepare for Prepare { + fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { + let state = state.take(); + let receiver = state.receiver?; + + transition!(ReadParseComplete { + sender: state.sender, + receiver, + name: state.name, + }) + } + + fn poll_read_parse_complete<'a>( + state: &'a mut RentToOwn<'a, ReadParseComplete>, + ) -> Poll { + let message = try_receive!(state.receiver.poll()); + let state = state.take(); + + match message { + Some(Message::ParseComplete) => transition!(ReadParameterDescription { + sender: state.sender, + receiver: state.receiver, + name: state.name, + }), + Some(Message::ErrorResponse(body)) => Err(error::__db(body)), + Some(_) => Err(bad_response()), + None => Err(disconnected()), + } + } + + fn poll_read_parameter_description<'a>( + state: &'a mut RentToOwn<'a, ReadParameterDescription>, + ) -> Poll { + let message = try_receive!(state.receiver.poll()); + let state = state.take(); + + match message { + Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription { + sender: state.sender, + receiver: state.receiver, + name: state.name, + parameters: body, + }), + Some(Message::ErrorResponse(body)) => Err(error::__db(body)), + Some(_) => Err(bad_response()), + None => Err(disconnected()), + } + } + + fn poll_read_row_description<'a>( + state: &'a mut RentToOwn<'a, ReadRowDescription>, + ) -> Poll { + let message = try_receive!(state.receiver.poll()); + let state = state.take(); + + match message { + Some(Message::RowDescription(body)) => transition!(ReadReadyForQuery { + sender: state.sender, + receiver: state.receiver, + name: state.name, + parameters: state.parameters, + columns: body, + }), + Some(Message::ErrorResponse(body)) => Err(error::__db(body)), + Some(_) => Err(bad_response()), + None => Err(disconnected()), + } + } + + fn poll_read_ready_for_query<'a>( + state: &'a mut RentToOwn<'a, ReadReadyForQuery>, + ) -> Poll { + let message = try_receive!(state.receiver.poll()); + let state = state.take(); + + match message { + Some(Message::ReadyForQuery(_)) => { + // FIXME handle custom types + let parameters = state + .parameters + .parameters() + .map(|oid| Type::from_oid(oid).unwrap()) + .collect()?; + let columns = state + .columns + .fields() + .map(|f| { + Column::new(f.name().to_string(), Type::from_oid(f.type_oid()).unwrap()) + }) + .collect()?; + + transition!(Finished(Statement::new( + state.sender, + state.name, + parameters, + columns + ))) + } + Some(Message::ErrorResponse(body)) => Err(error::__db(body)), + Some(_) => Err(bad_response()), + None => Err(disconnected()), + } + } +} + +impl PrepareFuture { + pub fn new( + sender: mpsc::UnboundedSender, + receiver: Result, Error>, + name: String, + ) -> PrepareFuture { + Prepare::start(sender, receiver, name) + } +} diff --git a/tokio-postgres/src/proto/statement.rs b/tokio-postgres/src/proto/statement.rs index 407e18b8..557c9b7c 100644 --- a/tokio-postgres/src/proto/statement.rs +++ b/tokio-postgres/src/proto/statement.rs @@ -16,8 +16,9 @@ impl Drop for Statement { fn drop(&mut self) { let mut buf = vec![]; frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid"); + frontend::sync(&mut buf); let (sender, _) = mpsc::channel(0); - self.sender.unbounded_send(Request { + let _ = self.sender.unbounded_send(Request { messages: buf, sender, }); @@ -26,7 +27,7 @@ impl Drop for Statement { impl Statement { pub fn new( - sender: mpsc::UnboundedReceiver, + sender: mpsc::UnboundedSender, name: String, params: Vec, columns: Vec, diff --git a/tokio-postgres/tests/test.rs b/tokio-postgres/tests/test.rs new file mode 100644 index 00000000..4289d68f --- /dev/null +++ b/tokio-postgres/tests/test.rs @@ -0,0 +1,37 @@ +extern crate env_logger; +extern crate tokio; +extern crate tokio_postgres; + +use tokio::prelude::*; +use tokio::runtime::current_thread::Runtime; +use tokio_postgres::types::Type; + +#[test] +fn pipelined_prepare() { + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap()); + let (mut client, connection) = runtime.block_on(handshake).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + let prepare1 = client.prepare("SELECT 1::BIGINT WHERE $1::BOOL"); + let prepare2 = client.prepare("SELECT ''::TEXT, 1::FLOAT4 WHERE $1::VARCHAR IS NOT NULL"); + let prepare = prepare1.join(prepare2); + let (statement1, statement2) = runtime.block_on(prepare).unwrap(); + + assert_eq!(statement1.params(), &[Type::BOOL]); + assert_eq!(statement1.columns().len(), 1); + assert_eq!(statement1.columns()[0].type_(), &Type::INT8); + + assert_eq!(statement2.params(), &[Type::VARCHAR]); + assert_eq!(statement2.columns().len(), 2); + assert_eq!(statement2.columns()[0].type_(), &Type::TEXT); + assert_eq!(statement2.columns()[1].type_(), &Type::FLOAT4); + + drop(statement1); + drop(statement2); + drop(client); + runtime.run().unwrap(); +}