diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index 45fc08a0..6eb27b23 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -1,4 +1,4 @@ -use futures::{FutureExt, TryStreamExt}; +use futures::{FutureExt}; use native_tls::{self, Certificate}; use tokio::net::TcpStream; use tokio_postgres::tls::TlsConnect; @@ -23,7 +23,6 @@ where let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let rows = client .query(&stmt, &[&1i32]) - .try_collect::>() .await .unwrap(); @@ -99,7 +98,6 @@ async fn runtime() { let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let rows = client .query(&stmt, &[&1i32]) - .try_collect::>() .await .unwrap(); diff --git a/postgres-openssl/src/test.rs b/postgres-openssl/src/test.rs index 9f29bab1..eb3e5e29 100644 --- a/postgres-openssl/src/test.rs +++ b/postgres-openssl/src/test.rs @@ -1,4 +1,4 @@ -use futures::{FutureExt, TryStreamExt}; +use futures::{FutureExt}; use openssl::ssl::{SslConnector, SslMethod}; use tokio::net::TcpStream; use tokio_postgres::tls::TlsConnect; @@ -21,7 +21,6 @@ where let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let rows = client .query(&stmt, &[&1i32]) - .try_collect::>() .await .unwrap(); @@ -110,7 +109,6 @@ async fn runtime() { let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); let rows = client .query(&stmt, &[&1i32]) - .try_collect::>() .await .unwrap(); diff --git a/postgres/src/client.rs b/postgres/src/client.rs index a4157b3e..f99afa1c 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -122,11 +122,13 @@ impl Client { where T: ?Sized + ToStatement, { - self.query_iter(query, params)?.collect() + executor::block_on(self.0.query(query, params)) } - /// Like `query`, except that it returns a fallible iterator over the resulting rows rather than buffering the - /// response in memory. + /// A maximally-flexible version of `query`. + /// + /// It takes an iterator of parameters rather than a slice, and returns an iterator of rows rather than collecting + /// them into an array. /// /// # Panics /// @@ -137,12 +139,13 @@ impl Client { /// ```no_run /// use postgres::{Client, NoTls}; /// use fallible_iterator::FallibleIterator; + /// use std::iter; /// /// # fn main() -> Result<(), postgres::Error> { /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; /// /// let baz = true; - /// let mut it = client.query_iter("SELECT foo FROM bar WHERE baz = $1", &[&baz])?; + /// let mut it = client.query_raw("SELECT foo FROM bar WHERE baz = $1", iter::once(&baz as _))?; /// /// while let Some(row) = it.next()? { /// let foo: i32 = row.get("foo"); @@ -151,15 +154,18 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub fn query_iter<'a, T>( - &'a mut self, - query: &'a T, - params: &'a [&(dyn ToSql + Sync)], - ) -> Result + 'a, Error> + pub fn query_raw<'a, T, I>( + &mut self, + query: &T, + params: I, + ) -> Result, Error> where T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { - Ok(Iter::new(self.0.query(query, params))) + let stream = executor::block_on(self.0.query_raw(query, params))?; + Ok(Iter::new(stream)) } /// Creates a new prepared statement. diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index ac30369e..d9dbbc9c 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -55,19 +55,22 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.query_iter(query, params)?.collect() + executor::block_on(self.0.query(query, params)) } - /// Like `Client::query_iter`. - pub fn query_iter<'b, T>( - &'b mut self, - query: &'b T, - params: &'b [&(dyn ToSql + Sync)], - ) -> Result + 'b, Error> + /// Like `Client::query_raw`. + pub fn query_raw<'b, T, I>( + &mut self, + query: &T, + params: I, + ) -> Result, Error> where T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { - Ok(Iter::new(self.0.query(query, params))) + let stream = executor::block_on(self.0.query_raw(query, params))?; + Ok(Iter::new(stream)) } /// Binds parameters to a statement, creating a "portal". diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 29d4a331..cb7d6f26 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -3,6 +3,7 @@ use crate::cancel_query; use crate::codec::BackendMessages; use crate::config::{Host, SslMode}; use crate::connection::{Request, RequestMessages}; +use crate::query::RowStream; use crate::slice_iter; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; @@ -18,7 +19,7 @@ use crate::{Error, Statement}; use bytes::{Bytes, IntoBuf}; use fallible_iterator::FallibleIterator; use futures::channel::mpsc; -use futures::{future, Stream, TryFutureExt, TryStream}; +use futures::{future, Stream, TryFutureExt, TryStream, TryStreamExt}; use futures::{ready, StreamExt}; use parking_lot::Mutex; use postgres_protocol::message::backend::Message; @@ -190,40 +191,40 @@ impl Client { prepare::prepare(&self.inner, query, parameter_types).await } - /// Executes a statement, returning a stream of the resulting rows. + /// Executes a statement, returning a vector of the resulting rows. /// /// # Panics /// /// Panics if the number of parameters provided does not match the number expected. - pub fn query<'a, T>( - &'a self, - statement: &'a T, - params: &'a [&(dyn ToSql + Sync)], - ) -> impl Stream> + 'a + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> where T: ?Sized + ToStatement, { - self.query_iter(statement, slice_iter(params)) + self.query_raw(statement, slice_iter(params)) + .await? + .try_collect() + .await } - /// Like [`query`], but takes an iterator of parameters rather than a slice. + /// The maximally flexible version of [`query`]. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. /// /// [`query`]: #method.query - pub fn query_iter<'a, T, I>( - &'a self, - statement: &'a T, - params: I, - ) -> impl Stream> + 'a + pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, - I: IntoIterator + 'a, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { - let f = async move { - let statement = statement.__convert().into_statement(self).await?; - Ok(query::query(&self.inner, statement, params)) - }; - f.try_flatten_stream() + let statement = statement.__convert().into_statement(self).await?; + query::query(&self.inner, statement, params).await } /// Executes a statement, returning the number of rows modified. @@ -241,13 +242,17 @@ impl Client { where T: ?Sized + ToStatement, { - self.execute_iter(statement, slice_iter(params)).await + self.execute_raw(statement, slice_iter(params)).await } - /// Like [`execute`], but takes an iterator of parameters rather than a slice. + /// The maximally flexible version of [`execute`]. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. /// /// [`execute`]: #method.execute - pub async fn execute_iter<'a, T, I>(&self, statement: &T, params: I) -> Result + pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, I: IntoIterator, diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 0a3fa6ae..39036927 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -3,7 +3,7 @@ //! # Example //! //! ```no_run -//! use futures::{FutureExt, TryStreamExt}; +//! use futures::FutureExt; //! use tokio_postgres::{NoTls, Error, Row}; //! //! # #[cfg(not(feature = "runtime"))] fn main() {} @@ -29,7 +29,6 @@ //! // And then execute it, returning a Stream of Rows which we collect into a Vec. //! let rows: Vec = client //! .query(&stmt, &[&"hello world"]) -//! .try_collect() //! .await?; //! //! // Now we can check that we got back the same string we sent over. diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index f3c18712..8f27156d 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -6,7 +6,7 @@ use crate::types::{Field, Kind, Oid, Type}; use crate::{query, slice_iter}; use crate::{Column, Error, Statement}; use fallible_iterator::FallibleIterator; -use futures::{future, TryStreamExt}; +use futures::TryStreamExt; use pin_utils::pin_mut; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; @@ -132,8 +132,7 @@ async fn get_type(client: &Arc, oid: Oid) -> Result { let stmt = typeinfo_statement(client).await?; - let params = &[&oid as _]; - let rows = query::query(client, stmt, slice_iter(params)); + let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; pin_mut!(rows); let row = match rows.try_next().await? { @@ -204,7 +203,8 @@ async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, oid: Oid) -> Result>() .await?; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index ee04866b..5260da26 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -9,24 +9,21 @@ use postgres_protocol::message::frontend; use std::pin::Pin; use std::task::{Context, Poll}; -pub fn query<'a, I>( - client: &'a InnerClient, +pub async fn query<'a, I>( + client: &InnerClient, statement: Statement, params: I, -) -> impl Stream> + 'a +) -> Result where - I: IntoIterator + 'a, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { - let f = async move { - let buf = encode(&statement, params)?; - let responses = start(client, buf).await?; - Ok(Query { - statement, - responses, - }) - }; - f.try_flatten_stream() + let buf = encode(&statement, params)?; + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + }) } pub fn query_portal<'a>( @@ -41,7 +38,7 @@ pub fn query_portal<'a>( let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - Ok(Query { + Ok(RowStream { statement: portal.statement().clone(), responses, }) @@ -145,12 +142,12 @@ where } } -struct Query { +pub struct RowStream { statement: Statement, responses: Responses, } -impl Stream for Query { +impl Stream for RowStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 5f9dc8fd..8407c79c 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -1,5 +1,6 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; +use crate::query::RowStream; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -93,29 +94,25 @@ impl<'a> Transaction<'a> { } /// Like `Client::query`. - pub fn query<'b, T>( - &'b self, - statement: &'b T, - params: &'b [&(dyn ToSql + Sync)], - ) -> impl Stream> + 'b + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> where T: ?Sized + ToStatement, { - self.client.query(statement, params) + self.client.query(statement, params).await } - /// Like `Client::query_iter`. - pub fn query_iter<'b, T, I>( - &'b self, - statement: &'b T, - params: I, - ) -> impl Stream> + 'b + /// Like `Client::query_raw`. + pub async fn query_raw<'b, T, I>(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, - I: IntoIterator + 'b, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { - self.client.query_iter(statement, params) + self.client.query_raw(statement, params).await } /// Like `Client::execute`. @@ -131,7 +128,7 @@ impl<'a> Transaction<'a> { } /// Like `Client::execute_iter`. - pub async fn execute_iter<'b, I, T>( + pub async fn execute_raw<'b, I, T>( &self, statement: &Statement, params: I, @@ -141,7 +138,7 @@ impl<'a> Transaction<'a> { I: IntoIterator, I::IntoIter: ExactSizeIterator, { - self.client.execute_iter(statement, params).await + self.client.execute_raw(statement, params).await } /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index f3b37e96..547195bb 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -126,7 +126,7 @@ async fn insert_select() { let (insert, select) = try_join!(insert, select).unwrap(); let insert = client.execute(&insert, &[&"alice", &"bob"]); - let select = client.query(&select, &[]).try_collect::>(); + let select = client.query(&select, &[]); let (_, rows) = try_join!(insert, select).unwrap(); assert_eq!(rows.len(), 2); @@ -337,7 +337,6 @@ async fn transaction_commit() { let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); @@ -369,7 +368,6 @@ async fn transaction_rollback() { let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); @@ -400,7 +398,6 @@ async fn transaction_rollback_drop() { let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); @@ -436,7 +433,6 @@ async fn copy_in() { .unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); @@ -503,7 +499,6 @@ async fn copy_in_error() { .unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); assert_eq!(rows.len(), 0); diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 07f0ed4f..50b3ab6f 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -1,4 +1,4 @@ -use futures::{join, FutureExt, TryStreamExt}; +use futures::{join, FutureExt}; use std::time::{Duration, Instant}; use tokio::timer; use tokio_postgres::error::SqlState; @@ -18,7 +18,6 @@ async fn smoke_test(s: &str) { let stmt = client.prepare("SELECT $1::INT").await.unwrap(); let rows = client .query(&stmt, &[&1i32]) - .try_collect::>() .await .unwrap(); assert_eq!(rows[0].get::<_, i32>(0), 1i32); diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 6f7dd5eb..40d3017c 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -1,4 +1,3 @@ -use futures::TryStreamExt; use postgres_types::to_sql_checked; use std::collections::HashMap; use std::error::Error; @@ -35,7 +34,6 @@ where for (val, repr) in checks { let rows = client .query(&*format!("SELECT {}::{}", repr, sql_type), &[]) - .try_collect::>() .await .unwrap(); let result = rows[0].get(0); @@ -43,7 +41,6 @@ where let rows = client .query(&*format!("SELECT $1::{}", sql_type), &[&val]) - .try_collect::>() .await .unwrap(); let result = rows[0].get(0); @@ -200,7 +197,6 @@ async fn test_borrowed_text() { let stmt = client.prepare("SELECT 'foo'").await.unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); let s: &str = rows[0].get(0); @@ -236,10 +232,11 @@ async fn test_bpchar_params() { .unwrap(); let rows = client .query(&stmt, &[]) - .map_ok(|row| row.get(0)) - .try_collect::>>() .await - .unwrap(); + .unwrap() + .into_iter() + .map(|row| row.get(0)) + .collect::>>(); assert_eq!( vec![Some("12345".to_owned()), Some("123 ".to_owned()), None], @@ -276,10 +273,11 @@ async fn test_citext_params() { .unwrap(); let rows = client .query(&stmt, &[]) - .map_ok(|row| row.get(0)) - .try_collect::>() .await - .unwrap(); + .unwrap() + .into_iter() + .map(|row| row.get(0)) + .collect::>(); assert_eq!(vec!["foobar".to_string(), "FooBar".to_string()], rows,); } @@ -302,7 +300,6 @@ async fn test_borrowed_bytea() { let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); let s: &[u8] = rows[0].get(0); @@ -365,7 +362,6 @@ where .unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); let val: T = rows[0].get(0); @@ -391,7 +387,6 @@ async fn test_pg_database_datname() { .unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); assert_eq!(rows[0].get::<_, &str>(0), "postgres"); @@ -418,10 +413,11 @@ async fn test_slice() { .unwrap(); let rows = client .query(&stmt, &[&&[1i32, 3, 4][..]]) - .map_ok(|r| r.get(0)) - .try_collect::>() .await - .unwrap(); + .unwrap() + .into_iter() + .map(|r| r.get(0)) + .collect::>(); assert_eq!(vec!["a".to_owned(), "c".to_owned(), "d".to_owned()], rows); } @@ -445,7 +441,6 @@ async fn test_slice_wrong_type() { .unwrap(); let err = client .query(&stmt, &[&&[&"hi"][..]]) - .try_collect::>() .await .err() .unwrap(); @@ -462,7 +457,6 @@ async fn test_slice_range() { let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap(); let err = client .query(&stmt, &[&&[&1i64][..]]) - .try_collect::>() .await .err() .unwrap(); @@ -535,7 +529,6 @@ async fn domain() { let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap(); let rows = client .query(&stmt, &[]) - .try_collect::>() .await .unwrap(); assert_eq!(id, rows[0].get(0));