Support portals
This commit is contained in:
parent
e4a1ec23a1
commit
26a17ac4ed
45
tokio-postgres/src/bind.rs
Normal file
45
tokio-postgres/src/bind.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::ToSql;
|
||||
use crate::{query, Error, Portal, Statement};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
pub async fn bind(
|
||||
client: Arc<InnerClient>,
|
||||
statement: Statement,
|
||||
bind: Result<PendingBind, Error>,
|
||||
) -> Result<Portal, Error> {
|
||||
let bind = bind?;
|
||||
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(bind.buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
Message::BindComplete => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
Ok(Portal::new(&client, bind.name, statement))
|
||||
}
|
||||
|
||||
pub struct PendingBind {
|
||||
buf: Vec<u8>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<PendingBind, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a dyn ToSql>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
|
||||
let mut buf = query::encode_bind(statement, params, &name)?;
|
||||
frontend::sync(&mut buf);
|
||||
|
||||
Ok(PendingBind { buf, name })
|
||||
}
|
@ -113,6 +113,7 @@ pub use crate::config::Config;
|
||||
pub use crate::connection::Connection;
|
||||
use crate::error::DbError;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::portal::Portal;
|
||||
pub use crate::row::{Row, SimpleQueryRow};
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::socket::Socket;
|
||||
@ -122,6 +123,7 @@ pub use crate::tls::NoTls;
|
||||
pub use crate::transaction::Transaction;
|
||||
pub use statement::{Column, Statement};
|
||||
|
||||
mod bind;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod cancel_query;
|
||||
mod cancel_query_raw;
|
||||
@ -139,6 +141,7 @@ mod copy_in;
|
||||
mod copy_out;
|
||||
pub mod error;
|
||||
mod maybe_tls_stream;
|
||||
mod portal;
|
||||
mod prepare;
|
||||
mod query;
|
||||
pub mod row;
|
||||
|
48
tokio-postgres/src/portal.rs
Normal file
48
tokio-postgres/src/portal.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::Statement;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
struct Inner {
|
||||
client: Weak<InnerClient>,
|
||||
name: String,
|
||||
statement: Statement,
|
||||
}
|
||||
|
||||
impl Drop for Inner {
|
||||
fn drop(&mut self) {
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
let mut buf = vec![];
|
||||
frontend::close(b'P', &self.name, &mut buf).expect("portal name not valid");
|
||||
frontend::sync(&mut buf);
|
||||
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A portal.
|
||||
///
|
||||
/// Portals can only be used with the connection that created them, and only exist for the duration of the transaction
|
||||
/// in which they were created.
|
||||
#[derive(Clone)]
|
||||
pub struct Portal(Arc<Inner>);
|
||||
|
||||
impl Portal {
|
||||
pub(crate) fn new(client: &Arc<InnerClient>, name: String, statement: Statement) -> Portal {
|
||||
Portal(Arc::new(Inner {
|
||||
client: Arc::downgrade(client),
|
||||
name,
|
||||
statement,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
&self.0.name
|
||||
}
|
||||
|
||||
pub(crate) fn statement(&self) -> &Statement {
|
||||
&self.0.statement
|
||||
}
|
||||
}
|
@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::{IsNull, ToSql};
|
||||
use crate::{Error, Row, Statement};
|
||||
use crate::{Error, Portal, Row, Statement};
|
||||
use futures::{ready, Stream, TryFutureExt};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
@ -23,6 +23,27 @@ pub fn query(
|
||||
.try_flatten_stream()
|
||||
}
|
||||
|
||||
pub fn query_portal(
|
||||
client: Arc<InnerClient>,
|
||||
portal: Portal,
|
||||
max_rows: i32,
|
||||
) -> impl Stream<Item = Result<Row, Error>> {
|
||||
let start = async move {
|
||||
let mut buf = vec![];
|
||||
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
|
||||
frontend::sync(&mut buf);
|
||||
|
||||
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
Ok(Query {
|
||||
statement: portal.statement().clone(),
|
||||
responses,
|
||||
})
|
||||
};
|
||||
|
||||
start.try_flatten_stream()
|
||||
}
|
||||
|
||||
pub async fn execute(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<u64, Error> {
|
||||
let mut responses = start(client, buf).await?;
|
||||
|
||||
@ -59,6 +80,18 @@ async fn start(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<
|
||||
}
|
||||
|
||||
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<Vec<u8>, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a dyn ToSql>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let mut buf = encode_bind(statement, params, "")?;
|
||||
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
|
||||
frontend::sync(&mut buf);
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
pub fn encode_bind<'a, I>(statement: &Statement, params: I, portal: &str) -> Result<Vec<u8>, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a dyn ToSql>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
@ -76,7 +109,7 @@ where
|
||||
|
||||
let mut error_idx = 0;
|
||||
let r = frontend::bind(
|
||||
"",
|
||||
portal,
|
||||
statement.name(),
|
||||
Some(1),
|
||||
params.zip(statement.params()).enumerate(),
|
||||
@ -92,15 +125,10 @@ where
|
||||
&mut buf,
|
||||
);
|
||||
match r {
|
||||
Ok(()) => {}
|
||||
Ok(()) => Ok(buf),
|
||||
Err(frontend::BindError::Conversion(e)) => return Err(Error::to_sql(e, error_idx)),
|
||||
Err(frontend::BindError::Serialization(e)) => return Err(Error::encode(e)),
|
||||
}
|
||||
|
||||
frontend::execute("", 0, &mut buf).map_err(Error::encode)?;
|
||||
frontend::sync(&mut buf);
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
struct Query {
|
||||
@ -116,7 +144,9 @@ impl Stream for Query {
|
||||
Message::DataRow(body) => {
|
||||
Poll::Ready(Some(Ok(Row::new(self.statement.clone(), body)?)))
|
||||
}
|
||||
Message::EmptyQueryResponse | Message::CommandComplete(_) => Poll::Ready(None),
|
||||
Message::EmptyQueryResponse
|
||||
| Message::CommandComplete(_)
|
||||
| Message::PortalSuspended => Poll::Ready(None),
|
||||
Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))),
|
||||
_ => Poll::Ready(Some(Err(Error::unexpected_message()))),
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ use crate::tls::TlsConnect;
|
||||
use crate::types::{ToSql, Type};
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Socket;
|
||||
use crate::{query, Client, Error, Row, SimpleQueryMessage, Statement};
|
||||
use crate::{bind, query, Client, Error, Portal, Row, SimpleQueryMessage, Statement};
|
||||
use bytes::{Bytes, IntoBuf};
|
||||
use futures::{Stream, TryStream};
|
||||
use postgres_protocol::message::frontend;
|
||||
@ -122,6 +122,52 @@ impl<'a> Transaction<'a> {
|
||||
query::execute(self.client.inner(), buf)
|
||||
}
|
||||
|
||||
/// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
|
||||
///
|
||||
/// Portals only last for the duration of the transaction in which they are created, and can only be used on the
|
||||
/// connection that created them.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of parameters provided does not match the number expected.
|
||||
pub fn bind(
|
||||
&mut self,
|
||||
statement: &Statement,
|
||||
params: &[&dyn ToSql],
|
||||
) -> impl Future<Output = Result<Portal, Error>> {
|
||||
// https://github.com/rust-lang/rust/issues/63032
|
||||
let buf = bind::encode(statement, params.iter().cloned());
|
||||
bind::bind(self.client.inner(), statement.clone(), buf)
|
||||
}
|
||||
|
||||
/// Like [`bind`], but takes an iterator of parameters rather than a slice.
|
||||
///
|
||||
/// [`bind`]: #method.bind
|
||||
pub fn bind_iter<'b, I>(
|
||||
&mut self,
|
||||
statement: &Statement,
|
||||
params: I,
|
||||
) -> impl Future<Output = Result<Portal, Error>>
|
||||
where
|
||||
I: IntoIterator<Item = &'b dyn ToSql>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let buf = bind::encode(statement, params);
|
||||
bind::bind(self.client.inner(), statement.clone(), buf)
|
||||
}
|
||||
|
||||
/// Continues execution of a portal, returning a stream of the resulting rows.
|
||||
///
|
||||
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
|
||||
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
|
||||
pub fn query_portal(
|
||||
&mut self,
|
||||
portal: &Portal,
|
||||
max_rows: i32,
|
||||
) -> impl Stream<Item = Result<Row, Error>> {
|
||||
query::query_portal(self.client.inner(), portal.clone(), max_rows)
|
||||
}
|
||||
|
||||
/// Like `Client::copy_in`.
|
||||
pub fn copy_in<S>(
|
||||
&mut self,
|
||||
|
@ -569,6 +569,49 @@ async fn notifications() {
|
||||
assert_eq!(notifications[1].payload(), "world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn query_portal() {
|
||||
let mut client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute(
|
||||
"CREATE TEMPORARY TABLE foo (
|
||||
id SERIAL,
|
||||
name TEXT
|
||||
);
|
||||
|
||||
INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stmt = client
|
||||
.prepare("SELECT id, name FROM foo ORDER BY id")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut transaction = client.transaction().await.unwrap();
|
||||
|
||||
let portal = transaction.bind(&stmt, &[]).await.unwrap();
|
||||
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
|
||||
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
|
||||
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
|
||||
|
||||
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();
|
||||
|
||||
assert_eq!(r1.len(), 2);
|
||||
assert_eq!(r1[0].get::<_, i32>(0), 1);
|
||||
assert_eq!(r1[0].get::<_, &str>(1), "alice");
|
||||
assert_eq!(r1[1].get::<_, i32>(0), 2);
|
||||
assert_eq!(r1[1].get::<_, &str>(1), "bob");
|
||||
|
||||
assert_eq!(r2.len(), 1);
|
||||
assert_eq!(r2[0].get::<_, i32>(0), 3);
|
||||
assert_eq!(r2[0].get::<_, &str>(1), "charlie");
|
||||
|
||||
assert_eq!(r3.len(), 0);
|
||||
}
|
||||
|
||||
/*
|
||||
#[test]
|
||||
fn query_portal() {
|
||||
|
Loading…
Reference in New Issue
Block a user