Return iterators from query in sync API

This commit is contained in:
Steven Fackler 2018-12-28 20:39:32 -08:00
parent 45593f5ad0
commit 5169820d6a
7 changed files with 138 additions and 22 deletions

View File

@ -11,6 +11,7 @@ runtime = ["tokio-postgres/runtime", "tokio", "lazy_static", "log"]
[dependencies]
bytes = "0.4"
fallible-iterator = "0.1"
futures = "0.1"
tokio-postgres = { version = "0.3", path = "../tokio-postgres", default-features = false }

View File

@ -4,13 +4,13 @@ use futures::{Async, Future, Poll, Stream};
use std::io::{self, BufRead, Cursor, Read};
use std::marker::PhantomData;
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{Error, Row};
use tokio_postgres::Error;
#[cfg(feature = "runtime")]
use tokio_postgres::{MakeTlsMode, Socket, TlsMode};
#[cfg(feature = "runtime")]
use crate::Builder;
use crate::{Statement, ToStatement, Transaction};
use crate::{Query, Statement, ToStatement, Transaction};
pub struct Client(tokio_postgres::Client);
@ -48,12 +48,12 @@ impl Client {
self.0.execute(&statement.0, params).wait()
}
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Vec<Row>, Error>
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Query<'_>, Error>
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
self.0.query(&statement.0, params).collect().wait()
Ok(Query::new(self.0.query(&statement.0, params)))
}
pub fn copy_in<T, R>(

View File

@ -7,6 +7,8 @@ use tokio::runtime::{self, Runtime};
mod builder;
mod client;
mod portal;
mod query;
mod query_portal;
mod statement;
mod to_statement;
mod transaction;
@ -19,6 +21,8 @@ mod test;
pub use crate::builder::*;
pub use crate::client::*;
pub use crate::portal::*;
pub use crate::query::*;
pub use crate::query_portal::*;
pub use crate::statement::*;
pub use crate::to_statement::*;
pub use crate::transaction::*;

36
postgres/src/query.rs Normal file
View File

@ -0,0 +1,36 @@
use fallible_iterator::FallibleIterator;
use futures::stream::{self, Stream};
use std::marker::PhantomData;
use tokio_postgres::{Error, Row};
pub struct Query<'a> {
it: stream::Wait<tokio_postgres::Query>,
_p: PhantomData<&'a mut ()>,
}
// no-op impl to extend the borrow until drop
impl<'a> Drop for Query<'a> {
fn drop(&mut self) {}
}
impl<'a> Query<'a> {
pub(crate) fn new(stream: tokio_postgres::Query) -> Query<'a> {
Query {
it: stream.wait(),
_p: PhantomData,
}
}
}
impl<'a> FallibleIterator for Query<'a> {
type Item = Row;
type Error = Error;
fn next(&mut self) -> Result<Option<Row>, Error> {
match self.it.next() {
Some(Ok(row)) => Ok(Some(row)),
Some(Err(e)) => Err(e),
None => Ok(None),
}
}
}

View File

@ -0,0 +1,36 @@
use fallible_iterator::FallibleIterator;
use futures::stream::{self, Stream};
use std::marker::PhantomData;
use tokio_postgres::{Error, Row};
pub struct QueryPortal<'a> {
it: stream::Wait<tokio_postgres::QueryPortal>,
_p: PhantomData<&'a mut ()>,
}
// no-op impl to extend the borrow until drop
impl<'a> Drop for QueryPortal<'a> {
fn drop(&mut self) {}
}
impl<'a> QueryPortal<'a> {
pub(crate) fn new(stream: tokio_postgres::QueryPortal) -> QueryPortal<'a> {
QueryPortal {
it: stream.wait(),
_p: PhantomData,
}
}
}
impl<'a> FallibleIterator for QueryPortal<'a> {
type Item = Row;
type Error = Error;
fn next(&mut self) -> Result<Option<Row>, Error> {
match self.it.next() {
Some(Ok(row)) => Ok(Some(row)),
Some(Err(e)) => Err(e),
None => Ok(None),
}
}
}

