Overhaul query_portal

This commit is contained in:
Steven Fackler 2019-10-08 17:22:56 -07:00
parent 2517100132
commit b8577b45b1
8 changed files with 57 additions and 106 deletions

View File

@ -1,4 +1,4 @@
use futures::{FutureExt}; use futures::FutureExt;
use native_tls::{self, Certificate}; use native_tls::{self, Certificate};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect; use tokio_postgres::tls::TlsConnect;
@ -21,10 +21,7 @@ where
tokio::spawn(connection); tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client let rows = client.query(&stmt, &[&1i32]).await.unwrap();
.query(&stmt, &[&1i32])
.await
.unwrap();
assert_eq!(rows.len(), 1); assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1); assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -96,10 +93,7 @@ async fn runtime() {
tokio::spawn(connection); tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client let rows = client.query(&stmt, &[&1i32]).await.unwrap();
.query(&stmt, &[&1i32])
.await
.unwrap();
assert_eq!(rows.len(), 1); assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1); assert_eq!(rows[0].get::<_, i32>(0), 1);

View File

@ -1,4 +1,4 @@
use futures::{FutureExt}; use futures::FutureExt;
use openssl::ssl::{SslConnector, SslMethod}; use openssl::ssl::{SslConnector, SslMethod};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect; use tokio_postgres::tls::TlsConnect;
@ -19,10 +19,7 @@ where
tokio::spawn(connection); tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client let rows = client.query(&stmt, &[&1i32]).await.unwrap();
.query(&stmt, &[&1i32])
.await
.unwrap();
assert_eq!(rows.len(), 1); assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1); assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -107,10 +104,7 @@ async fn runtime() {
tokio::spawn(connection); tokio::spawn(connection);
let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client let rows = client.query(&stmt, &[&1i32]).await.unwrap();
.query(&stmt, &[&1i32])
.await
.unwrap();
assert_eq!(rows.len(), 1); assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1); assert_eq!(rows[0].get::<_, i32>(0), 1);

View File

@ -95,17 +95,17 @@ impl<'a> Transaction<'a> {
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned. /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> { pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.query_portal_iter(portal, max_rows)?.collect() executor::block_on(self.0.query_portal(portal, max_rows))
} }
/// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering /// The maximally flexible version of `query_portal`.
/// the entire response in memory. pub fn query_portal_raw(
pub fn query_portal_iter<'b>( &mut self,
&'b mut self, portal: &Portal,
portal: &'b Portal,
max_rows: i32, max_rows: i32,
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error> { ) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> {
Ok(Iter::new(self.0.query_portal(&portal, max_rows))) let stream = executor::block_on(self.0.query_portal_raw(portal, max_rows))?;
Ok(Iter::new(stream))
} }
/// Like `Client::copy_in`. /// Like `Client::copy_in`.

View File

@ -3,7 +3,7 @@ use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::types::{IsNull, ToSql}; use crate::types::{IsNull, ToSql};
use crate::{Error, Portal, Row, Statement}; use crate::{Error, Portal, Row, Statement};
use futures::{ready, Stream, TryFutureExt}; use futures::{ready, Stream};
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use std::pin::Pin; use std::pin::Pin;
@ -26,25 +26,21 @@ where
}) })
} }
pub fn query_portal<'a>( pub async fn query_portal(
client: &'a InnerClient, client: &InnerClient,
portal: &'a Portal, portal: &Portal,
max_rows: i32, max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> + 'a { ) -> Result<RowStream, Error> {
let start = async move { let mut buf = vec![];
let mut buf = vec![]; frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?; frontend::sync(&mut buf);
frontend::sync(&mut buf);
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(RowStream { Ok(RowStream {
statement: portal.statement().clone(), statement: portal.statement().clone(),
responses, responses,
}) })
};
start.try_flatten_stream()
} }
pub async fn execute<'a, I>( pub async fn execute<'a, I>(

View File

@ -11,7 +11,7 @@ use crate::{
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement, bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
}; };
use bytes::{Bytes, IntoBuf}; use bytes::{Bytes, IntoBuf};
use futures::{Stream, TryStream}; use futures::{Stream, TryStream, TryStreamExt};
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use std::error; use std::error;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
@ -177,12 +177,20 @@ impl<'a> Transaction<'a> {
/// ///
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all rows will be returned. /// `query_portal`. If the requested number is negative or 0, all rows will be returned.
pub fn query_portal<'b>( pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
&'b self, self.query_portal_raw(portal, max_rows)
portal: &'b Portal, .await?
.try_collect()
.await
}
/// The maximally flexible version of `query_portal`.
pub async fn query_portal_raw(
&self,
portal: &Portal,
max_rows: i32, max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> + 'b { ) -> Result<RowStream, Error> {
query::query_portal(self.client.inner(), portal, max_rows) query::query_portal(self.client.inner(), portal, max_rows).await
} }
/// Like `Client::copy_in`. /// Like `Client::copy_in`.

View File

