Working statement preparation

This commit is contained in:
Steven Fackler 2018-06-18 22:34:25 -04:00
parent 0d0435fc2e
commit 13fcea7ae2
10 changed files with 336 additions and 32 deletions

View File

@ -14,3 +14,5 @@ tokio-uds = { git = "https://github.com/sfackler/tokio" }
tokio-io = { git = "https://github.com/sfackler/tokio" } tokio-io = { git = "https://github.com/sfackler/tokio" }
tokio-timer = { git = "https://github.com/sfackler/tokio" } tokio-timer = { git = "https://github.com/sfackler/tokio" }
tokio-codec = { git = "https://github.com/sfackler/tokio" } tokio-codec = { git = "https://github.com/sfackler/tokio" }
tokio-reactor = { git = "https://github.com/sfackler/tokio" }
tokio-executor = { git = "https://github.com/sfackler/tokio" }

View File

@ -2,8 +2,10 @@
use std::error::Error; use std::error::Error;
use std::mem; use std::mem;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration; use std::time::Duration;
use error;
use params::url::Url; use params::url::Url;
mod url; mod url;
@ -96,6 +98,14 @@ impl ConnectParams {
} }
} }
impl FromStr for ConnectParams {
type Err = error::Error;
fn from_str(s: &str) -> Result<ConnectParams, error::Error> {
s.into_connect_params().map_err(error::connect)
}
}
/// A builder for `ConnectParams`. /// A builder for `ConnectParams`.
pub struct Builder { pub struct Builder {
port: u16, port: u16,

View File

@ -37,6 +37,7 @@ fallible-iterator = "0.1.3"
futures = "0.1.7" futures = "0.1.7"
futures-cpupool = "0.1" futures-cpupool = "0.1"
lazy_static = "1.0" lazy_static = "1.0"
log = "0.4"
postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" } postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
postgres-shared = { version = "0.4.0", path = "../postgres-shared" } postgres-shared = { version = "0.4.0", path = "../postgres-shared" }
state_machine_future = "0.1.7" state_machine_future = "0.1.7"
@ -48,3 +49,7 @@ want = "0.0.5"
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
tokio-uds = "0.2" tokio-uds = "0.2"
[dev-dependencies]
tokio = "0.1.7"
env_logger = "0.5"

View File

@ -14,6 +14,8 @@ extern crate futures;
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
#[macro_use] #[macro_use]
extern crate log;
#[macro_use]
extern crate state_machine_future; extern crate state_machine_future;
#[cfg(unix)] #[cfg(unix)]
@ -21,6 +23,7 @@ extern crate tokio_uds;
use futures::{Async, Future, Poll}; use futures::{Async, Future, Poll};
use std::io; use std::io;
use std::sync::atomic::{AtomicUsize, Ordering};
#[doc(inline)] #[doc(inline)]
pub use postgres_shared::stmt::Column; pub use postgres_shared::stmt::Column;
@ -31,9 +34,12 @@ pub use postgres_shared::{CancelData, Notification};
use error::Error; use error::Error;
use params::ConnectParams; use params::ConnectParams;
use types::Type;
mod proto; mod proto;
static NEXT_STATEMENT_ID: AtomicUsize = AtomicUsize::new(0);
fn bad_response() -> Error { fn bad_response() -> Error {
Error::from(io::Error::new( Error::from(io::Error::new(
io::ErrorKind::InvalidInput, io::ErrorKind::InvalidInput,
@ -48,13 +54,13 @@ fn disconnected() -> Error {
)) ))
} }
pub fn connect(params: ConnectParams) -> Handshake {
Handshake(proto::HandshakeFuture::new(params))
}
pub struct Client(proto::Client); pub struct Client(proto::Client);
impl Client { impl Client {
pub fn connect(params: ConnectParams) -> Handshake {
Handshake(proto::HandshakeFuture::new(params))
}
/// Polls to to determine whether the connection is ready to send new requests to the backend. /// Polls to to determine whether the connection is ready to send new requests to the backend.
/// ///
/// Requests are unboundedly buffered to enable pipelining, but this risks unbounded memory consumption if requests /// Requests are unboundedly buffered to enable pipelining, but this risks unbounded memory consumption if requests
@ -64,6 +70,15 @@ impl Client {
pub fn poll_ready(&mut self) -> Poll<(), Error> { pub fn poll_ready(&mut self) -> Poll<(), Error> {
self.0.poll_ready() self.0.poll_ready()
} }
pub fn prepare(&mut self, query: &str) -> Prepare {
self.prepare_typed(query, &[])
}
pub fn prepare_typed(&mut self, query: &str, param_types: &[Type]) -> Prepare {
let name = format!("s{}", NEXT_STATEMENT_ID.fetch_add(1, Ordering::SeqCst));
Prepare(self.0.prepare(name, query, param_types))
}
} }
pub struct Connection(proto::Connection); pub struct Connection(proto::Connection);
@ -99,3 +114,28 @@ impl Future for Handshake {
Ok(Async::Ready((Client(client), Connection(connection)))) Ok(Async::Ready((Client(client), Connection(connection))))
} }
} }
pub struct Prepare(proto::PrepareFuture);
impl Future for Prepare {
type Item = Statement;
type Error = Error;
fn poll(&mut self) -> Poll<Statement, Error> {
let statement = try_ready!(self.0.poll());
Ok(Async::Ready(Statement(statement)))
}
}
pub struct Statement(proto::Statement);
impl Statement {
pub fn params(&self) -> &[Type] {
self.0.params()
}
pub fn columns(&self) -> &[Column] {
self.0.columns()
}
}

View File

@ -1,11 +1,14 @@
use futures::sync::mpsc; use futures::sync::mpsc;
use futures::Poll; use futures::Poll;
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use want::Giver; use want::Giver;
use disconnected; use disconnected;
use error::Error; use error::Error;
use proto::connection::Request; use proto::connection::Request;
use proto::prepare::PrepareFuture;
use types::Type;
pub struct Client { pub struct Client {
sender: mpsc::UnboundedSender<Request>, sender: mpsc::UnboundedSender<Request>,
@ -21,7 +24,18 @@ impl Client {
self.giver.poll_want().map_err(|_| disconnected()) self.giver.poll_want().map_err(|_| disconnected())
} }
pub fn send(&mut self, messages: Vec<u8>) -> Result<mpsc::Receiver<Message>, Error> { pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
let mut buf = vec![];
let receiver = frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), &mut buf)
.and_then(|()| frontend::describe(b'S', &name, &mut buf))
.and_then(|()| Ok(frontend::sync(&mut buf)))
.map_err(Into::into)
.and_then(|()| self.send(buf));
PrepareFuture::new(self.sender.clone(), receiver, name)
}
fn send(&mut self, messages: Vec<u8>) -> Result<mpsc::Receiver<Message>, Error> {
let (sender, receiver) = mpsc::channel(0); let (sender, receiver) = mpsc::channel(0);
self.giver.give(); self.giver.give();
self.sender self.sender

View File

@ -7,6 +7,7 @@ use std::io;
use tokio_codec::Framed; use tokio_codec::Framed;
use want::Taker; use want::Taker;
use disconnected;
use error::{self, Error}; use error::{self, Error};
use proto::codec::PostgresCodec; use proto::codec::PostgresCodec;
use proto::socket::Socket; use proto::socket::Socket;
@ -17,7 +18,7 @@ pub struct Request {
pub sender: mpsc::Sender<Message>, pub sender: mpsc::Sender<Message>,
} }
#[derive(PartialEq)] #[derive(PartialEq, Debug)]
enum State { enum State {
Active, Active,
Terminating, Terminating,
@ -67,22 +68,28 @@ impl Connection {
fn poll_response(&mut self) -> Poll<Option<Message>, io::Error> { fn poll_response(&mut self) -> Poll<Option<Message>, io::Error> {
if let Some(message) = self.pending_response.take() { if let Some(message) = self.pending_response.take() {
trace!("retrying pending response");
return Ok(Async::Ready(Some(message))); return Ok(Async::Ready(Some(message)));
} }
self.stream.poll() self.stream.poll()
} }
fn poll_read(&mut self) -> Result<(), Error> { fn poll_read(&mut self) -> Poll<(), Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(Async::Ready(()));
}
loop { loop {
let message = match self.poll_response()? { let message = match self.poll_response()? {
Async::Ready(Some(message)) => message, Async::Ready(Some(message)) => message,
Async::Ready(None) => { Async::Ready(None) => {
return Err(Error::from(io::Error::from(io::ErrorKind::UnexpectedEof))); return Err(disconnected());
} }
Async::NotReady => { Async::NotReady => {
self.taker.want(); trace!("poll_read: waiting on response");
return Ok(()); return Ok(Async::NotReady);
} }
}; };
@ -123,7 +130,8 @@ impl Connection {
Ok(AsyncSink::NotReady(message)) => { Ok(AsyncSink::NotReady(message)) => {
self.responses.push_front(sender); self.responses.push_front(sender);
self.pending_response = Some(message); self.pending_response = Some(message);
return Ok(()); trace!("poll_read: waiting on socket");
return Ok(Async::NotReady);
} }
} }
} }
@ -131,57 +139,86 @@ impl Connection {
fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> { fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> {
if let Some(message) = self.pending_request.take() { if let Some(message) = self.pending_request.take() {
trace!("retrying pending request");
return Ok(Async::Ready(Some(message))); return Ok(Async::Ready(Some(message)));
} }
match self.receiver.poll() { match try_receive!(self.receiver.poll()) {
Ok(Async::Ready(Some(request))) => { Some(request) => {
trace!("polled new request");
self.responses.push_back(request.sender); self.responses.push_back(request.sender);
Ok(Async::Ready(Some(request.messages))) Ok(Async::Ready(Some(request.messages)))
} }
Ok(Async::Ready(None)) => Ok(Async::Ready(None)), None => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
} }
} }
fn poll_write(&mut self) -> Result<(), Error> { fn poll_write(&mut self) -> Poll<(), Error> {
loop { loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(Async::Ready(()));
}
let request = match self.poll_request()? { let request = match self.poll_request()? {
Async::Ready(Some(request)) => request, Async::Ready(Some(request)) => request,
Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => { Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating; self.state = State::Terminating;
let mut request = vec![]; let mut request = vec![];
frontend::terminate(&mut request); frontend::terminate(&mut request);
request request
} }
Async::Ready(None) => return Ok(()), Async::Ready(None) => {
Async::NotReady => return Ok(()), trace!(
"poll_write: at eof, pending responses {}",
self.responses.len(),
);
return Ok(Async::Ready(()));
}
Async::NotReady => {
trace!("poll_write: waiting on request");
self.taker.want();
return Ok(Async::NotReady);
}
}; };
match self.stream.start_send(request)? { match self.stream.start_send(request)? {
AsyncSink::Ready => { AsyncSink::Ready => {
if self.state == State::Terminating { if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing; self.state = State::Closing;
} }
} }
AsyncSink::NotReady(request) => { AsyncSink::NotReady(request) => {
trace!("poll_write: waiting on socket");
self.pending_request = Some(request); self.pending_request = Some(request);
return Ok(()); return Ok(Async::NotReady);
} }
} }
} }
} }
fn poll_flush(&mut self) -> Result<(), Error> { fn poll_flush(&mut self) -> Poll<(), Error> {
self.stream.poll_complete()?; trace!("flushing");
Ok(()) self.stream.poll_complete().map_err(Into::into)
} }
fn poll_shutdown(&mut self) -> Poll<(), Error> { fn poll_shutdown(&mut self) -> Poll<(), Error> {
match self.state { if self.state != State::Closing {
State::Active | State::Terminating => Ok(Async::NotReady), return Ok(Async::NotReady);
State::Closing => self.stream.close().map_err(Into::into), }
match self.stream.close() {
Ok(Async::Ready(())) => {
trace!("poll_shutdown: complete");
Ok(Async::Ready(()))
}
Ok(Async::NotReady) => {
trace!("poll_shutdown: waiting on socket");
Ok(Async::NotReady)
}
Err(e) => Err(Error::from(e)),
} }
} }
} }