View File

@ -1,3 +1,4 @@
use fallible_iterator::FallibleIterator;
use std::io::Read;
use tokio_postgres::types::Type;
use tokio_postgres::NoTls;
@ -20,7 +21,11 @@ fn query_prepared() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
let stmt = client.prepare("SELECT $1::TEXT").unwrap();
let rows = client.query(&stmt, &[&"hello"]).unwrap();
let rows = client
.query(&stmt, &[&"hello"])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "hello");
}
@ -29,7 +34,11 @@ fn query_prepared() {
fn query_unprepared() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
let rows = client.query("SELECT $1::TEXT", &[&"hello"]).unwrap();
let rows = client
.query("SELECT $1::TEXT", &[&"hello"])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "hello");
}
@ -50,7 +59,11 @@ fn transaction_commit() {
transaction.commit().unwrap();
let rows = client.query("SELECT * FROM foo", &[]).unwrap();
let rows = client
.query("SELECT * FROM foo", &[])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
}
@ -71,7 +84,11 @@ fn transaction_rollback() {
transaction.rollback().unwrap();
let rows = client.query("SELECT * FROM foo", &[]).unwrap();
let rows = client
.query("SELECT * FROM foo", &[])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 0);
}
@ -91,7 +108,11 @@ fn transaction_drop() {
drop(transaction);
let rows = client.query("SELECT * FROM foo", &[]).unwrap();
let rows = client
.query("SELECT * FROM foo", &[])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 0);
}
@ -119,6 +140,8 @@ fn nested_transactions() {
let rows = transaction
.query("SELECT id FROM foo ORDER BY id", &[])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -139,7 +162,11 @@ fn nested_transactions() {
transaction3.commit().unwrap();
transaction.commit().unwrap();
let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap();
let rows = client
.query("SELECT id FROM foo ORDER BY id", &[])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[1].get::<_, i32>(0), 3);
@ -164,6 +191,8 @@ fn copy_in() {
let rows = client
.query("SELECT id, name FROM foo ORDER BY id", &[])
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 2);
@ -219,12 +248,20 @@ fn portal() {
.bind("SELECT * FROM foo ORDER BY id", &[])
.unwrap();
let rows = transaction.query_portal(&portal, 2).unwrap();
let rows = transaction
.query_portal(&portal, 2)
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[1].get::<_, i32>(0), 2);
let rows = transaction.query_portal(&portal, 2).unwrap();
let rows = transaction
.query_portal(&portal, 2)
.unwrap()
.collect::<Vec<_>>()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 3);
}

View File

@ -1,9 +1,9 @@
use futures::{Future, Stream};
use futures::Future;
use std::io::Read;
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{Error, Row};
use tokio_postgres::Error;
use crate::{Client, CopyOutReader, Portal, Statement, ToStatement};
use crate::{Client, CopyOutReader, Portal, Query, QueryPortal, Statement, ToStatement};
pub struct Transaction<'a> {
client: &'a mut Client,
@ -67,7 +67,7 @@ impl<'a> Transaction<'a> {
self.client.execute(query, params)
}
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Vec<Row>, Error>
pub fn query<T>(&mut self, query: &T, params: &[&dyn ToSql]) -> Result<Query<'_>, Error>
where
T: ?Sized + ToStatement,
{
@ -86,12 +86,14 @@ impl<'a> Transaction<'a> {
.map(Portal)
}
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.client
.get_mut()
.query_portal(&portal.0, max_rows)
.collect()
.wait()
pub fn query_portal(
&mut self,
portal: &Portal,
max_rows: i32,
) -> Result<QueryPortal<'_>, Error> {
Ok(QueryPortal::new(
self.client.get_mut().query_portal(&portal.0, max_rows),
))
}
pub fn copy_in<T, R>(