Overhaul query_portal

This commit is contained in:
Steven Fackler 2019-10-08 17:22:56 -07:00
parent 2517100132
commit b8577b45b1
8 changed files with 57 additions and 106 deletions

View File

@ -1,4 +1,4 @@
use futures::{FutureExt};
use futures::FutureExt;
use native_tls::{self, Certificate};
use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect;
@ -21,10 +21,7 @@ where
tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.await
.unwrap();
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -96,10 +93,7 @@ async fn runtime() {
tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.await
.unwrap();
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);

View File

@ -1,4 +1,4 @@
use futures::{FutureExt};
use futures::FutureExt;
use openssl::ssl::{SslConnector, SslMethod};
use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect;
@ -19,10 +19,7 @@ where
tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.await
.unwrap();
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -107,10 +104,7 @@ async fn runtime() {
tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.await
.unwrap();
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);

View File

@ -95,17 +95,17 @@ impl<'a> Transaction<'a> {
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.query_portal_iter(portal, max_rows)?.collect()
executor::block_on(self.0.query_portal(portal, max_rows))
}
/// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering
/// the entire response in memory.
pub fn query_portal_iter<'b>(
&'b mut self,
portal: &'b Portal,
/// The maximally flexible version of `query_portal`.
pub fn query_portal_raw(
&mut self,
portal: &Portal,
max_rows: i32,
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error> {
Ok(Iter::new(self.0.query_portal(&portal, max_rows)))
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> {
let stream = executor::block_on(self.0.query_portal_raw(portal, max_rows))?;
Ok(Iter::new(stream))
}
/// Like `Client::copy_in`.

View File

@ -3,7 +3,7 @@ use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::{IsNull, ToSql};
use crate::{Error, Portal, Row, Statement};
use futures::{ready, Stream, TryFutureExt};
use futures::{ready, Stream};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::pin::Pin;
@ -26,25 +26,21 @@ where
})
}
pub fn query_portal<'a>(
client: &'a InnerClient,
portal: &'a Portal,
pub async fn query_portal(
client: &InnerClient,
portal: &Portal,
max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> + 'a {
let start = async move {
let mut buf = vec![];
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
frontend::sync(&mut buf);
) -> Result<RowStream, Error> {
let mut buf = vec![];
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
frontend::sync(&mut buf);
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(RowStream {
statement: portal.statement().clone(),
responses,
})
};
start.try_flatten_stream()
Ok(RowStream {
statement: portal.statement().clone(),
responses,
})
}
pub async fn execute<'a, I>(

View File

@ -11,7 +11,7 @@ use crate::{
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
};
use bytes::{Bytes, IntoBuf};
use futures::{Stream, TryStream};
use futures::{Stream, TryStream, TryStreamExt};
use postgres_protocol::message::frontend;
use std::error;
use tokio::io::{AsyncRead, AsyncWrite};
@ -177,12 +177,20 @@ impl<'a> Transaction<'a> {
///
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
pub fn query_portal<'b>(
&'b self,
portal: &'b Portal,
pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.query_portal_raw(portal, max_rows)
.await?
.try_collect()
.await
}
/// The maximally flexible version of `query_portal`.
pub async fn query_portal_raw(
&self,
portal: &Portal,
max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> + 'b {
query::query_portal(self.client.inner(), portal, max_rows)
) -> Result<RowStream, Error> {
query::query_portal(self.client.inner(), portal, max_rows).await
}
/// Like `Client::copy_in`.

View File

@ -335,10 +335,7 @@ async fn transaction_commit() {
transaction.commit().await.unwrap();
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "steven");
@ -366,10 +363,7 @@ async fn transaction_rollback() {
transaction.rollback().await.unwrap();
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(rows.len(), 0);
}
@ -396,10 +390,7 @@ async fn transaction_rollback_drop() {
drop(transaction);
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(rows.len(), 0);
}
@ -431,10 +422,7 @@ async fn copy_in() {
.prepare("SELECT id, name FROM foo ORDER BY id")
.await
.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -497,10 +485,7 @@ async fn copy_in_error() {
.prepare("SELECT id, name FROM foo ORDER BY id")
.await
.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(rows.len(), 0);
}
@ -583,9 +568,9 @@ async fn query_portal() {
let transaction = client.transaction().await.unwrap();
let portal = transaction.bind(&stmt, &[]).await.unwrap();
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
let f1 = transaction.query_portal(&portal, 2);
let f2 = transaction.query_portal(&portal, 2);
let f3 = transaction.query_portal(&portal, 2);
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();

View File

@ -16,10 +16,7 @@ async fn smoke_test(s: &str) {
let client = connect(s).await;
let stmt = client.prepare("SELECT $1::INT").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.await
.unwrap();
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
assert_eq!(rows[0].get::<_, i32>(0), 1i32);
}

View File

@ -195,10 +195,7 @@ async fn test_borrowed_text() {
let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'").await.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
let s: &str = rows[0].get(0);
assert_eq!(s, "foo");
}
@ -298,10 +295,7 @@ async fn test_bytea_params() {
async fn test_borrowed_bytea() {
let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
let s: &[u8] = rows[0].get(0);
assert_eq!(s, b"foo");
}
@ -360,10 +354,7 @@ where
.prepare(&format!("SELECT 'NaN'::{}", sql_type))
.await
.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
let val: T = rows[0].get(0);
assert!(val != val);
}
@ -385,10 +376,7 @@ async fn test_pg_database_datname() {
.prepare("SELECT datname FROM pg_database")
.await
.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(rows[0].get::<_, &str>(0), "postgres");
}
@ -439,11 +427,7 @@ async fn test_slice_wrong_type() {
.prepare("SELECT * FROM foo WHERE id = ANY($1)")
.await
.unwrap();
let err = client
.query(&stmt, &[&&[&"hi"][..]])
.await
.err()
.unwrap();
let err = client.query(&stmt, &[&&[&"hi"][..]]).await.err().unwrap();
match err.source() {
Some(e) if e.is::<WrongType>() => {}
_ => panic!("Unexpected error {:?}", err),
@ -455,11 +439,7 @@ async fn test_slice_range() {
let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
let err = client
.query(&stmt, &[&&[&1i64][..]])
.await
.err()
.unwrap();
let err = client.query(&stmt, &[&&[&1i64][..]]).await.err().unwrap();
match err.source() {
Some(e) if e.is::<WrongType>() => {}
_ => panic!("Unexpected error {:?}", err),
@ -527,10 +507,7 @@ async fn domain() {
client.execute(&stmt, &[&id]).await.unwrap();
let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
assert_eq!(id, rows[0].get(0));
}