From f9e46510baf9955ab136ca062332e24a3b03ef74 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Tue, 23 Jul 2019 19:54:22 -0700 Subject: [PATCH] Implement prepare --- tokio-postgres/Cargo.toml | 2 +- tokio-postgres/src/client.rs | 77 +++++++++++++++++++++++++++++-- tokio-postgres/src/lib.rs | 3 ++ tokio-postgres/src/prepare.rs | 71 ++++++++++++++++++++++++++++ tokio-postgres/src/statement.rs | 59 +++++++++++++++++++++++ tokio-postgres/tests/test/main.rs | 19 +++++++- 6 files changed, 226 insertions(+), 5 deletions(-) create mode 100644 tokio-postgres/src/prepare.rs create mode 100644 tokio-postgres/src/statement.rs diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 7fbf32af..405517b9 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -34,7 +34,7 @@ with-serde_json-1 = ["serde-1", "serde_json-1"] antidote = "1.0" bytes = "0.4" fallible-iterator = "0.2" -futures-preview = "0.3.0-alpha.17" +futures-preview = { version = "0.3.0-alpha.17", features = ["nightly", "async-await"] } log = "0.4" percent-encoding = "1.0" phf = "0.7.23" diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 3bfd7e12..46fb60c4 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -1,8 +1,60 @@ -use crate::connection::Request; +use crate::codec::BackendMessages; +use crate::connection::{Request, RequestMessages}; +use crate::prepare::prepare; +use crate::types::Type; +use crate::{Error, Statement}; +use fallible_iterator::FallibleIterator; use futures::channel::mpsc; +use futures::{Stream, StreamExt}; +use postgres_protocol::message::backend::Message; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +pub struct Responses { + receiver: mpsc::Receiver, + cur: BackendMessages, +} + +impl Responses { + pub async fn next(&mut self) -> Result { + loop { + match self.cur.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(message) => return Ok(message), + None => {} + } + + match self.receiver.next().await { + Some(messages) => self.cur = messages, + None => return Err(Error::closed()), + } + } + } +} + +pub struct InnerClient { + sender: mpsc::UnboundedSender, +} + +impl InnerClient { + pub fn send(&self, messages: RequestMessages) -> Result { + let (sender, receiver) = mpsc::channel(1); + let request = Request { messages, sender }; + self.sender + .unbounded_send(request) + .map_err(|_| Error::closed())?; + + Ok(Responses { + receiver, + cur: BackendMessages::empty(), + }) + } +} pub struct Client { - sender: mpsc::UnboundedSender, + inner: Arc, process_id: i32, secret_key: i32, } @@ -14,9 +66,28 @@ impl Client { secret_key: i32, ) -> Client { Client { - sender, + inner: Arc::new(InnerClient { sender }), process_id, secret_key, } } + + pub(crate) fn inner(&self) -> Arc { + self.inner.clone() + } + + pub fn prepare<'a>( + &mut self, + query: &'a str, + ) -> impl Future> + 'a { + self.prepare_typed(query, &[]) + } + + pub fn prepare_typed<'a>( + &mut self, + query: &'a str, + parameter_types: &'a [Type], + ) -> impl Future> + 'a { + prepare(self.inner(), query, parameter_types) + } } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 68b03e6d..e894b2e2 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -122,6 +122,7 @@ pub use crate::socket::Socket; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; pub use crate::tls::NoTls; +pub use statement::{Column, Statement}; mod client; mod codec; @@ -135,8 +136,10 @@ mod connect_tls; mod connection; pub mod error; mod maybe_tls_stream; +mod prepare; #[cfg(feature = "runtime")] mod socket; +mod statement; pub mod tls; pub mod types; diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs new file mode 100644 index 00000000..9e243c63 --- /dev/null +++ b/tokio-postgres/src/prepare.rs @@ -0,0 +1,71 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::{Request, RequestMessages}; +use crate::types::{Oid, Type}; +use crate::{Column, Error, Statement}; +use fallible_iterator::FallibleIterator; +use futures::StreamExt; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + +pub async fn prepare( + client: Arc, + query: &str, + types: &[Type], +) -> Result { + let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + + let mut buf = vec![]; + frontend::parse(&name, query, types.iter().map(Type::oid), &mut buf).map_err(Error::encode)?; + frontend::describe(b'S', &name, &mut buf).map_err(Error::encode)?; + frontend::sync(&mut buf); + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = get_type(&client, oid).await?; + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(&client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_); + columns.push(column); + } + } + + Ok(Statement::new(&client, name, parameters, columns)) +} + +async fn get_type(client: &InnerClient, oid: Oid) -> Result { + if let Some(type_) = Type::from_oid(oid) { + return Ok(type_); + } + + unimplemented!() +} diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs new file mode 100644 index 00000000..b18592eb --- /dev/null +++ b/tokio-postgres/src/statement.rs @@ -0,0 +1,59 @@ +use crate::client::InnerClient; +use crate::connection::Request; +use crate::types::Type; +use std::sync::{Arc, Weak}; + +pub struct Statement { + client: Weak, + name: String, + params: Vec, + columns: Vec, +} + +impl Statement { + pub(crate) fn new( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement { + client: Arc::downgrade(inner), + name, + params, + columns, + } + } + + /// Returns the expected types of the statement's parameters. + pub fn params(&self) -> &[Type] { + &self.params + } + + /// Returns information about the columns returned when the statement is queried. + pub fn columns(&self) -> &[Column] { + &self.columns + } +} + +#[derive(Debug)] +pub struct Column { + name: String, + type_: Type, +} + +impl Column { + pub(crate) fn new(name: String, type_: Type) -> Column { + Column { name, type_ } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the column. + pub fn type_(&self) -> &Type { + &self.type_ + } +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 4c928d52..a31bef4a 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -1,11 +1,12 @@ #![warn(rust_2018_idioms)] #![feature(async_await)] -use futures::FutureExt; +use futures::{try_join, FutureExt}; use tokio::net::TcpStream; use tokio_postgres::error::SqlState; use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::{Client, Config, Connection, Error}; +use tokio_postgres::types::Type; mod parse; #[cfg(feature = "runtime")] @@ -95,6 +96,22 @@ async fn scram_password_ok() { connect("user=scram_user password=password dbname=postgres").await; } +#[tokio::test] +async fn pipelined_prepare() { + let mut client = connect("user=postgres").await; + + let prepare1 = client.prepare("SELECT $1::TEXT"); + let prepare2 = client.prepare("SELECT $1::BIGINT"); + + let (statement1, statement2) = try_join!(prepare1, prepare2).unwrap(); + + assert_eq!(statement1.params()[0], Type::TEXT); + assert_eq!(statement1.columns()[0].type_(), &Type::TEXT); + + assert_eq!(statement2.params()[0], Type::INT8); + assert_eq!(statement2.columns()[0].type_(), &Type::INT8); +} + /* #[test] fn pipelined_prepare() {