Add query_opt

Closes #510
This commit is contained in:
Steven Fackler 2019-11-30 18:18:50 -05:00
parent 299ef6c8dd
commit b4694471ad
10 changed files with 214 additions and 43 deletions

View File

@ -48,17 +48,17 @@
#![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use std::task::{Context, Poll};
use bytes::{Buf, BufMut};
use std::future::Future;
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use bytes::{Buf, BufMut};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use std::mem::MaybeUninit;
#[cfg(test)]
mod test;

View File

@ -42,7 +42,7 @@
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use std::task::{Poll, Context};
use bytes::{Buf, BufMut};
#[cfg(feature = "runtime")]
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
@ -53,17 +53,17 @@ use openssl::ssl::{ConnectConfiguration, SslRef};
use std::fmt::Debug;
use std::future::Future;
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
#[cfg(feature = "runtime")]
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use bytes::{Buf, BufMut};
use tokio_openssl::{HandshakeError, SslStream};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use std::mem::MaybeUninit;
#[cfg(test)]
mod test;

View File

@ -119,6 +119,8 @@ impl Client {
/// Executes a statement which returns a single row, returning it.
///
/// Returns an error if the query does not return exactly one row.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
@ -152,6 +154,52 @@ impl Client {
executor::block_on(self.0.query_one(query, params))
}
/// Executes a statement which returns zero or one rows, returning it.
///
/// Returns an error if the query returns more than one row.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// # Examples
///
/// ```no_run
/// use postgres::{Client, NoTls};
///
/// # fn main() -> Result<(), postgres::Error> {
/// let mut client = Client::connect("host=localhost user=postgres", NoTls)?;
///
/// let baz = true;
/// let row = client.query_opt("SELECT foo FROM bar WHERE baz = $1", &[&baz])?;
/// match row {
/// Some(row) => {
/// let foo: i32 = row.get("foo");
/// println!("foo: {}", foo);
/// }
/// None => println!("no matching foo"),
/// }
/// # Ok(())
/// # }
/// ```
pub fn query_opt<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, Error>
where
T: ?Sized + ToStatement,
{
executor::block_on(self.0.query_opt(query, params))
}
/// A maximally-flexible version of `query`.
///
/// It takes an iterator of parameters rather than a slice, and returns an iterator of rows rather than collecting

View File

@ -60,6 +60,18 @@ impl<'a> Transaction<'a> {
executor::block_on(self.0.query_one(query, params))
}
/// Like `Client::query_opt`.
pub fn query_opt<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, Error>
where
T: ?Sized + ToStatement,
{
executor::block_on(self.0.query_opt(query, params))
}
/// Like `Client::query_raw`.
pub fn query_raw<'b, T, I>(&mut self, query: &T, params: I) -> Result<RowIter<'_>, Error>
where

View File

@ -1,16 +1,16 @@
use bytes::{BufMut, Bytes, BytesMut, Buf};
use futures::{ready, Stream, SinkExt};
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{ready, SinkExt, Stream};
use pin_project_lite::pin_project;
use std::convert::TryFrom;
use std::error::Error;
use std::io::Cursor;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_postgres::types::{IsNull, ToSql, Type, FromSql, WrongType};
use tokio_postgres::{CopyOutStream, CopyInSink};
use std::io::Cursor;
use byteorder::{ByteOrder, BigEndian};
use tokio_postgres::types::{FromSql, IsNull, ToSql, Type, WrongType};
use tokio_postgres::{CopyInSink, CopyOutStream};
#[cfg(test)]
mod test;
@ -49,10 +49,13 @@ impl BinaryCopyInWriter {
self.write_raw(values.iter().cloned()).await
}
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
where
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
I::IntoIter: ExactSizeIterator,
pub async fn write_raw<'a, I>(
self: Pin<&mut Self>,
values: I,
) -> Result<(), Box<dyn Error + Sync + Send>>
where
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
I::IntoIter: ExactSizeIterator,
{
let mut this = self.project();
@ -126,7 +129,7 @@ impl Stream for BinaryCopyOutStream {
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
None => return Poll::Ready(Some(Err("unexpected EOF".into()))),
};
let mut chunk= Cursor::new(chunk);
let mut chunk = Cursor::new(chunk);
let has_oids = match &this.header {
Some(header) => header.has_oids,
@ -200,7 +203,10 @@ pub struct BinaryCopyOutRow {
}
impl BinaryCopyOutRow {
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>> where T: FromSql<'a> {
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>>
where
T: FromSql<'a>,
{
let type_ = &self.types[idx];
if !T::accepts(type_) {
return Err(WrongType::new::<T>(type_.clone()).into());
@ -208,11 +214,14 @@ impl BinaryCopyOutRow {
match &self.ranges[idx] {
Some(range) => T::from_sql(type_, &self.buf[range.clone()]).map_err(Into::into),
None => T::from_sql_null(type_).map_err(Into::into)
None => T::from_sql_null(type_).map_err(Into::into),
}
}
pub fn get<'a, T>(&'a self, idx: usize) -> T where T: FromSql<'a> {
pub fn get<'a, T>(&'a self, idx: usize) -> T
where
T: FromSql<'a>,
{
match self.try_get(idx) {
Ok(value) => value,
Err(e) => panic!("error retrieving column {}: {}", idx, e),

View File

@ -1,7 +1,7 @@
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
use futures::{pin_mut, TryStreamExt};
use tokio_postgres::types::Type;
use tokio_postgres::{Client, NoTls};
use futures::{TryStreamExt, pin_mut};
async fn connect() -> Client {
let (client, connection) =
@ -23,11 +23,18 @@ async fn write_basic() {
.await
.unwrap();
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
let sink = client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[])
.await
.unwrap();
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]);
pin_mut!(writer);
writer.as_mut().write(&[&1i32, &"foobar"]).await.unwrap();
writer.as_mut().write(&[&2i32, &None::<&str>]).await.unwrap();
writer
.as_mut()
.write(&[&2i32, &None::<&str>])
.await
.unwrap();
writer.finish().await.unwrap();
let rows = client
@ -50,12 +57,19 @@ async fn write_many_rows() {
.await
.unwrap();
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
let sink = client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[])
.await
.unwrap();
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]);
pin_mut!(writer);
for i in 0..10_000i32 {
writer.as_mut().write(&[&i, &format!("the value for {}", i)]).await.unwrap();
writer
.as_mut()
.write(&[&i, &format!("the value for {}", i)])
.await
.unwrap();
}
writer.finish().await.unwrap();
@ -79,12 +93,19 @@ async fn write_big_rows() {
.await
.unwrap();
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
let sink = client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[])
.await
.unwrap();
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::BYTEA]);
pin_mut!(writer);
for i in 0..2i32 {
writer.as_mut().write(&[&i, &vec![i as u8; 128 * 1024]]).await.unwrap();
writer
.as_mut()
.write(&[&i, &vec![i as u8; 128 * 1024]])
.await
.unwrap();
}
writer.finish().await.unwrap();
@ -108,13 +129,19 @@ async fn read_basic() {
"
CREATE TEMPORARY TABLE foo (id INT, bar TEXT);
INSERT INTO foo (id, bar) VALUES (1, 'foobar'), (2, NULL);
"
",
)
.await
.unwrap();
let stream = client.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[]).await.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream).try_collect::<Vec<_>>().await.unwrap();
let stream = client
.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[])
.await
.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<i32>(0), 1);
@ -136,8 +163,14 @@ async fn read_many_rows() {
.await
.unwrap();
let stream = client.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[]).await.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream).try_collect::<Vec<_>>().await.unwrap();
let stream = client
.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[])
.await
.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 10_000);
for (i, row) in rows.iter().enumerate() {
@ -155,11 +188,23 @@ async fn read_big_rows() {
.await
.unwrap();
for i in 0..2i32 {
client.execute("INSERT INTO foo (id, bar) VALUES ($1, $2)", &[&i, &vec![i as u8; 128 * 1024]]).await.unwrap();
client
.execute(
"INSERT INTO foo (id, bar) VALUES ($1, $2)",
&[&i, &vec![i as u8; 128 * 1024]],
)
.await
.unwrap();
}
let stream = client.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[]).await.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream).try_collect::<Vec<_>>().await.unwrap();
let stream = client
.copy_out("COPY foo (id, bar) TO STDIN BINARY", &[])
.await
.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream)
.try_collect::<Vec<_>>()
.await
.unwrap();
assert_eq!(rows.len(), 2);
for (i, row) in rows.iter().enumerate() {

View File

@ -231,6 +231,8 @@ impl Client {
/// Executes a statement which returns a single row, returning it.
///
/// Returns an error if the query does not return exactly one row.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
@ -238,8 +240,6 @@ impl Client {
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// Returns an error if the query does not return exactly one row.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
@ -266,6 +266,43 @@ impl Client {
Ok(row)
}
/// Executes a statements which returns zero or one rows, returning it.
///
/// Returns an error if the query returns more than one row.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn query_opt<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, Error>
where
T: ?Sized + ToStatement,
{
let stream = self.query_raw(statement, slice_iter(params)).await?;
pin_mut!(stream);
let row = match stream.try_next().await? {
Some(row) => row,
None => return Ok(None),
};
if stream.try_next().await?.is_some() {
return Err(Error::row_count());
}
Ok(Some(row))
}
/// The maximally flexible version of [`query`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list

View File

@ -6,13 +6,13 @@ use crate::{query, Error, Statement};
use bytes::buf::BufExt;
use bytes::{Buf, BufMut, BytesMut};
use futures::channel::mpsc;
use futures::{ready, Sink, SinkExt, Stream, StreamExt};
use futures::future;
use futures::{ready, Sink, SinkExt, Stream, StreamExt};
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::CopyData;
use std::marker::{PhantomPinned, PhantomData};
use std::marker::{PhantomData, PhantomPinned};
use std::pin::Pin;
use std::task::{Context, Poll};

View File

@ -13,7 +13,7 @@ use crate::{
ToStatement,
};
use bytes::Buf;
use futures::{TryStreamExt};
use futures::TryStreamExt;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite};
@ -120,6 +120,18 @@ impl<'a> Transaction<'a> {
self.client.query_one(statement, params).await
}
/// Like `Client::query_opt`.
pub async fn query_opt<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Option<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.client.query_opt(statement, params).await
}
/// Like `Client::query_raw`.
pub async fn query_raw<'b, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where

View File

@ -2,7 +2,9 @@
use bytes::{Bytes, BytesMut};
use futures::channel::mpsc;
use futures::{future, stream, StreamExt, SinkExt, pin_mut, join, try_join, FutureExt, TryStreamExt};
use futures::{
future, join, pin_mut, stream, try_join, FutureExt, SinkExt, StreamExt, TryStreamExt,
};
use std::fmt::Write;
use std::time::Duration;
use tokio::net::TcpStream;
@ -422,7 +424,10 @@ async fn copy_in() {
let rows = sink.finish().await.unwrap();
assert_eq!(rows, 2);
let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap();
let rows = client
.query("SELECT id, name FROM foo ORDER BY id", &[])
.await
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -487,7 +492,10 @@ async fn copy_in_error() {
sink.send(Bytes::from_static(b"1\tsteven")).await.unwrap();
}
let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap();
let rows = client
.query("SELECT id, name FROM foo ORDER BY id", &[])
.await
.unwrap();
assert_eq!(rows.len(), 0);
}