Support portals

This commit is contained in:
Steven Fackler 2019-08-01 20:43:13 -07:00
parent e4a1ec23a1
commit 26a17ac4ed
6 changed files with 225 additions and 10 deletions

View File

@ -0,0 +1,45 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::ToSql;
use crate::{query, Error, Portal, Statement};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn bind(
client: Arc<InnerClient>,
statement: Statement,
bind: Result<PendingBind, Error>,
) -> Result<Portal, Error> {
let bind = bind?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(bind.buf)))?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(Portal::new(&client, bind.name, statement))
}
pub struct PendingBind {
buf: Vec<u8>,
name: String,
}
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<PendingBind, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let mut buf = query::encode_bind(statement, params, &name)?;
frontend::sync(&mut buf);
Ok(PendingBind { buf, name })
}

View File

@ -113,6 +113,7 @@ pub use crate::config::Config;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::portal::Portal;
pub use crate::row::{Row, SimpleQueryRow};
#[cfg(feature = "runtime")]
pub use crate::socket::Socket;
@ -122,6 +123,7 @@ pub use crate::tls::NoTls;
pub use crate::transaction::Transaction;
pub use statement::{Column, Statement};
mod bind;
#[cfg(feature = "runtime")]
mod cancel_query;
mod cancel_query_raw;
@ -139,6 +141,7 @@ mod copy_in;
mod copy_out;
pub mod error;
mod maybe_tls_stream;
mod portal;
mod prepare;
mod query;
pub mod row;

View File

@ -0,0 +1,48 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::Statement;
use postgres_protocol::message::frontend;
use std::sync::{Arc, Weak};
struct Inner {
client: Weak<InnerClient>,
name: String,
statement: Statement,
}
impl Drop for Inner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let mut buf = vec![];
frontend::close(b'P', &self.name, &mut buf).expect("portal name not valid");
frontend::sync(&mut buf);
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
}
/// A portal.
///
/// Portals can only be used with the connection that created them, and only exist for the duration of the transaction
/// in which they were created.
#[derive(Clone)]
pub struct Portal(Arc<Inner>);
impl Portal {
pub(crate) fn new(client: &Arc<InnerClient>, name: String, statement: Statement) -> Portal {
Portal(Arc::new(Inner {
client: Arc::downgrade(client),
name,
statement,
}))
}
pub(crate) fn name(&self) -> &str {
&self.0.name
}
pub(crate) fn statement(&self) -> &Statement {
&self.0.statement
}
}

View File

@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::{IsNull, ToSql};
use crate::{Error, Row, Statement};
use crate::{Error, Portal, Row, Statement};
use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
@ -23,6 +23,27 @@ pub fn query(
.try_flatten_stream()
}
pub fn query_portal(
client: Arc<InnerClient>,
portal: Portal,
max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> {
let start = async move {
let mut buf = vec![];
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
frontend::sync(&mut buf);
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(Query {
statement: portal.statement().clone(),
responses,
})
};
start.try_flatten_stream()
}
pub async fn execute(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<u64, Error> {
let mut responses = start(client, buf).await?;
@ -59,6 +80,18 @@ async fn start(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<
}
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<Vec<u8>, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let mut buf = encode_bind(statement, params, "")?;
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
frontend::sync(&mut buf);
Ok(buf)
}
pub fn encode_bind<'a, I>(statement: &Statement, params: I, portal: &str) -> Result<Vec<u8>, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
@ -76,7 +109,7 @@ where
let mut error_idx = 0;
let r = frontend::bind(
"",
portal,
statement.name(),
Some(1),
params.zip(statement.params()).enumerate(),
@ -92,15 +125,10 @@ where
&mut buf,
);
match r {
Ok(()) => {}
Ok(()) => Ok(buf),
Err(frontend::BindError::Conversion(e)) => return Err(Error::to_sql(e, error_idx)),
Err(frontend::BindError::Serialization(e)) => return Err(Error::encode(e)),
}
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
frontend::sync(&mut buf);
Ok(buf)
}
struct Query {
@ -116,7 +144,9 @@ impl Stream for Query {
Message::DataRow(body) => {
Poll::Ready(Some(Ok(Row::new(self.statement.clone(), body)?)))
}
Message::EmptyQueryResponse | Message::CommandComplete(_) => Poll::Ready(None),
Message::EmptyQueryResponse
| Message::CommandComplete(_)
| Message::PortalSuspended => Poll::Ready(None),
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
}

View File

@ -6,7 +6,7 @@ use crate::tls::TlsConnect;
use crate::types::{ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{query, Client, Error, Row, SimpleQueryMessage, Statement};
use crate::{bind, query, Client, Error, Portal, Row, SimpleQueryMessage, Statement};
use bytes::{Bytes, IntoBuf};
use futures::{Stream, TryStream};
use postgres_protocol::message::frontend;
@ -122,6 +122,52 @@ impl<'a> Transaction<'a> {
query::execute(self.client.inner(), buf)
}
/// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
///
/// Portals only last for the duration of the transaction in which they are created, and can only be used on the
/// connection that created them.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn bind(
&mut self,
statement: &Statement,
params: &[&dyn ToSql],
) -> impl Future<Output = Result<Portal, Error>> {
// https://github.com/rust-lang/rust/issues/63032
let buf = bind::encode(statement, params.iter().cloned());
bind::bind(self.client.inner(), statement.clone(), buf)
}
/// Like [`bind`], but takes an iterator of parameters rather than a slice.
///
/// [`bind`]: #method.bind
pub fn bind_iter<'b, I>(
&mut self,
statement: &Statement,
params: I,
) -> impl Future<Output = Result<Portal, Error>>
where
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let buf = bind::encode(statement, params);
bind::bind(self.client.inner(), statement.clone(), buf)
}
/// Continues execution of a portal, returning a stream of the resulting rows.
///
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
pub fn query_portal(
&mut self,
portal: &Portal,
max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> {
query::query_portal(self.client.inner(), portal.clone(), max_rows)
}
/// Like `Client::copy_in`.
pub fn copy_in<S>(
&mut self,

View File

@ -569,6 +569,49 @@ async fn notifications() {
assert_eq!(notifications[1].payload(), "world");
}
#[tokio::test]
async fn query_portal() {
let mut client = connect("user=postgres").await;
client
.batch_execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL,
name TEXT
);
INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');",
)
.await
.unwrap();
let stmt = client
.prepare("SELECT id, name FROM foo ORDER BY id")
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
let portal = transaction.bind(&stmt, &[]).await.unwrap();
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();
assert_eq!(r1.len(), 2);
assert_eq!(r1[0].get::<_, i32>(0), 1);
assert_eq!(r1[0].get::<_, &str>(1), "alice");
assert_eq!(r1[1].get::<_, i32>(0), 2);
assert_eq!(r1[1].get::<_, &str>(1), "bob");
assert_eq!(r2.len(), 1);
assert_eq!(r2[0].get::<_, i32>(0), 3);
assert_eq!(r2[0].get::<_, &str>(1), "charlie");
assert_eq!(r3.len(), 0);
}
/*
#[test]
fn query_portal() {