diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index 6eb27b23..7a50bc67 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -1,4 +1,4 @@ -use futures::{FutureExt}; +use futures::FutureExt; use native_tls::{self, Certificate}; use tokio::net::TcpStream; use tokio_postgres::tls::TlsConnect; @@ -21,10 +21,7 @@ where tokio::spawn(connection); let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); - let rows = client - .query(&stmt, &[&1i32]) - .await - .unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, i32>(0), 1); @@ -96,10 +93,7 @@ async fn runtime() { tokio::spawn(connection); let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); - let rows = client - .query(&stmt, &[&1i32]) - .await - .unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, i32>(0), 1); diff --git a/postgres-openssl/src/test.rs b/postgres-openssl/src/test.rs index eb3e5e29..15ed90ad 100644 --- a/postgres-openssl/src/test.rs +++ b/postgres-openssl/src/test.rs @@ -1,4 +1,4 @@ -use futures::{FutureExt}; +use futures::FutureExt; use openssl::ssl::{SslConnector, SslMethod}; use tokio::net::TcpStream; use tokio_postgres::tls::TlsConnect; @@ -19,10 +19,7 @@ where tokio::spawn(connection); let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); - let rows = client - .query(&stmt, &[&1i32]) - .await - .unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, i32>(0), 1); @@ -107,10 +104,7 @@ async fn runtime() { tokio::spawn(connection); let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); - let rows = client - .query(&stmt, &[&1i32]) - .await - .unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, i32>(0), 1); diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index d9dbbc9c..895c6939 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -95,17 +95,17 @@ impl<'a> Transaction<'a> { /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned. pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result, Error> { - self.query_portal_iter(portal, max_rows)?.collect() + executor::block_on(self.0.query_portal(portal, max_rows)) } - /// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering - /// the entire response in memory. - pub fn query_portal_iter<'b>( - &'b mut self, - portal: &'b Portal, + /// The maximally flexible version of `query_portal`. + pub fn query_portal_raw( + &mut self, + portal: &Portal, max_rows: i32, - ) -> Result + 'b, Error> { - Ok(Iter::new(self.0.query_portal(&portal, max_rows))) + ) -> Result, Error> { + let stream = executor::block_on(self.0.query_portal_raw(portal, max_rows))?; + Ok(Iter::new(stream)) } /// Like `Client::copy_in`. diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 5260da26..57af33f6 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -3,7 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{IsNull, ToSql}; use crate::{Error, Portal, Row, Statement}; -use futures::{ready, Stream, TryFutureExt}; +use futures::{ready, Stream}; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::pin::Pin; @@ -26,25 +26,21 @@ where }) } -pub fn query_portal<'a>( - client: &'a InnerClient, - portal: &'a Portal, +pub async fn query_portal( + client: &InnerClient, + portal: &Portal, max_rows: i32, -) -> impl Stream> + 'a { - let start = async move { - let mut buf = vec![]; - frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?; - frontend::sync(&mut buf); +) -> Result { + let mut buf = vec![]; + frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?; + frontend::sync(&mut buf); - let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - Ok(RowStream { - statement: portal.statement().clone(), - responses, - }) - }; - - start.try_flatten_stream() + Ok(RowStream { + statement: portal.statement().clone(), + responses, + }) } pub async fn execute<'a, I>( diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 8407c79c..ee6db4c9 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -11,7 +11,7 @@ use crate::{ bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement, }; use bytes::{Bytes, IntoBuf}; -use futures::{Stream, TryStream}; +use futures::{Stream, TryStream, TryStreamExt}; use postgres_protocol::message::frontend; use std::error; use tokio::io::{AsyncRead, AsyncWrite}; @@ -177,12 +177,20 @@ impl<'a> Transaction<'a> { /// /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// `query_portal`. If the requested number is negative or 0, all rows will be returned. - pub fn query_portal<'b>( - &'b self, - portal: &'b Portal, + pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result, Error> { + self.query_portal_raw(portal, max_rows) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of `query_portal`. + pub async fn query_portal_raw( + &self, + portal: &Portal, max_rows: i32, - ) -> impl Stream> + 'b { - query::query_portal(self.client.inner(), portal, max_rows) + ) -> Result { + query::query_portal(self.client.inner(), portal, max_rows).await } /// Like `Client::copy_in`. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 547195bb..4beb3fe0 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -335,10 +335,7 @@ async fn transaction_commit() { transaction.commit().await.unwrap(); let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, &str>(0), "steven"); @@ -366,10 +363,7 @@ async fn transaction_rollback() { transaction.rollback().await.unwrap(); let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(rows.len(), 0); } @@ -396,10 +390,7 @@ async fn transaction_rollback_drop() { drop(transaction); let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(rows.len(), 0); } @@ -431,10 +422,7 @@ async fn copy_in() { .prepare("SELECT id, name FROM foo ORDER BY id") .await .unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(rows.len(), 2); assert_eq!(rows[0].get::<_, i32>(0), 1); @@ -497,10 +485,7 @@ async fn copy_in_error() { .prepare("SELECT id, name FROM foo ORDER BY id") .await .unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(rows.len(), 0); } @@ -583,9 +568,9 @@ async fn query_portal() { let transaction = client.transaction().await.unwrap(); let portal = transaction.bind(&stmt, &[]).await.unwrap(); - let f1 = transaction.query_portal(&portal, 2).try_collect::>(); - let f2 = transaction.query_portal(&portal, 2).try_collect::>(); - let f3 = transaction.query_portal(&portal, 2).try_collect::>(); + let f1 = transaction.query_portal(&portal, 2); + let f2 = transaction.query_portal(&portal, 2); + let f3 = transaction.query_portal(&portal, 2); let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap(); diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 50b3ab6f..dbfe9192 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -16,10 +16,7 @@ async fn smoke_test(s: &str) { let client = connect(s).await; let stmt = client.prepare("SELECT $1::INT").await.unwrap(); - let rows = client - .query(&stmt, &[&1i32]) - .await - .unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); assert_eq!(rows[0].get::<_, i32>(0), 1i32); } diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 40d3017c..fefd1ed5 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -195,10 +195,7 @@ async fn test_borrowed_text() { let client = connect("user=postgres").await; let stmt = client.prepare("SELECT 'foo'").await.unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); let s: &str = rows[0].get(0); assert_eq!(s, "foo"); } @@ -298,10 +295,7 @@ async fn test_bytea_params() { async fn test_borrowed_bytea() { let client = connect("user=postgres").await; let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); let s: &[u8] = rows[0].get(0); assert_eq!(s, b"foo"); } @@ -360,10 +354,7 @@ where .prepare(&format!("SELECT 'NaN'::{}", sql_type)) .await .unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); let val: T = rows[0].get(0); assert!(val != val); } @@ -385,10 +376,7 @@ async fn test_pg_database_datname() { .prepare("SELECT datname FROM pg_database") .await .unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(rows[0].get::<_, &str>(0), "postgres"); } @@ -439,11 +427,7 @@ async fn test_slice_wrong_type() { .prepare("SELECT * FROM foo WHERE id = ANY($1)") .await .unwrap(); - let err = client - .query(&stmt, &[&&[&"hi"][..]]) - .await - .err() - .unwrap(); + let err = client.query(&stmt, &[&&[&"hi"][..]]).await.err().unwrap(); match err.source() { Some(e) if e.is::() => {} _ => panic!("Unexpected error {:?}", err), @@ -455,11 +439,7 @@ async fn test_slice_range() { let client = connect("user=postgres").await; let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap(); - let err = client - .query(&stmt, &[&&[&1i64][..]]) - .await - .err() - .unwrap(); + let err = client.query(&stmt, &[&&[&1i64][..]]).await.err().unwrap(); match err.source() { Some(e) if e.is::() => {} _ => panic!("Unexpected error {:?}", err), @@ -527,10 +507,7 @@ async fn domain() { client.execute(&stmt, &[&id]).await.unwrap(); let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap(); - let rows = client - .query(&stmt, &[]) - .await - .unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); assert_eq!(id, rows[0].get(0)); }