Overhaul query

This is the template that we'll use for all other methods taking
parameters. The `foo_raw` variant is the most flexible (but annoying to
use), while `foo` covers the expected common case.
This commit is contained in:
Steven Fackler 2019-10-08 17:15:41 -07:00
parent 1473c09b83
commit 2517100132
12 changed files with 103 additions and 112 deletions

View File

@ -1,4 +1,4 @@
use futures::{FutureExt, TryStreamExt};
use futures::{FutureExt};
use native_tls::{self, Certificate};
use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect;
@ -23,7 +23,6 @@ where
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -99,7 +98,6 @@ async fn runtime() {
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.try_collect::<Vec<_>>()
.await
.unwrap();

View File

@ -1,4 +1,4 @@
use futures::{FutureExt, TryStreamExt};
use futures::{FutureExt};
use openssl::ssl::{SslConnector, SslMethod};
use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect;
@ -21,7 +21,6 @@ where
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -110,7 +109,6 @@ async fn runtime() {
let stmt = client.prepare("SELECT $1::INT4").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.try_collect::<Vec<_>>()
.await
.unwrap();

View File

@ -122,11 +122,13 @@ impl Client {
where
T: ?Sized + ToStatement,
{
self.query_iter(query, params)?.collect()
executor::block_on(self.0.query(query, params))
}
/// Like `query`, except that it returns a fallible iterator over the resulting rows rather than buffering the
/// response in memory.
/// A maximally-flexible version of `query`.
///
/// It takes an iterator of parameters rather than a slice, and returns an iterator of rows rather than collecting
/// them into an array.
///
/// # Panics
///
@ -137,12 +139,13 @@ impl Client {
/// ```no_run
/// use postgres::{Client, NoTls};
/// use fallible_iterator::FallibleIterator;
/// use std::iter;
///
/// # fn main() -> Result<(), postgres::Error> {
/// let mut client = Client::connect("host=localhost user=postgres", NoTls)?;
///
/// let baz = true;
/// let mut it = client.query_iter("SELECT foo FROM bar WHERE baz = $1", &[&baz])?;
/// let mut it = client.query_raw("SELECT foo FROM bar WHERE baz = $1", iter::once(&baz as _))?;
///
/// while let Some(row) = it.next()? {
/// let foo: i32 = row.get("foo");
@ -151,15 +154,18 @@ impl Client {
/// # Ok(())
/// # }
/// ```
pub fn query_iter<'a, T>(
&'a mut self,
query: &'a T,
params: &'a [&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'a, Error>
pub fn query_raw<'a, T, I>(
&mut self,
query: &T,
params: I,
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
Ok(Iter::new(self.0.query(query, params)))
let stream = executor::block_on(self.0.query_raw(query, params))?;
Ok(Iter::new(stream))
}
/// Creates a new prepared statement.

View File

@ -55,19 +55,22 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
self.query_iter(query, params)?.collect()
executor::block_on(self.0.query(query, params))
}
/// Like `Client::query_iter`.
pub fn query_iter<'b, T>(
&'b mut self,
query: &'b T,
params: &'b [&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error>
/// Like `Client::query_raw`.
pub fn query_raw<'b, T, I>(
&mut self,
query: &T,
params: I,
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
Ok(Iter::new(self.0.query(query, params)))
let stream = executor::block_on(self.0.query_raw(query, params))?;
Ok(Iter::new(stream))
}
/// Binds parameters to a statement, creating a "portal".

View File

@ -3,6 +3,7 @@ use crate::cancel_query;
use crate::codec::BackendMessages;
use crate::config::{Host, SslMode};
use crate::connection::{Request, RequestMessages};
use crate::query::RowStream;
use crate::slice_iter;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
@ -18,7 +19,7 @@ use crate::{Error, Statement};
use bytes::{Bytes, IntoBuf};
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{future, Stream, TryFutureExt, TryStream};
use futures::{future, Stream, TryFutureExt, TryStream, TryStreamExt};
use futures::{ready, StreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
@ -190,40 +191,40 @@ impl Client {
prepare::prepare(&self.inner, query, parameter_types).await
}
/// Executes a statement, returning a stream of the resulting rows.
/// Executes a statement, returning a vector of the resulting rows.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn query<'a, T>(
&'a self,
statement: &'a T,
params: &'a [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'a
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.query_iter(statement, slice_iter(params))
self.query_raw(statement, slice_iter(params))
.await?
.try_collect()
.await
}
/// Like [`query`], but takes an iterator of parameters rather than a slice.
/// The maximally flexible version of [`query`].
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`query`]: #method.query
pub fn query_iter<'a, T, I>(
&'a self,
statement: &'a T,
params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'a
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let f = async move {
let statement = statement.__convert().into_statement(self).await?;
Ok(query::query(&self.inner, statement, params))
};
f.try_flatten_stream()
query::query(&self.inner, statement, params).await
}
/// Executes a statement, returning the number of rows modified.
@ -241,13 +242,17 @@ impl Client {
where
T: ?Sized + ToStatement,
{
self.execute_iter(statement, slice_iter(params)).await
self.execute_raw(statement, slice_iter(params)).await
}
/// Like [`execute`], but takes an iterator of parameters rather than a slice.
/// The maximally flexible version of [`execute`].
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`execute`]: #method.execute
pub async fn execute_iter<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql>,