@ -335,10 +335,7 @@ async fn transaction_commit() {
transaction.commit().await.unwrap(); transaction.commit().await.unwrap();
let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(rows.len(), 1); assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "steven"); assert_eq!(rows[0].get::<_, &str>(0), "steven");
@ -366,10 +363,7 @@ async fn transaction_rollback() {
transaction.rollback().await.unwrap(); transaction.rollback().await.unwrap();
let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(rows.len(), 0); assert_eq!(rows.len(), 0);
} }
@ -396,10 +390,7 @@ async fn transaction_rollback_drop() {
drop(transaction); drop(transaction);
let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(rows.len(), 0); assert_eq!(rows.len(), 0);
} }
@ -431,10 +422,7 @@ async fn copy_in() {
.prepare("SELECT id, name FROM foo ORDER BY id") .prepare("SELECT id, name FROM foo ORDER BY id")
.await .await
.unwrap(); .unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(rows.len(), 2); assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1); assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -497,10 +485,7 @@ async fn copy_in_error() {
.prepare("SELECT id, name FROM foo ORDER BY id") .prepare("SELECT id, name FROM foo ORDER BY id")
.await .await
.unwrap(); .unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(rows.len(), 0); assert_eq!(rows.len(), 0);
} }
@ -583,9 +568,9 @@ async fn query_portal() {
let transaction = client.transaction().await.unwrap(); let transaction = client.transaction().await.unwrap();
let portal = transaction.bind(&stmt, &[]).await.unwrap(); let portal = transaction.bind(&stmt, &[]).await.unwrap();
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>(); let f1 = transaction.query_portal(&portal, 2);
let f2 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>(); let f2 = transaction.query_portal(&portal, 2);
let f3 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>(); let f3 = transaction.query_portal(&portal, 2);
let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap(); let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap();

View File

@ -16,10 +16,7 @@ async fn smoke_test(s: &str) {
let client = connect(s).await; let client = connect(s).await;
let stmt = client.prepare("SELECT $1::INT").await.unwrap(); let stmt = client.prepare("SELECT $1::INT").await.unwrap();
let rows = client let rows = client.query(&stmt, &[&1i32]).await.unwrap();
.query(&stmt, &[&1i32])
.await
.unwrap();
assert_eq!(rows[0].get::<_, i32>(0), 1i32); assert_eq!(rows[0].get::<_, i32>(0), 1i32);
} }

View File

@ -195,10 +195,7 @@ async fn test_borrowed_text() {
let client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'").await.unwrap(); let stmt = client.prepare("SELECT 'foo'").await.unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
let s: &str = rows[0].get(0); let s: &str = rows[0].get(0);
assert_eq!(s, "foo"); assert_eq!(s, "foo");
} }
@ -298,10 +295,7 @@ async fn test_bytea_params() {
async fn test_borrowed_bytea() { async fn test_borrowed_bytea() {
let client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap(); let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
let s: &[u8] = rows[0].get(0); let s: &[u8] = rows[0].get(0);
assert_eq!(s, b"foo"); assert_eq!(s, b"foo");
} }
@ -360,10 +354,7 @@ where
.prepare(&format!("SELECT 'NaN'::{}", sql_type)) .prepare(&format!("SELECT 'NaN'::{}", sql_type))
.await .await
.unwrap(); .unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
let val: T = rows[0].get(0); let val: T = rows[0].get(0);
assert!(val != val); assert!(val != val);
} }
@ -385,10 +376,7 @@ async fn test_pg_database_datname() {
.prepare("SELECT datname FROM pg_database") .prepare("SELECT datname FROM pg_database")
.await .await
.unwrap(); .unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(rows[0].get::<_, &str>(0), "postgres"); assert_eq!(rows[0].get::<_, &str>(0), "postgres");
} }
@ -439,11 +427,7 @@ async fn test_slice_wrong_type() {
.prepare("SELECT * FROM foo WHERE id = ANY($1)") .prepare("SELECT * FROM foo WHERE id = ANY($1)")
.await .await
.unwrap(); .unwrap();
let err = client let err = client.query(&stmt, &[&&[&"hi"][..]]).await.err().unwrap();
.query(&stmt, &[&&[&"hi"][..]])
.await
.err()
.unwrap();
match err.source() { match err.source() {
Some(e) if e.is::<WrongType>() => {} Some(e) if e.is::<WrongType>() => {}
_ => panic!("Unexpected error {:?}", err), _ => panic!("Unexpected error {:?}", err),
@ -455,11 +439,7 @@ async fn test_slice_range() {
let client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap(); let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
let err = client let err = client.query(&stmt, &[&&[&1i64][..]]).await.err().unwrap();
.query(&stmt, &[&&[&1i64][..]])
.await
.err()
.unwrap();
match err.source() { match err.source() {
Some(e) if e.is::<WrongType>() => {} Some(e) if e.is::<WrongType>() => {}
_ => panic!("Unexpected error {:?}", err), _ => panic!("Unexpected error {:?}", err),
@ -527,10 +507,7 @@ async fn domain() {
client.execute(&stmt, &[&id]).await.unwrap(); client.execute(&stmt, &[&id]).await.unwrap();
let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap(); let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap();
let rows = client let rows = client.query(&stmt, &[]).await.unwrap();
.query(&stmt, &[])
.await
.unwrap();
assert_eq!(id, rows[0].get(0)); assert_eq!(id, rows[0].get(0));
} }