View File

@ -1,3 +1,13 @@
macro_rules! try_receive {
($e:expr) => {
match $e {
Ok(::futures::Async::Ready(v)) => v,
Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady),
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
}
};
}
mod client; mod client;
mod codec; mod codec;
mod connection; mod connection;
@ -12,4 +22,4 @@ pub use proto::connection::Connection;
pub use proto::handshake::HandshakeFuture; pub use proto::handshake::HandshakeFuture;
pub use proto::prepare::PrepareFuture; pub use proto::prepare::PrepareFuture;
pub use proto::socket::Socket; pub use proto::socket::Socket;
pub use statement::Statement; pub use proto::statement::Statement;

View File

@ -1,20 +1,168 @@
use fallible_iterator::FallibleIterator;
use futures::sync::mpsc; use futures::sync::mpsc;
use postgres_protocol::message::backend::Message; use futures::{Poll, Stream};
use postgres_protocol::message::backend::{Message, ParameterDescriptionBody, RowDescriptionBody};
use state_machine_future::RentToOwn;
use error::Error; use error::{self, Error};
use proto::connection::Request; use proto::connection::Request;
use proto::statement::Statement; use proto::statement::Statement;
use types::Type;
use Column;
use {bad_response, disconnected};
#[derive(StateMachineFuture)] #[derive(StateMachineFuture)]
pub enum Prepare { pub enum Prepare {
#[state_machine_future(start)] #[state_machine_future(start, transitions(ReadParseComplete))]
Start { Start {
sender: mpsc::UnboundedSender<Request>, sender: mpsc::UnboundedSender<Request>,
receiver: Result<mpsc::Receiver<Message>, Error>, receiver: Result<mpsc::Receiver<Message>, Error>,
name: String, name: String,
}, },
#[state_machine_future(transitions(ReadParameterDescription))]
ReadParseComplete {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
},
#[state_machine_future(transitions(ReadRowDescription))]
ReadParameterDescription {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
},
#[state_machine_future(transitions(ReadReadyForQuery))]
ReadRowDescription {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
parameters: ParameterDescriptionBody,
},
#[state_machine_future(transitions(Finished))]
ReadReadyForQuery {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
parameters: ParameterDescriptionBody,
columns: RowDescriptionBody,
},
#[state_machine_future(ready)] #[state_machine_future(ready)]
Finished(Statement), Finished(Statement),
#[state_machine_future(error)] #[state_machine_future(error)]
Failed(Error), Failed(Error),
} }
impl PollPrepare for Prepare {
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
let state = state.take();
let receiver = state.receiver?;
transition!(ReadParseComplete {
sender: state.sender,
receiver,
name: state.name,
})
}
fn poll_read_parse_complete<'a>(
state: &'a mut RentToOwn<'a, ReadParseComplete>,
) -> Poll<AfterReadParseComplete, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ParseComplete) => transition!(ReadParameterDescription {
sender: state.sender,
receiver: state.receiver,
name: state.name,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_parameter_description<'a>(
state: &'a mut RentToOwn<'a, ReadParameterDescription>,
) -> Poll<AfterReadParameterDescription, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription {
sender: state.sender,
receiver: state.receiver,
name: state.name,
parameters: body,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_row_description<'a>(
state: &'a mut RentToOwn<'a, ReadRowDescription>,
) -> Poll<AfterReadRowDescription, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::RowDescription(body)) => transition!(ReadReadyForQuery {
sender: state.sender,
receiver: state.receiver,
name: state.name,
parameters: state.parameters,
columns: body,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_ready_for_query<'a>(
state: &'a mut RentToOwn<'a, ReadReadyForQuery>,
) -> Poll<AfterReadReadyForQuery, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ReadyForQuery(_)) => {
// FIXME handle custom types
let parameters = state
.parameters
.parameters()
.map(|oid| Type::from_oid(oid).unwrap())
.collect()?;
let columns = state
.columns
.fields()
.map(|f| {
Column::new(f.name().to_string(), Type::from_oid(f.type_oid()).unwrap())
})
.collect()?;
transition!(Finished(Statement::new(
state.sender,
state.name,
parameters,
columns
)))
}
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
}
impl PrepareFuture {
pub fn new(
sender: mpsc::UnboundedSender<Request>,
receiver: Result<mpsc::Receiver<Message>, Error>,
name: String,
) -> PrepareFuture {
Prepare::start(sender, receiver, name)
}
}

