Overhaul query_portal
This commit is contained in:
parent
2517100132
commit
b8577b45b1
@ -1,4 +1,4 @@
|
||||
use futures::{FutureExt};
|
||||
use futures::FutureExt;
|
||||
use native_tls::{self, Certificate};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
@ -21,10 +21,7 @@ where
|
||||
tokio::spawn(connection);
|
||||
|
||||
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[&1i32])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
@ -96,10 +93,7 @@ async fn runtime() {
|
||||
tokio::spawn(connection);
|
||||
|
||||
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[&1i32])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
|
@ -1,4 +1,4 @@
|
||||
use futures::{FutureExt};
|
||||
use futures::FutureExt;
|
||||
use openssl::ssl::{SslConnector, SslMethod};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::TlsConnect;
|
||||
@ -19,10 +19,7 @@ where
|
||||
tokio::spawn(connection);
|
||||
|
||||
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[&1i32])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
@ -107,10 +104,7 @@ async fn runtime() {
|
||||
tokio::spawn(connection);
|
||||
|
||||
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[&1i32])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
|
@ -95,17 +95,17 @@ impl<'a> Transaction<'a> {
|
||||
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
|
||||
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
|
||||
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
|
||||
self.query_portal_iter(portal, max_rows)?.collect()
|
||||
executor::block_on(self.0.query_portal(portal, max_rows))
|
||||
}
|
||||
|
||||
/// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering
|
||||
/// the entire response in memory.
|
||||
pub fn query_portal_iter<'b>(
|
||||
&'b mut self,
|
||||
portal: &'b Portal,
|
||||
/// The maximally flexible version of `query_portal`.
|
||||
pub fn query_portal_raw(
|
||||
&mut self,
|
||||
portal: &Portal,
|
||||
max_rows: i32,
|
||||
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error> {
|
||||
Ok(Iter::new(self.0.query_portal(&portal, max_rows)))
|
||||
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> {
|
||||
let stream = executor::block_on(self.0.query_portal_raw(portal, max_rows))?;
|
||||
Ok(Iter::new(stream))
|
||||
}
|
||||
|
||||
/// Like `Client::copy_in`.
|
||||
|
@ -3,7 +3,7 @@ use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::{IsNull, ToSql};
|
||||
use crate::{Error, Portal, Row, Statement};
|
||||
use futures::{ready, Stream, TryFutureExt};
|
||||
use futures::{ready, Stream};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::pin::Pin;
|
||||
@ -26,25 +26,21 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
pub fn query_portal<'a>(
|
||||
client: &'a InnerClient,
|
||||
portal: &'a Portal,
|
||||
pub async fn query_portal(
|
||||
client: &InnerClient,
|
||||
portal: &Portal,
|
||||
max_rows: i32,
|
||||
) -> impl Stream<Item = Result<Row, Error>> + 'a {
|
||||
let start = async move {
|
||||
let mut buf = vec![];
|
||||
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
|
||||
frontend::sync(&mut buf);
|
||||
) -> Result<RowStream, Error> {
|
||||
let mut buf = vec![];
|
||||
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
|
||||
frontend::sync(&mut buf);
|
||||
|
||||
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
Ok(RowStream {
|
||||
statement: portal.statement().clone(),
|
||||
responses,
|
||||
})
|
||||
};
|
||||
|
||||
start.try_flatten_stream()
|
||||
Ok(RowStream {
|
||||
statement: portal.statement().clone(),
|
||||
responses,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn execute<'a, I>(
|
||||
|
@ -11,7 +11,7 @@ use crate::{
|
||||
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
|
||||
};
|
||||
use bytes::{Bytes, IntoBuf};
|
||||
use futures::{Stream, TryStream};
|
||||
use futures::{Stream, TryStream, TryStreamExt};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@ -177,12 +177,20 @@ impl<'a> Transaction<'a> {
|
||||
///
|
||||
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
|
||||
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
|
||||
pub fn query_portal<'b>(
|
||||
&'b self,
|
||||
portal: &'b Portal,
|
||||
pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
|
||||
self.query_portal_raw(portal, max_rows)
|
||||
.await?
|
||||
.try_collect()
|
||||
.await
|
||||
}
|
||||
|
||||
/// The maximally flexible version of `query_portal`.
|
||||
pub async fn query_portal_raw(
|
||||
&self,
|
||||
portal: &Portal,
|
||||
max_rows: i32,
|
||||
) -> impl Stream<Item = Result<Row, Error>> + 'b {
|
||||
query::query_portal(self.client.inner(), portal, max_rows)
|
||||
) -> Result<RowStream, Error> {
|
||||
query::query_portal(self.client.inner(), portal, max_rows).await
|
||||
}
|
||||
|
||||
/// Like `Client::copy_in`.
|
||||
|
@ -335,10 +335,7 @@ async fn transaction_commit() {
|
||||
transaction.commit().await.unwrap();
|
||||
|
||||
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].get::<_, &str>(0), "steven");
|
||||
@ -366,10 +363,7 @@ async fn transaction_rollback() {
|
||||
transaction.rollback().await.unwrap();
|
||||
|
||||
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
@ -396,10 +390,7 @@ async fn transaction_rollback_drop() {
|
||||
drop(transaction);
|
||||
|
||||
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
@ -431,10 +422,7 @@ async fn copy_in() {
|
||||
.prepare("SELECT id, name FROM foo ORDER BY id")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 2);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
@ -497,10 +485,7 @@ async fn copy_in_error() {
|
||||
.prepare("SELECT id, name FROM foo ORDER BY id")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
|
||||
@ -583,9 +568,9 @@ async fn query_portal() {
|
||||
let transaction = client.transaction().await.unwrap();
|
||||
|
||||
let portal = transaction.bind(&stmt, &[]).await.unwrap();
|
||||
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
|
||||
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
|
||||
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
|
||||
let f1 = transaction.query_portal(&portal, 2);
|
||||
let f2 = transaction.query_portal(&portal, 2);
|
||||
let f3 = transaction.query_portal(&portal, 2);
|
||||
|
||||
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();
|
||||
|
||||
|
@ -16,10 +16,7 @@ async fn smoke_test(s: &str) {
|
||||
let client = connect(s).await;
|
||||
|
||||
let stmt = client.prepare("SELECT $1::INT").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[&1i32])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1i32);
|
||||
}
|
||||
|
||||
|
@ -195,10 +195,7 @@ async fn test_borrowed_text() {
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
let stmt = client.prepare("SELECT 'foo'").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
let s: &str = rows[0].get(0);
|
||||
assert_eq!(s, "foo");
|
||||
}
|
||||
@ -298,10 +295,7 @@ async fn test_bytea_params() {
|
||||
async fn test_borrowed_bytea() {
|
||||
let client = connect("user=postgres").await;
|
||||
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
let s: &[u8] = rows[0].get(0);
|
||||
assert_eq!(s, b"foo");
|
||||
}
|
||||
@ -360,10 +354,7 @@ where
|
||||
.prepare(&format!("SELECT 'NaN'::{}", sql_type))
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
let val: T = rows[0].get(0);
|
||||
assert!(val != val);
|
||||
}
|
||||
@ -385,10 +376,7 @@ async fn test_pg_database_datname() {
|
||||
.prepare("SELECT datname FROM pg_database")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
assert_eq!(rows[0].get::<_, &str>(0), "postgres");
|
||||
}
|
||||
|
||||
@ -439,11 +427,7 @@ async fn test_slice_wrong_type() {
|
||||
.prepare("SELECT * FROM foo WHERE id = ANY($1)")
|
||||
.await
|
||||
.unwrap();
|
||||
let err = client
|
||||
.query(&stmt, &[&&[&"hi"][..]])
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
let err = client.query(&stmt, &[&&[&"hi"][..]]).await.err().unwrap();
|
||||
match err.source() {
|
||||
Some(e) if e.is::<WrongType>() => {}
|
||||
_ => panic!("Unexpected error {:?}", err),
|
||||
@ -455,11 +439,7 @@ async fn test_slice_range() {
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
|
||||
let err = client
|
||||
.query(&stmt, &[&&[&1i64][..]])
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
let err = client.query(&stmt, &[&&[&1i64][..]]).await.err().unwrap();
|
||||
match err.source() {
|
||||
Some(e) if e.is::<WrongType>() => {}
|
||||
_ => panic!("Unexpected error {:?}", err),
|
||||
@ -527,10 +507,7 @@ async fn domain() {
|
||||
client.execute(&stmt, &[&id]).await.unwrap();
|
||||
|
||||
let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap();
|
||||
let rows = client
|
||||
.query(&stmt, &[])
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
assert_eq!(id, rows[0].get(0));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user