Working statement preparation

This commit is contained in:
Steven Fackler 2018-06-18 22:34:25 -04:00
parent 0d0435fc2e
commit 13fcea7ae2
10 changed files with 336 additions and 32 deletions

View File

@ -14,3 +14,5 @@ tokio-uds = { git = "https://github.com/sfackler/tokio" }
tokio-io = { git = "https://github.com/sfackler/tokio" }
tokio-timer = { git = "https://github.com/sfackler/tokio" }
tokio-codec = { git = "https://github.com/sfackler/tokio" }
tokio-reactor = { git = "https://github.com/sfackler/tokio" }
tokio-executor = { git = "https://github.com/sfackler/tokio" }

View File

@ -2,8 +2,10 @@
use std::error::Error;
use std::mem;
use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use error;
use params::url::Url;
mod url;
@ -96,6 +98,14 @@ impl ConnectParams {
}
}
impl FromStr for ConnectParams {
type Err = error::Error;
fn from_str(s: &str) -> Result<ConnectParams, error::Error> {
s.into_connect_params().map_err(error::connect)
}
}
/// A builder for `ConnectParams`.
pub struct Builder {
port: u16,

View File

@ -37,6 +37,7 @@ fallible-iterator = "0.1.3"
futures = "0.1.7"
futures-cpupool = "0.1"
lazy_static = "1.0"
log = "0.4"
postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
postgres-shared = { version = "0.4.0", path = "../postgres-shared" }
state_machine_future = "0.1.7"
@ -48,3 +49,7 @@ want = "0.0.5"
[target.'cfg(unix)'.dependencies]
tokio-uds = "0.2"
[dev-dependencies]
tokio = "0.1.7"
env_logger = "0.5"

View File

@ -14,6 +14,8 @@ extern crate futures;
#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate log;
#[macro_use]
extern crate state_machine_future;
#[cfg(unix)]
@ -21,6 +23,7 @@ extern crate tokio_uds;
use futures::{Async, Future, Poll};
use std::io;
use std::sync::atomic::{AtomicUsize, Ordering};
#[doc(inline)]
pub use postgres_shared::stmt::Column;
@ -31,9 +34,12 @@ pub use postgres_shared::{CancelData, Notification};
use error::Error;
use params::ConnectParams;
use types::Type;
mod proto;
static NEXT_STATEMENT_ID: AtomicUsize = AtomicUsize::new(0);
fn bad_response() -> Error {
Error::from(io::Error::new(
io::ErrorKind::InvalidInput,
@ -48,13 +54,13 @@ fn disconnected() -> Error {
))
}
pub fn connect(params: ConnectParams) -> Handshake {
Handshake(proto::HandshakeFuture::new(params))
}
pub struct Client(proto::Client);
impl Client {
pub fn connect(params: ConnectParams) -> Handshake {
Handshake(proto::HandshakeFuture::new(params))
}
/// Polls to to determine whether the connection is ready to send new requests to the backend.
///
/// Requests are unboundedly buffered to enable pipelining, but this risks unbounded memory consumption if requests
@ -64,6 +70,15 @@ impl Client {
pub fn poll_ready(&mut self) -> Poll<(), Error> {
self.0.poll_ready()
}
pub fn prepare(&mut self, query: &str) -> Prepare {
self.prepare_typed(query, &[])
}
pub fn prepare_typed(&mut self, query: &str, param_types: &[Type]) -> Prepare {
let name = format!("s{}", NEXT_STATEMENT_ID.fetch_add(1, Ordering::SeqCst));
Prepare(self.0.prepare(name, query, param_types))
}
}
pub struct Connection(proto::Connection);
@ -99,3 +114,28 @@ impl Future for Handshake {
Ok(Async::Ready((Client(client), Connection(connection))))
}
}
pub struct Prepare(proto::PrepareFuture);
impl Future for Prepare {
type Item = Statement;
type Error = Error;
fn poll(&mut self) -> Poll<Statement, Error> {
let statement = try_ready!(self.0.poll());
Ok(Async::Ready(Statement(statement)))
}
}
pub struct Statement(proto::Statement);
impl Statement {
pub fn params(&self) -> &[Type] {
self.0.params()
}
pub fn columns(&self) -> &[Column] {
self.0.columns()
}
}

View File

@ -1,11 +1,14 @@
use futures::sync::mpsc;
use futures::Poll;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use want::Giver;
use disconnected;
use error::Error;
use proto::connection::Request;
use proto::prepare::PrepareFuture;
use types::Type;
pub struct Client {
sender: mpsc::UnboundedSender<Request>,
@ -21,7 +24,18 @@ impl Client {
self.giver.poll_want().map_err(|_| disconnected())
}
pub fn send(&mut self, messages: Vec<u8>) -> Result<mpsc::Receiver<Message>, Error> {
pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
let mut buf = vec![];
let receiver = frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), &mut buf)
.and_then(|()| frontend::describe(b'S', &name, &mut buf))
.and_then(|()| Ok(frontend::sync(&mut buf)))
.map_err(Into::into)
.and_then(|()| self.send(buf));
PrepareFuture::new(self.sender.clone(), receiver, name)
}
fn send(&mut self, messages: Vec<u8>) -> Result<mpsc::Receiver<Message>, Error> {
let (sender, receiver) = mpsc::channel(0);
self.giver.give();
self.sender

View File

@ -7,6 +7,7 @@ use std::io;
use tokio_codec::Framed;
use want::Taker;
use disconnected;
use error::{self, Error};
use proto::codec::PostgresCodec;
use proto::socket::Socket;
@ -17,7 +18,7 @@ pub struct Request {
pub sender: mpsc::Sender<Message>,
}
#[derive(PartialEq)]
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
@ -67,22 +68,28 @@ impl Connection {
fn poll_response(&mut self) -> Poll<Option<Message>, io::Error> {
if let Some(message) = self.pending_response.take() {
trace!("retrying pending response");
return Ok(Async::Ready(Some(message)));
}
self.stream.poll()
}
fn poll_read(&mut self) -> Result<(), Error> {
fn poll_read(&mut self) -> Poll<(), Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(Async::Ready(()));
}
loop {
let message = match self.poll_response()? {
Async::Ready(Some(message)) => message,
Async::Ready(None) => {
return Err(Error::from(io::Error::from(io::ErrorKind::UnexpectedEof)));
return Err(disconnected());
}
Async::NotReady => {
self.taker.want();
return Ok(());
trace!("poll_read: waiting on response");
return Ok(Async::NotReady);
}
};
@ -123,7 +130,8 @@ impl Connection {
Ok(AsyncSink::NotReady(message)) => {
self.responses.push_front(sender);
self.pending_response = Some(message);
return Ok(());
trace!("poll_read: waiting on socket");
return Ok(Async::NotReady);
}
}
}
@ -131,57 +139,86 @@ impl Connection {
fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> {
if let Some(message) = self.pending_request.take() {
trace!("retrying pending request");
return Ok(Async::Ready(Some(message)));
}
match self.receiver.poll() {
Ok(Async::Ready(Some(request))) => {
match try_receive!(self.receiver.poll()) {
Some(request) => {
trace!("polled new request");
self.responses.push_back(request.sender);
Ok(Async::Ready(Some(request.messages)))
}
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
None => Ok(Async::Ready(None)),
}
}
fn poll_write(&mut self) -> Result<(), Error> {
fn poll_write(&mut self) -> Poll<(), Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(Async::Ready(()));
}
let request = match self.poll_request()? {
Async::Ready(Some(request)) => request,
Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
request
}
Async::Ready(None) => return Ok(()),
Async::NotReady => return Ok(()),
Async::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len(),
);
return Ok(Async::Ready(()));
}
Async::NotReady => {
trace!("poll_write: waiting on request");
self.taker.want();
return Ok(Async::NotReady);
}
};
match self.stream.start_send(request)? {
AsyncSink::Ready => {
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
AsyncSink::NotReady(request) => {
trace!("poll_write: waiting on socket");
self.pending_request = Some(request);
return Ok(());
return Ok(Async::NotReady);
}
}
}
}
fn poll_flush(&mut self) -> Result<(), Error> {
self.stream.poll_complete()?;
Ok(())
fn poll_flush(&mut self) -> Poll<(), Error> {
trace!("flushing");
self.stream.poll_complete().map_err(Into::into)
}
fn poll_shutdown(&mut self) -> Poll<(), Error> {
match self.state {
State::Active | State::Terminating => Ok(Async::NotReady),
State::Closing => self.stream.close().map_err(Into::into),
if self.state != State::Closing {
return Ok(Async::NotReady);
}
match self.stream.close() {
Ok(Async::Ready(())) => {
trace!("poll_shutdown: complete");
Ok(Async::Ready(()))
}
Ok(Async::NotReady) => {
trace!("poll_shutdown: waiting on socket");
Ok(Async::NotReady)
}
Err(e) => Err(Error::from(e)),
}
}
}

View File

@ -1,3 +1,13 @@
macro_rules! try_receive {
($e:expr) => {
match $e {
Ok(::futures::Async::Ready(v)) => v,
Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady),
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
}
};
}
mod client;
mod codec;
mod connection;
@ -12,4 +22,4 @@ pub use proto::connection::Connection;
pub use proto::handshake::HandshakeFuture;
pub use proto::prepare::PrepareFuture;
pub use proto::socket::Socket;
pub use statement::Statement;
pub use proto::statement::Statement;

View File

@ -1,20 +1,168 @@
use fallible_iterator::FallibleIterator;
use futures::sync::mpsc;
use postgres_protocol::message::backend::Message;
use futures::{Poll, Stream};
use postgres_protocol::message::backend::{Message, ParameterDescriptionBody, RowDescriptionBody};
use state_machine_future::RentToOwn;
use error::Error;
use error::{self, Error};
use proto::connection::Request;
use proto::statement::Statement;
use types::Type;
use Column;
use {bad_response, disconnected};
#[derive(StateMachineFuture)]
pub enum Prepare {
#[state_machine_future(start)]
#[state_machine_future(start, transitions(ReadParseComplete))]
Start {
sender: mpsc::UnboundedSender<Request>,
receiver: Result<mpsc::Receiver<Message>, Error>,
name: String,
},
#[state_machine_future(transitions(ReadParameterDescription))]
ReadParseComplete {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
},
#[state_machine_future(transitions(ReadRowDescription))]
ReadParameterDescription {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
},
#[state_machine_future(transitions(ReadReadyForQuery))]
ReadRowDescription {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
parameters: ParameterDescriptionBody,
},
#[state_machine_future(transitions(Finished))]
ReadReadyForQuery {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
parameters: ParameterDescriptionBody,
columns: RowDescriptionBody,
},
#[state_machine_future(ready)]
Finished(Statement),
#[state_machine_future(error)]
Failed(Error),
}
impl PollPrepare for Prepare {
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
let state = state.take();
let receiver = state.receiver?;
transition!(ReadParseComplete {
sender: state.sender,
receiver,
name: state.name,
})
}
fn poll_read_parse_complete<'a>(
state: &'a mut RentToOwn<'a, ReadParseComplete>,
) -> Poll<AfterReadParseComplete, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ParseComplete) => transition!(ReadParameterDescription {
sender: state.sender,
receiver: state.receiver,
name: state.name,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_parameter_description<'a>(
state: &'a mut RentToOwn<'a, ReadParameterDescription>,
) -> Poll<AfterReadParameterDescription, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription {
sender: state.sender,
receiver: state.receiver,
name: state.name,
parameters: body,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_row_description<'a>(
state: &'a mut RentToOwn<'a, ReadRowDescription>,
) -> Poll<AfterReadRowDescription, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::RowDescription(body)) => transition!(ReadReadyForQuery {
sender: state.sender,
receiver: state.receiver,
name: state.name,
parameters: state.parameters,
columns: body,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_ready_for_query<'a>(
state: &'a mut RentToOwn<'a, ReadReadyForQuery>,
) -> Poll<AfterReadReadyForQuery, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ReadyForQuery(_)) => {
// FIXME handle custom types
let parameters = state
.parameters
.parameters()
.map(|oid| Type::from_oid(oid).unwrap())
.collect()?;
let columns = state
.columns
.fields()
.map(|f| {
Column::new(f.name().to_string(), Type::from_oid(f.type_oid()).unwrap())
})
.collect()?;
transition!(Finished(Statement::new(
state.sender,
state.name,
parameters,
columns
)))
}
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
}
impl PrepareFuture {
pub fn new(
sender: mpsc::UnboundedSender<Request>,
receiver: Result<mpsc::Receiver<Message>, Error>,
name: String,
) -> PrepareFuture {
Prepare::start(sender, receiver, name)
}
}

