Implement prepare
This commit is contained in:
parent
2480fefd2c
commit
f9e46510ba
@ -34,7 +34,7 @@ with-serde_json-1 = ["serde-1", "serde_json-1"]
|
||||
antidote = "1.0"
|
||||
bytes = "0.4"
|
||||
fallible-iterator = "0.2"
|
||||
futures-preview = "0.3.0-alpha.17"
|
||||
futures-preview = { version = "0.3.0-alpha.17", features = ["nightly", "async-await"] }
|
||||
log = "0.4"
|
||||
percent-encoding = "1.0"
|
||||
phf = "0.7.23"
|
||||
|
@ -1,8 +1,60 @@
|
||||
use crate::connection::Request;
|
||||
use crate::codec::BackendMessages;
|
||||
use crate::connection::{Request, RequestMessages};
|
||||
use crate::prepare::prepare;
|
||||
use crate::types::Type;
|
||||
use crate::{Error, Statement};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{Stream, StreamExt};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
pub struct Responses {
|
||||
receiver: mpsc::Receiver<BackendMessages>,
|
||||
cur: BackendMessages,
|
||||
}
|
||||
|
||||
impl Responses {
|
||||
pub async fn next(&mut self) -> Result<Message, Error> {
|
||||
loop {
|
||||
match self.cur.next().map_err(Error::parse)? {
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(message) => return Ok(message),
|
||||
None => {}
|
||||
}
|
||||
|
||||
match self.receiver.next().await {
|
||||
Some(messages) => self.cur = messages,
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InnerClient {
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
}
|
||||
|
||||
impl InnerClient {
|
||||
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
|
||||
let (sender, receiver) = mpsc::channel(1);
|
||||
let request = Request { messages, sender };
|
||||
self.sender
|
||||
.unbounded_send(request)
|
||||
.map_err(|_| Error::closed())?;
|
||||
|
||||
Ok(Responses {
|
||||
receiver,
|
||||
cur: BackendMessages::empty(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Client {
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
inner: Arc<InnerClient>,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
}
|
||||
@ -14,9 +66,28 @@ impl Client {
|
||||
secret_key: i32,
|
||||
) -> Client {
|
||||
Client {
|
||||
sender,
|
||||
inner: Arc::new(InnerClient { sender }),
|
||||
process_id,
|
||||
secret_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> Arc<InnerClient> {
|
||||
self.inner.clone()
|
||||
}
|
||||
|
||||
pub fn prepare<'a>(
|
||||
&mut self,
|
||||
query: &'a str,
|
||||
) -> impl Future<Output = Result<Statement, Error>> + 'a {
|
||||
self.prepare_typed(query, &[])
|
||||
}
|
||||
|
||||
pub fn prepare_typed<'a>(
|
||||
&mut self,
|
||||
query: &'a str,
|
||||
parameter_types: &'a [Type],
|
||||
) -> impl Future<Output = Result<Statement, Error>> + 'a {
|
||||
prepare(self.inner(), query, parameter_types)
|
||||
}
|
||||
}
|
||||
|
@ -122,6 +122,7 @@ pub use crate::socket::Socket;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::tls::MakeTlsConnect;
|
||||
pub use crate::tls::NoTls;
|
||||
pub use statement::{Column, Statement};
|
||||
|
||||
mod client;
|
||||
mod codec;
|
||||
@ -135,8 +136,10 @@ mod connect_tls;
|
||||
mod connection;
|
||||
pub mod error;
|
||||
mod maybe_tls_stream;
|
||||
mod prepare;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod socket;
|
||||
mod statement;
|
||||
pub mod tls;
|
||||
pub mod types;
|
||||
|
||||
|
71
tokio-postgres/src/prepare.rs
Normal file
71
tokio-postgres/src/prepare.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::{Request, RequestMessages};
|
||||
use crate::types::{Oid, Type};
|
||||
use crate::{Column, Error, Statement};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::StreamExt;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
pub async fn prepare(
|
||||
client: Arc<InnerClient>,
|
||||
query: &str,
|
||||
types: &[Type],
|
||||
) -> Result<Statement, Error> {
|
||||
let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
|
||||
|
||||
let mut buf = vec![];
|
||||
frontend::parse(&name, query, types.iter().map(Type::oid), &mut buf).map_err(Error::encode)?;
|
||||
frontend::describe(b'S', &name, &mut buf).map_err(Error::encode)?;
|
||||
frontend::sync(&mut buf);
|
||||
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
Message::ParseComplete => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
let parameter_description = match responses.next().await? {
|
||||
Message::ParameterDescription(body) => body,
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
let row_description = match responses.next().await? {
|
||||
Message::RowDescription(body) => Some(body),
|
||||
Message::NoData => None,
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
let mut parameters = vec![];
|
||||
let mut it = parameter_description.parameters();
|
||||
while let Some(oid) = it.next().map_err(Error::parse)? {
|
||||
let type_ = get_type(&client, oid).await?;
|
||||
parameters.push(type_);
|
||||
}
|
||||
|
||||
let mut columns = vec![];
|
||||
if let Some(row_description) = row_description {
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(Error::parse)? {
|
||||
let type_ = get_type(&client, field.type_oid()).await?;
|
||||
let column = Column::new(field.name().to_string(), type_);
|
||||
columns.push(column);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Statement::new(&client, name, parameters, columns))
|
||||
}
|
||||
|
||||
async fn get_type(client: &InnerClient, oid: Oid) -> Result<Type, Error> {
|
||||
if let Some(type_) = Type::from_oid(oid) {
|
||||
return Ok(type_);
|
||||
}
|
||||
|
||||
unimplemented!()
|
||||
}
|
59
tokio-postgres/src/statement.rs
Normal file
59
tokio-postgres/src/statement.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::connection::Request;
|
||||
use crate::types::Type;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
pub struct Statement {
|
||||
client: Weak<InnerClient>,
|
||||
name: String,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
}
|
||||
|
||||
impl Statement {
|
||||
pub(crate) fn new(
|
||||
inner: &Arc<InnerClient>,
|
||||
name: String,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
) -> Statement {
|
||||
Statement {
|
||||
client: Arc::downgrade(inner),
|
||||
name,
|
||||
params,
|
||||
columns,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the expected types of the statement's parameters.
|
||||
pub fn params(&self) -> &[Type] {
|
||||
&self.params
|
||||
}
|
||||
|
||||
/// Returns information about the columns returned when the statement is queried.
|
||||
pub fn columns(&self) -> &[Column] {
|
||||
&self.columns
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Column {
|
||||
name: String,
|
||||
type_: Type,
|
||||
}
|
||||
|
||||
impl Column {
|
||||
pub(crate) fn new(name: String, type_: Type) -> Column {
|
||||
Column { name, type_ }
|
||||
}
|
||||
|
||||
/// Returns the name of the column.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns the type of the column.
|
||||
pub fn type_(&self) -> &Type {
|
||||
&self.type_
|
||||
}
|
||||
}
|
@ -1,11 +1,12 @@
|
||||
#![warn(rust_2018_idioms)]
|
||||
#![feature(async_await)]
|
||||
|
||||
use futures::FutureExt;
|
||||
use futures::{try_join, FutureExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::tls::{NoTls, NoTlsStream};
|
||||
use tokio_postgres::{Client, Config, Connection, Error};
|
||||
use tokio_postgres::types::Type;
|
||||
|
||||
mod parse;
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -95,6 +96,22 @@ async fn scram_password_ok() {
|
||||
connect("user=scram_user password=password dbname=postgres").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pipelined_prepare() {
|
||||
let mut client = connect("user=postgres").await;
|
||||
|
||||
let prepare1 = client.prepare("SELECT $1::TEXT");
|
||||
let prepare2 = client.prepare("SELECT $1::BIGINT");
|
||||
|
||||
let (statement1, statement2) = try_join!(prepare1, prepare2).unwrap();
|
||||
|
||||
assert_eq!(statement1.params()[0], Type::TEXT);
|
||||
assert_eq!(statement1.columns()[0].type_(), &Type::TEXT);
|
||||
|
||||
assert_eq!(statement2.params()[0], Type::INT8);
|
||||
assert_eq!(statement2.columns()[0].type_(), &Type::INT8);
|
||||
}
|
||||
|
||||
/*
|
||||
#[test]
|
||||
fn pipelined_prepare() {
|
||||
|
Loading…
Reference in New Issue
Block a user