View File

@ -3,7 +3,7 @@
//! # Example
//!
//! ```no_run
//! use futures::{FutureExt, TryStreamExt};
//! use futures::FutureExt;
//! use tokio_postgres::{NoTls, Error, Row};
//!
//! # #[cfg(not(feature = "runtime"))] fn main() {}
@ -29,7 +29,6 @@
//! // And then execute it, returning a Stream of Rows which we collect into a Vec.
//! let rows: Vec<Row> = client
//! .query(&stmt, &[&"hello world"])
//! .try_collect()
//! .await?;
//!
//! // Now we can check that we got back the same string we sent over.

View File

@ -6,7 +6,7 @@ use crate::types::{Field, Kind, Oid, Type};
use crate::{query, slice_iter};
use crate::{Column, Error, Statement};
use fallible_iterator::FallibleIterator;
use futures::{future, TryStreamExt};
use futures::TryStreamExt;
use pin_utils::pin_mut;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
@ -132,8 +132,7 @@ async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
let stmt = typeinfo_statement(client).await?;
let params = &[&oid as _];
let rows = query::query(client, stmt, slice_iter(params));
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
pin_mut!(rows);
let row = match rows.try_next().await? {
@ -204,7 +203,8 @@ async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<St
let stmt = typeinfo_enum_statement(client).await?;
query::query(client, stmt, slice_iter(&[&oid]))
.and_then(|row| future::ready(row.try_get(0)))
.await?
.and_then(|row| async move { row.try_get(0) })
.try_collect()
.await
}
@ -230,6 +230,7 @@ async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec
let stmt = typeinfo_composite_statement(client).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid]))
.await?
.try_collect::<Vec<_>>()
.await?;

View File