View File

@ -16,8 +16,9 @@ impl Drop for Statement {
fn drop(&mut self) {
let mut buf = vec![];
frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid");
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0);
self.sender.unbounded_send(Request {
let _ = self.sender.unbounded_send(Request {
messages: buf,
sender,
});
@ -26,7 +27,7 @@ impl Drop for Statement {
impl Statement {
pub fn new(
sender: mpsc::UnboundedReceiver<Request>,
sender: mpsc::UnboundedSender<Request>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,

View File

@ -0,0 +1,37 @@
extern crate env_logger;
extern crate tokio;
extern crate tokio_postgres;
use tokio::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::types::Type;
#[test]
fn pipelined_prepare() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect("postgres://postgres@localhost:5433".parse().unwrap());
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare1 = client.prepare("SELECT 1::BIGINT WHERE $1::BOOL");
let prepare2 = client.prepare("SELECT ''::TEXT, 1::FLOAT4 WHERE $1::VARCHAR IS NOT NULL");
let prepare = prepare1.join(prepare2);
let (statement1, statement2) = runtime.block_on(prepare).unwrap();
assert_eq!(statement1.params(), &[Type::BOOL]);
assert_eq!(statement1.columns().len(), 1);
assert_eq!(statement1.columns()[0].type_(), &Type::INT8);
assert_eq!(statement2.params(), &[Type::VARCHAR]);
assert_eq!(statement2.columns().len(), 2);
assert_eq!(statement2.columns()[0].type_(), &Type::TEXT);
assert_eq!(statement2.columns()[1].type_(), &Type::FLOAT4);
drop(statement1);
drop(statement2);
drop(client);
runtime.run().unwrap();
}