Implement prepare

This commit is contained in:
Steven Fackler 2019-07-23 19:54:22 -07:00
parent 2480fefd2c
commit f9e46510ba
6 changed files with 226 additions and 5 deletions

View File

@ -34,7 +34,7 @@ with-serde_json-1 = ["serde-1", "serde_json-1"]
antidote = "1.0" antidote = "1.0"
bytes = "0.4" bytes = "0.4"
fallible-iterator = "0.2" fallible-iterator = "0.2"
futures-preview = "0.3.0-alpha.17" futures-preview = { version = "0.3.0-alpha.17", features = ["nightly", "async-await"] }
log = "0.4" log = "0.4"
percent-encoding = "1.0" percent-encoding = "1.0"
phf = "0.7.23" phf = "0.7.23"

View File

@ -1,8 +1,60 @@
use crate::connection::Request; use crate::codec::BackendMessages;
use crate::connection::{Request, RequestMessages};
use crate::prepare::prepare;
use crate::types::Type;
use crate::{Error, Statement};
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::{Stream, StreamExt};
use postgres_protocol::message::backend::Message;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub struct Responses {
receiver: mpsc::Receiver<BackendMessages>,
cur: BackendMessages,
}
impl Responses {
pub async fn next(&mut self) -> Result<Message, Error> {
loop {
match self.cur.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(message) => return Ok(message),
None => {}
}
match self.receiver.next().await {
Some(messages) => self.cur = messages,
None => return Err(Error::closed()),
}
}
}
}
pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
}
impl InnerClient {
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
let (sender, receiver) = mpsc::channel(1);
let request = Request { messages, sender };
self.sender
.unbounded_send(request)
.map_err(|_| Error::closed())?;
Ok(Responses {
receiver,
cur: BackendMessages::empty(),
})
}
}
pub struct Client { pub struct Client {
sender: mpsc::UnboundedSender<Request>, inner: Arc<InnerClient>,
process_id: i32, process_id: i32,
secret_key: i32, secret_key: i32,
} }
@ -14,9 +66,28 @@ impl Client {
secret_key: i32, secret_key: i32,
) -> Client { ) -> Client {
Client { Client {
sender, inner: Arc::new(InnerClient { sender }),
process_id, process_id,
secret_key, secret_key,
} }
} }
pub(crate) fn inner(&self) -> Arc<InnerClient> {
self.inner.clone()
}
pub fn prepare<'a>(
&mut self,
query: &'a str,
) -> impl Future<Output = Result<Statement, Error>> + 'a {
self.prepare_typed(query, &[])
}
pub fn prepare_typed<'a>(
&mut self,
query: &'a str,
parameter_types: &'a [Type],
) -> impl Future<Output = Result<Statement, Error>> + 'a {
prepare(self.inner(), query, parameter_types)
}
} }

View File

@ -122,6 +122,7 @@ pub use crate::socket::Socket;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect; use crate::tls::MakeTlsConnect;
pub use crate::tls::NoTls; pub use crate::tls::NoTls;
pub use statement::{Column, Statement};
mod client; mod client;
mod codec; mod codec;
@ -135,8 +136,10 @@ mod connect_tls;
mod connection; mod connection;
pub mod error; pub mod error;
mod maybe_tls_stream; mod maybe_tls_stream;
mod prepare;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
mod socket; mod socket;
mod statement;
pub mod tls; pub mod tls;
pub mod types; pub mod types;

View File

@ -0,0 +1,71 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::{Request, RequestMessages};
use crate::types::{Oid, Type};
use crate::{Column, Error, Statement};
use fallible_iterator::FallibleIterator;
use futures::StreamExt;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn prepare(
client: Arc<InnerClient>,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let mut buf = vec![];
frontend::parse(&name, query, types.iter().map(Type::oid), &mut buf).map_err(Error::encode)?;
frontend::describe(b'S', &name, &mut buf).map_err(Error::encode)?;
frontend::sync(&mut buf);
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(&client, oid).await?;
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(&client, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_);
columns.push(column);
}
}
Ok(Statement::new(&client, name, parameters, columns))
}
async fn get_type(client: &InnerClient, oid: Oid) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
unimplemented!()
}

View File

@ -0,0 +1,59 @@
use crate::client::InnerClient;
use crate::connection::Request;
use crate::types::Type;
use std::sync::{Arc, Weak};
pub struct Statement {
client: Weak<InnerClient>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
}
impl Statement {
pub(crate) fn new(
inner: &Arc<InnerClient>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
Statement {
client: Arc::downgrade(inner),
name,
params,
columns,
}
}
/// Returns the expected types of the statement's parameters.
pub fn params(&self) -> &[Type] {
&self.params
}
/// Returns information about the columns returned when the statement is queried.
pub fn columns(&self) -> &[Column] {
&self.columns
}
}
#[derive(Debug)]
pub struct Column {
name: String,
type_: Type,
}
impl Column {
pub(crate) fn new(name: String, type_: Type) -> Column {
Column { name, type_ }
}
/// Returns the name of the column.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the type of the column.
pub fn type_(&self) -> &Type {
&self.type_
}
}

View File

@ -1,11 +1,12 @@
#![warn(rust_2018_idioms)] #![warn(rust_2018_idioms)]
#![feature(async_await)] #![feature(async_await)]
use futures::FutureExt; use futures::{try_join, FutureExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_postgres::error::SqlState; use tokio_postgres::error::SqlState;
use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::tls::{NoTls, NoTlsStream};
use tokio_postgres::{Client, Config, Connection, Error}; use tokio_postgres::{Client, Config, Connection, Error};
use tokio_postgres::types::Type;
mod parse; mod parse;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -95,6 +96,22 @@ async fn scram_password_ok() {
connect("user=scram_user password=password dbname=postgres").await; connect("user=scram_user password=password dbname=postgres").await;
} }
#[tokio::test]
async fn pipelined_prepare() {
let mut client = connect("user=postgres").await;
let prepare1 = client.prepare("SELECT $1::TEXT");
let prepare2 = client.prepare("SELECT $1::BIGINT");
let (statement1, statement2) = try_join!(prepare1, prepare2).unwrap();
assert_eq!(statement1.params()[0], Type::TEXT);
assert_eq!(statement1.columns()[0].type_(), &Type::TEXT);
assert_eq!(statement2.params()[0], Type::INT8);
assert_eq!(statement2.columns()[0].type_(), &Type::INT8);
}
/* /*
#[test] #[test]
fn pipelined_prepare() { fn pipelined_prepare() {