View File

@ -16,8 +16,9 @@ impl Drop for Statement {
fn drop(&mut self) { fn drop(&mut self) {
let mut buf = vec![]; let mut buf = vec![];
frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid"); frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid");
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0); let (sender, _) = mpsc::channel(0);
self.sender.unbounded_send(Request { let _ = self.sender.unbounded_send(Request {
messages: buf, messages: buf,
sender, sender,
}); });
@ -26,7 +27,7 @@ impl Drop for Statement {
impl Statement { impl Statement {
pub fn new( pub fn new(
sender: mpsc::UnboundedReceiver<Request>, sender: mpsc::UnboundedSender<Request>,
name: String, name: String,
params: Vec<Type>, params: Vec<Type>,
columns: Vec<Column>, columns: Vec<Column>,

View File

@ -0,0 +1,37 @@
extern crate env_logger;
extern crate tokio;
extern crate tokio_postgres;
use tokio::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::types::Type;
#[test]
fn pipelined_prepare() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap());
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare1 = client.prepare("SELECT 1::BIGINT WHERE $1::BOOL");
let prepare2 = client.prepare("SELECT ''::TEXT, 1::FLOAT4 WHERE $1::VARCHAR IS NOT NULL");
let prepare = prepare1.join(prepare2);
let (statement1, statement2) = runtime.block_on(prepare).unwrap();
assert_eq!(statement1.params(), &[Type::BOOL]);
assert_eq!(statement1.columns().len(), 1);
assert_eq!(statement1.columns()[0].type_(), &Type::INT8);
assert_eq!(statement2.params(), &[Type::VARCHAR]);
assert_eq!(statement2.columns().len(), 2);
assert_eq!(statement2.columns()[0].type_(), &Type::TEXT);
assert_eq!(statement2.columns()[1].type_(), &Type::FLOAT4);
drop(statement1);
drop(statement2);
drop(client);
runtime.run().unwrap();
}