@ -9,24 +9,21 @@ use postgres_protocol::message::frontend;
use std::pin::Pin;
use std::task::{Context, Poll};
pub fn query<'a, I>(
client: &'a InnerClient,
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'a
) -> Result<RowStream, Error>
where
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let f = async move {
let buf = encode(&statement, params)?;
let responses = start(client, buf).await?;
Ok(Query {
Ok(RowStream {
statement,
responses,
})
};
f.try_flatten_stream()
}
pub fn query_portal<'a>(
@ -41,7 +38,7 @@ pub fn query_portal<'a>(
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(Query {
Ok(RowStream {
statement: portal.statement().clone(),
responses,
})
@ -145,12 +142,12 @@ where
}
}
struct Query {
pub struct RowStream {
statement: Statement,
responses: Responses,
}
impl Stream for Query {
impl Stream for RowStream {
type Item = Result<Row, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {

View File

@ -1,5 +1,6 @@
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::query::RowStream;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
@ -93,29 +94,25 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::query`.
pub fn query<'b, T>(
&'b self,
statement: &'b T,
params: &'b [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'b
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.client.query(statement, params)
self.client.query(statement, params).await
}
/// Like `Client::query_iter`.
pub fn query_iter<'b, T, I>(
&'b self,
statement: &'b T,
params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'b
/// Like `Client::query_raw`.
pub async fn query_raw<'b, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql> + 'b,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_iter(statement, params)
self.client.query_raw(statement, params).await
}
/// Like `Client::execute`.
@ -131,7 +128,7 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::execute_iter`.
pub async fn execute_iter<'b, I, T>(
pub async fn execute_raw<'b, I, T>(
&self,
statement: &Statement,
params: I,
@ -141,7 +138,7 @@ impl<'a> Transaction<'a> {
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
self.client.execute_iter(statement, params).await
self.client.execute_raw(statement, params).await
}
/// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.

View File

@ -126,7 +126,7 @@ async fn insert_select() {
let (insert, select) = try_join!(insert, select).unwrap();
let insert = client.execute(&insert, &[&"alice", &"bob"]);
let select = client.query(&select, &[]).try_collect::<Vec<_>>();
let select = client.query(&select, &[]);
let (_, rows) = try_join!(insert, select).unwrap();
assert_eq!(rows.len(), 2);
@ -337,7 +337,6 @@ async fn transaction_commit() {
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -369,7 +368,6 @@ async fn transaction_rollback() {
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -400,7 +398,6 @@ async fn transaction_rollback_drop() {
let stmt = client.prepare("SELECT name FROM foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -436,7 +433,6 @@ async fn copy_in() {
.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -503,7 +499,6 @@ async fn copy_in_error() {
.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 0);

View File

@ -1,4 +1,4 @@
use futures::{join, FutureExt, TryStreamExt};
use futures::{join, FutureExt};
use std::time::{Duration, Instant};
use tokio::timer;
use tokio_postgres::error::SqlState;
@ -18,7 +18,6 @@ async fn smoke_test(s: &str) {
let stmt = client.prepare("SELECT $1::INT").await.unwrap();
let rows = client
.query(&stmt, &[&1i32])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows[0].get::<_, i32>(0), 1i32);

View File

@ -1,4 +1,3 @@
use futures::TryStreamExt;
use postgres_types::to_sql_checked;
use std::collections::HashMap;
use std::error::Error;
@ -35,7 +34,6 @@ where
for (val, repr) in checks {
let rows = client
.query(&*format!("SELECT {}::{}", repr, sql_type), &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
let result = rows[0].get(0);
@ -43,7 +41,6 @@ where
let rows = client
.query(&*format!("SELECT $1::{}", sql_type), &[&val])
.try_collect::<Vec<_>>()
.await
.unwrap();
let result = rows[0].get(0);
@ -200,7 +197,6 @@ async fn test_borrowed_text() {
let stmt = client.prepare("SELECT 'foo'").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
let s: &str = rows[0].get(0);
@ -236,10 +232,11 @@ async fn test_bpchar_params() {
.unwrap();
let rows = client
.query(&stmt, &[])
.map_ok(|row| row.get(0))
.try_collect::<Vec<Option<String>>>()
.await
.unwrap();
.unwrap()
.into_iter()
.map(|row| row.get(0))
.collect::<Vec<Option<String>>>();
assert_eq!(
vec![Some("12345".to_owned()), Some("123 ".to_owned()), None],
@ -276,10 +273,11 @@ async fn test_citext_params() {
.unwrap();
let rows = client
.query(&stmt, &[])
.map_ok(|row| row.get(0))
.try_collect::<Vec<String>>()
.await
.unwrap();
.unwrap()
.into_iter()
.map(|row| row.get(0))
.collect::<Vec<String>>();
assert_eq!(vec!["foobar".to_string(), "FooBar".to_string()], rows,);
}
@ -302,7 +300,6 @@ async fn test_borrowed_bytea() {
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
let s: &[u8] = rows[0].get(0);
@ -365,7 +362,6 @@ where
.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
let val: T = rows[0].get(0);
@ -391,7 +387,6 @@ async fn test_pg_database_datname() {
.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows[0].get::<_, &str>(0), "postgres");
@ -418,10 +413,11 @@ async fn test_slice() {
.unwrap();
let rows = client
.query(&stmt, &[&&[1i32, 3, 4][..]])
.map_ok(|r| r.get(0))
.try_collect::<Vec<String>>()
.await
.unwrap();
.unwrap()
.into_iter()
.map(|r| r.get(0))
.collect::<Vec<String>>();
assert_eq!(vec!["a".to_owned(), "c".to_owned(), "d".to_owned()], rows);
}
@ -445,7 +441,6 @@ async fn test_slice_wrong_type() {
.unwrap();
let err = client
.query(&stmt, &[&&[&"hi"][..]])
.try_collect::<Vec<_>>()
.await
.err()
.unwrap();
@ -462,7 +457,6 @@ async fn test_slice_range() {
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
let err = client
.query(&stmt, &[&&[&1i64][..]])
.try_collect::<Vec<_>>()
.await
.err()
.unwrap();
@ -535,7 +529,6 @@ async fn domain() {
let stmt = client.prepare("SELECT id FROM pg_temp.foo").await.unwrap();
let rows = client
.query(&stmt, &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(id, rows[0].get(0));