Add a ToStatement trait in tokio-postgres

This commit is contained in:
Steven Fackler 2019-09-28 14:42:00 -04:00
parent 286ecdb5b9
commit 0d2d554122
19 changed files with 410 additions and 407 deletions

View File

@ -38,7 +38,7 @@
//! .build()?;
//! let connector = MakeTlsConnector::new(connector);
//!
//! let mut client = postgres::Client::connect(
//! let client = postgres::Client::connect(
//! "host=localhost user=postgres sslmode=require",
//! connector,
//! )?;

View File

@ -30,7 +30,7 @@
//! builder.set_ca_file("database_cert.pem")?;
//! let connector = MakeTlsConnector::new(builder.build());
//!
//! let mut client = postgres::Client::connect(
//! let client = postgres::Client::connect(
//! "host=localhost user=postgres sslmode=require",
//! connector,
//! )?;

View File

@ -84,8 +84,7 @@ impl Client {
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
executor::block_on(self.0.execute(&statement, params))
executor::block_on(self.0.execute(query, params))
}
/// Executes a statement, returning the resulting rows.
@ -154,14 +153,13 @@ impl Client {
/// ```
pub fn query_iter<'a, T>(
&'a mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
query: &'a T,
params: &'a [&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'a, Error>
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
Ok(Iter::new(self.0.query(&statement, params)))
Ok(Iter::new(self.0.query(query, params)))
}
/// Creates a new prepared statement.
@ -249,8 +247,7 @@ impl Client {
T: ?Sized + ToStatement,
R: Read + Unpin,
{
let statement = query.__statement(self)?;
executor::block_on(self.0.copy_in(&statement, params, CopyInStream(reader)))
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
}
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
@ -274,14 +271,13 @@ impl Client {
/// ```
pub fn copy_out<'a, T>(
&'a mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
query: &'a T,
params: &'a [&(dyn ToSql + Sync)],
) -> Result<impl BufRead + 'a, Error>
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
let stream = self.0.copy_out(&statement, params);
let stream = self.0.copy_out(query, params);
CopyOutReader::new(stream)
}
@ -314,7 +310,7 @@ impl Client {
/// them to this method!
pub fn simple_query_iter<'a>(
&'a mut self,
query: &str,
query: &'a str,
) -> Result<impl FallibleIterator<Item = SimpleQueryMessage, Error = Error> + 'a, Error> {
Ok(Iter::new(self.0.simple_query(query)))
}

View File

@ -62,7 +62,9 @@ use tokio::runtime::{self, Runtime};
#[cfg(feature = "runtime")]
pub use tokio_postgres::Socket;
pub use tokio_postgres::{error, row, tls, types, Column, Portal, SimpleQueryMessage, Statement};
pub use tokio_postgres::{
error, row, tls, types, Column, Portal, SimpleQueryMessage, Statement, ToStatement,
};
pub use crate::client::*;
#[cfg(feature = "runtime")]
@ -73,7 +75,6 @@ pub use crate::error::Error;
pub use crate::row::{Row, SimpleQueryRow};
#[doc(no_inline)]
pub use crate::tls::NoTls;
pub use crate::to_statement::*;
pub use crate::transaction::*;
mod client;
@ -82,7 +83,6 @@ pub mod config;
mod copy_in_stream;
mod copy_out_reader;
mod iter;
mod to_statement;
mod transaction;
#[cfg(feature = "runtime")]

View File

@ -1,59 +0,0 @@
use tokio_postgres::Error;
use crate::{Client, Statement, Transaction};
mod sealed {
pub trait Sealed {}
}
#[doc(hidden)]
pub trait Prepare {
fn prepare(&mut self, query: &str) -> Result<Statement, Error>;
}
impl Prepare for Client {
fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.prepare(query)
}
}
impl<'a> Prepare for Transaction<'a> {
fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.prepare(query)
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: sealed::Sealed {
#[doc(hidden)]
fn __statement<T>(&self, client: &mut T) -> Result<Statement, Error>
where
T: Prepare;
}
impl sealed::Sealed for str {}
impl ToStatement for str {
fn __statement<T>(&self, client: &mut T) -> Result<Statement, Error>
where
T: Prepare,
{
client.prepare(self)
}
}
impl sealed::Sealed for Statement {}
impl ToStatement for Statement {
fn __statement<T>(&self, _: &mut T) -> Result<Statement, Error>
where
T: Prepare,
{
Ok(self.clone())
}
}

View File

@ -47,8 +47,7 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
executor::block_on(self.0.execute(&statement, params))
executor::block_on(self.0.execute(query, params))
}
/// Like `Client::query`.
@ -60,16 +59,15 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::query_iter`.
pub fn query_iter<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error>
pub fn query_iter<'b, T>(
&'b mut self,
query: &'b T,
params: &'b [&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error>
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
Ok(Iter::new(self.0.query(&statement, params)))
Ok(Iter::new(self.0.query(query, params)))
}
/// Binds parameters to a statement, creating a "portal".
@ -86,8 +84,7 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
executor::block_on(self.0.bind(&statement, params))
executor::block_on(self.0.bind(query, params))
}
/// Continues execution of a portal, returning the next set of rows.
@ -100,11 +97,11 @@ impl<'a> Transaction<'a> {
/// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering
/// the entire response in memory.
pub fn query_portal_iter(
&mut self,
portal: &Portal,
pub fn query_portal_iter<'b>(
&'b mut self,
portal: &'b Portal,
max_rows: i32,
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> {
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error> {
Ok(Iter::new(self.0.query_portal(&portal, max_rows)))
}
@ -119,21 +116,19 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement,
R: Read + Unpin,
{
let statement = query.__statement(self)?;
executor::block_on(self.0.copy_in(&statement, params, CopyInStream(reader)))
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
}
/// Like `Client::copy_out`.
pub fn copy_out<'b, T>(
&'a mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
&'b mut self,
query: &'b T,
params: &'b [&(dyn ToSql + Sync)],
) -> Result<impl BufRead + 'b, Error>
where
T: ?Sized + ToStatement,
{
let statement = query.__statement(self)?;
let stream = self.0.copy_out(&statement, params);
let stream = self.0.copy_out(query, params);
CopyOutReader::new(stream)
}
@ -145,7 +140,7 @@ impl<'a> Transaction<'a> {
/// Like `Client::simple_query_iter`.
pub fn simple_query_iter<'b>(
&'b mut self,
query: &str,
query: &'b str,
) -> Result<impl FallibleIterator<Item = SimpleQueryMessage, Error = Error> + 'b, Error> {
Ok(Iter::new(self.0.simple_query(query)))
}

View File

@ -10,36 +10,25 @@ use std::sync::Arc;
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn bind(
client: Arc<InnerClient>,
pub async fn bind<'a, I>(
client: &Arc<InnerClient>,
statement: Statement,
bind: Result<PendingBind, Error>,
) -> Result<Portal, Error> {
let bind = bind?;
params: I,
) -> Result<Portal, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let mut buf = query::encode_bind(&statement, params, &name)?;
frontend::sync(&mut buf);
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(bind.buf)))?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(Portal::new(&client, bind.name, statement))
}
pub struct PendingBind {
buf: Vec<u8>,
name: String,
}
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<PendingBind, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let mut buf = query::encode_bind(statement, params, &name)?;
frontend::sync(&mut buf);
Ok(PendingBind { buf, name })
Ok(Portal::new(client, name, statement))
}

View File

@ -3,9 +3,11 @@ use crate::cancel_query;
use crate::codec::BackendMessages;
use crate::config::{Host, SslMode};
use crate::connection::{Request, RequestMessages};
use crate::slice_iter;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::to_statement::ToStatement;
use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
@ -16,13 +18,12 @@ use crate::{Error, Statement};
use bytes::{Bytes, IntoBuf};
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{future, Stream, TryStream};
use futures::{future, Stream, TryFutureExt, TryStream};
use futures::{ready, StreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
use std::collections::HashMap;
use std::error;
use std::future::Future;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
@ -160,8 +161,8 @@ impl Client {
}
}
pub(crate) fn inner(&self) -> Arc<InnerClient> {
self.inner.clone()
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
}
#[cfg(feature = "runtime")]
@ -194,29 +195,35 @@ impl Client {
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn query<'a>(
pub fn query<'a, T>(
&'a self,
statement: &'a Statement,
params: &'a [&'a (dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'a {
let buf = query::encode(statement, params.iter().map(|s| *s as _));
query::query(&self.inner, statement, buf)
statement: &'a T,
params: &'a [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'a
where
T: ?Sized + ToStatement,
{
self.query_iter(statement, slice_iter(params))
}
/// Like [`query`], but takes an iterator of parameters rather than a slice.
///
/// [`query`]: #method.query
pub fn query_iter<'a, I>(
pub fn query_iter<'a, T, I>(
&'a self,
statement: &'a Statement,
statement: &'a T,
params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'a
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I::IntoIter: ExactSizeIterator,
{
let buf = query::encode(statement, params);
query::query(&self.inner, statement, buf)
let f = async move {
let statement = statement.__convert().into_statement(self).await?;
Ok(query::query(&self.inner, statement, params))
};
f.try_flatten_stream()
}
/// Executes a statement, returning the number of rows modified.
@ -226,29 +233,28 @@ impl Client {
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn execute(
pub async fn execute<T>(
&self,
statement: &Statement,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error> {
let buf = query::encode(statement, params.iter().map(|s| *s as _));
query::execute(&self.inner, buf).await
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.execute_iter(statement, slice_iter(params)).await
}
/// Like [`execute`], but takes an iterator of parameters rather than a slice.
///
/// [`execute`]: #method.execute
pub async fn execute_iter<'a, I>(
&self,
statement: &Statement,
params: I,
) -> Result<u64, Error>
pub async fn execute_iter<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let buf = query::encode(statement, params);
query::execute(&self.inner, buf).await
let statement = statement.__convert().into_statement(self).await?;
query::execute(self.inner(), statement, params).await
}
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
@ -259,20 +265,22 @@ impl Client {
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn copy_in<S>(
pub async fn copy_in<T, S>(
&self,
statement: &Statement,
statement: &T,
params: &[&(dyn ToSql + Sync)],
stream: S,
) -> impl Future<Output = Result<u64, Error>>
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
let buf = query::encode(statement, params.iter().map(|s| *s as _));
copy_in::copy_in(self.inner(), buf, stream)
let statement = statement.__convert().into_statement(self).await?;
let params = slice_iter(params);
copy_in::copy_in(self.inner(), statement, params, stream).await
}
/// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.
@ -280,13 +288,20 @@ impl Client {
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn copy_out(
&self,
statement: &Statement,
params: &[&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Bytes, Error>> {
let buf = query::encode(statement, params.iter().map(|s| *s as _));
copy_out::copy_out(self.inner(), buf)
pub fn copy_out<'a, T>(
&'a self,
statement: &'a T,
params: &'a [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Bytes, Error>> + 'a
where
T: ?Sized + ToStatement,
{
let f = async move {
let statement = statement.__convert().into_statement(self).await?;
let params = slice_iter(params);
Ok(copy_out::copy_out(self.inner(), statement, params))
};
f.try_flatten_stream()
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
@ -302,10 +317,10 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub fn simple_query(
&self,
query: &str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> {
pub fn simple_query<'a>(
&'a self,
query: &'a str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> + 'a {
simple_query::simple_query(self.inner(), query)
}
@ -319,8 +334,8 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub fn batch_execute(&self, query: &str) -> impl Future<Output = Result<(), Error>> {
simple_query::batch_execute(self.inner(), query)
pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
simple_query::batch_execute(self.inner(), query).await
}
/// Begins a new database transaction.
@ -338,7 +353,7 @@ impl Client {
///
/// Requires the `runtime` Cargo feature (enabled by default).
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, tls: T) -> impl Future<Output = Result<(), Error>>
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where
T: MakeTlsConnect<Socket>,
{
@ -349,15 +364,12 @@ impl Client {
self.process_id,
self.secret_key,
)
.await
}
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub fn cancel_query_raw<S, T>(
&self,
stream: S,
tls: T,
) -> impl Future<Output = Result<(), Error>>
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
@ -369,6 +381,7 @@ impl Client {
self.process_id,
self.secret_key,
)
.await
}
/// Determines if the connection to the server has already closed.

View File

@ -1,7 +1,8 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::Error;
use crate::types::ToSql;
use crate::{query, Error, Statement};
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use futures::channel::mpsc;
use futures::ready;
@ -12,7 +13,6 @@ use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::CopyData;
use std::error;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
enum CopyInMessage {
@ -62,18 +62,21 @@ impl Stream for CopyInReceiver {
}
}
pub async fn copy_in<S>(
client: Arc<InnerClient>,
buf: Result<Vec<u8>, Error>,
pub async fn copy_in<'a, I, S>(
client: &InnerClient,
statement: Statement,
params: I,
stream: S,
) -> Result<u64, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
let buf = buf?;
let buf = query::encode(&statement, params)?;
let (mut sender, receiver) = mpsc::channel(1);
let receiver = CopyInReceiver::new(receiver);

View File

@ -1,25 +1,32 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::Error;
use crate::types::ToSql;
use crate::{query, Error, Statement};
use bytes::Bytes;
use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub fn copy_out(
client: Arc<InnerClient>,
buf: Result<Vec<u8>, Error>,
) -> impl Stream<Item = Result<Bytes, Error>> {
start(client, buf)
.map_ok(|responses| CopyOut { responses })
.try_flatten_stream()
pub fn copy_out<'a, I>(
client: &'a InnerClient,
statement: Statement,
params: I,
) -> impl Stream<Item = Result<Bytes, Error>> + 'a
where
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I::IntoIter: ExactSizeIterator,
{
let f = async move {
let buf = query::encode(&statement, params)?;
let responses = start(client, buf).await?;
Ok(CopyOut { responses })
};
f.try_flatten_stream()
}
async fn start(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<Responses, Error> {
let buf = buf?;
async fn start(client: &InnerClient, buf: Vec<u8>) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {

View File

@ -114,11 +114,13 @@ pub use crate::portal::Portal;
pub use crate::row::{Row, SimpleQueryRow};
#[cfg(feature = "runtime")]
pub use crate::socket::Socket;
pub use crate::statement::{Column, Statement};
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
pub use crate::tls::NoTls;
pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction;
pub use statement::{Column, Statement};
use crate::types::ToSql;
mod bind;
#[cfg(feature = "runtime")]
@ -147,6 +149,7 @@ mod simple_query;
mod socket;
mod statement;
pub mod tls;
mod to_statement;
mod transaction;
pub mod types;
@ -220,3 +223,9 @@ pub enum SimpleQueryMessage {
#[doc(hidden)]
__NonExhaustive,
}
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
s.iter().map(|s| *s as _)
}

View File

@ -2,8 +2,8 @@ use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::error::SqlState;
use crate::query;
use crate::types::{Field, Kind, Oid, ToSql, Type};
use crate::types::{Field, Kind, Oid, Type};
use crate::{query, slice_iter};
use crate::{Column, Error, Statement};
use fallible_iterator::FallibleIterator;
use futures::{future, TryStreamExt};
@ -132,8 +132,8 @@ async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
let stmt = typeinfo_statement(client).await?;
let buf = query::encode(&stmt, (&[&oid as &dyn ToSql]).iter().cloned());
let rows = query::query(client, &stmt, buf);
let params = &[&oid as _];
let rows = query::query(client, stmt, slice_iter(params));
pin_mut!(rows);
let row = match rows.try_next().await? {
@ -203,8 +203,7 @@ async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Erro
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client).await?;
let buf = query::encode(&stmt, (&[&oid as &dyn ToSql]).iter().cloned());
query::query(client, &stmt, buf)
query::query(client, stmt, slice_iter(&[&oid]))
.and_then(|row| future::ready(row.try_get(0)))
.try_collect()
.await
@ -230,8 +229,7 @@ async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement,
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client).await?;
let buf = query::encode(&stmt, (&[&oid as &dyn ToSql]).iter().cloned());
let rows = query::query(client, &stmt, buf)
let rows = query::query(client, stmt, slice_iter(&[&oid]))
.try_collect::<Vec<_>>()
.await?;

View File

@ -7,26 +7,33 @@ use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub fn query<'a>(
client: &'a Arc<InnerClient>,
statement: &'a Statement,
buf: Result<Vec<u8>, Error>,
) -> impl Stream<Item = Result<Row, Error>> + 'a {
pub fn query<'a, I>(
client: &'a InnerClient,
statement: Statement,
params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'a
where
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I::IntoIter: ExactSizeIterator,
{
let f = async move {
let buf = encode(&statement, params)?;
let responses = start(client, buf).await?;
Ok(Query { statement: statement.clone(), responses })
Ok(Query {
statement,
responses,
})
};
f.try_flatten_stream()
}
pub fn query_portal(
client: Arc<InnerClient>,
portal: Portal,
pub fn query_portal<'a>(
client: &'a InnerClient,
portal: &'a Portal,
max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> {
) -> impl Stream<Item = Result<Row, Error>> + 'a {
let start = async move {
let mut buf = vec![];
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
@ -43,7 +50,16 @@ pub fn query_portal(
start.try_flatten_stream()
}
pub async fn execute(client: &InnerClient, buf: Result<Vec<u8>, Error>) -> Result<u64, Error> {
pub async fn execute<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<u64, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let buf = encode(&statement, params)?;
let mut responses = start(client, buf).await?;
loop {
@ -66,8 +82,7 @@ pub async fn execute(client: &InnerClient, buf: Result<Vec<u8>, Error>) -> Resul
}
}
async fn start(client: &InnerClient, buf: Result<Vec<u8>, Error>) -> Result<Responses, Error> {
let buf = buf?;
async fn start(client: &InnerClient, buf: Vec<u8>) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {

View File

@ -6,19 +6,16 @@ use fallible_iterator::FallibleIterator;
use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub fn simple_query(
client: Arc<InnerClient>,
query: &str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> {
let buf = encode(query);
let start = async move {
let buf = buf?;
pub fn simple_query<'a>(
client: &'a InnerClient,
query: &'a str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> + 'a {
let f = async move {
let buf = encode(query)?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(SimpleQuery {
@ -26,18 +23,11 @@ pub fn simple_query(
columns: None,
})
};
start.try_flatten_stream()
f.try_flatten_stream()
}
pub fn batch_execute(
client: Arc<InnerClient>,
query: &str,
) -> impl Future<Output = Result<(), Error>> {
let buf = encode(query);
async move {
let buf = buf?;
pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Error> {
let buf = encode(query)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
loop {
@ -50,7 +40,6 @@ pub fn batch_execute(
_ => return Err(Error::unexpected_message()),
}
}
}
}
fn encode(query: &str) -> Result<Vec<u8>, Error> {

View File

@ -0,0 +1,49 @@
use crate::to_statement::private::{Sealed, ToStatementType};
use crate::Statement;
mod private {
use crate::{Client, Error, Statement};
pub trait Sealed {}
pub enum ToStatementType<'a> {
Statement(&'a Statement),
Query(&'a str),
}
impl<'a> ToStatementType<'a> {
pub async fn into_statement(self, client: &Client) -> Result<Statement, Error> {
match self {
ToStatementType::Statement(s) => Ok(s.clone()),
ToStatementType::Query(s) => client.prepare(s).await,
}
}
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: private::Sealed {
#[doc(hidden)]
fn __convert(&self) -> ToStatementType<'_>;
}
impl ToStatement for Statement {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Statement(self)
}
}
impl Sealed for Statement {}
impl ToStatement for str {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for str {}

View File

@ -6,12 +6,13 @@ use crate::tls::TlsConnect;
use crate::types::{ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{bind, query, Client, Error, Portal, Row, SimpleQueryMessage, Statement};
use crate::{
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
};
use bytes::{Bytes, IntoBuf};
use futures::{Stream, TryStream};
use postgres_protocol::message::frontend;
use std::error;
use std::future::Future;
use tokio::io::{AsyncRead, AsyncWrite};
/// A representation of a PostgreSQL database transaction.
@ -92,21 +93,25 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::query`.
pub fn query<'b>(
pub fn query<'b, T>(
&'b self,
statement: &'b Statement,
params: &'b [&'b (dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'b {
statement: &'b T,
params: &'b [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'b
where
T: ?Sized + ToStatement,
{
self.client.query(statement, params)
}
/// Like `Client::query_iter`.
pub fn query_iter<'b, I>(
pub fn query_iter<'b, T, I>(
&'b self,
statement: &'b Statement,
statement: &'b T,
params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'b
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql> + 'b,
I::IntoIter: ExactSizeIterator,
{
@ -114,21 +119,25 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::execute`.
pub async fn execute(
pub async fn execute<T>(
&self,
statement: &Statement,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error> {
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.client.execute(statement, params).await
}
/// Like `Client::execute_iter`.
pub async fn execute_iter<'b, I>(
pub async fn execute_iter<'b, I, T>(
&self,
statement: &Statement,
params: I,
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
@ -143,102 +152,100 @@ impl<'a> Transaction<'a> {
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub fn bind(
pub async fn bind<T>(
&self,
statement: &Statement,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> impl Future<Output = Result<Portal, Error>> {
// https://github.com/rust-lang/rust/issues/63032
let buf = bind::encode(statement, params.iter().map(|s| *s as _));
bind::bind(self.client.inner(), statement.clone(), buf)
) -> Result<Portal, Error>
where
T: ?Sized + ToStatement,
{
self.bind_iter(statement, slice_iter(params)).await
}
/// Like [`bind`], but takes an iterator of parameters rather than a slice.
///
/// [`bind`]: #method.bind
pub fn bind_iter<'b, I>(
&self,
statement: &Statement,
params: I,
) -> impl Future<Output = Result<Portal, Error>>
pub async fn bind_iter<'b, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let buf = bind::encode(statement, params);
bind::bind(self.client.inner(), statement.clone(), buf)
let statement = statement.__convert().into_statement(&self.client).await?;
bind::bind(self.client.inner(), statement, params).await
}
/// Continues execution of a portal, returning a stream of the resulting rows.
///
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
pub fn query_portal(
&self,
portal: &Portal,
pub fn query_portal<'b>(
&'b self,
portal: &'b Portal,
max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> {
query::query_portal(self.client.inner(), portal.clone(), max_rows)
) -> impl Stream<Item = Result<Row, Error>> + 'b {
query::query_portal(self.client.inner(), portal, max_rows)
}
/// Like `Client::copy_in`.
pub fn copy_in<S>(
pub async fn copy_in<T, S>(
&self,
statement: &Statement,
statement: &T,
params: &[&(dyn ToSql + Sync)],
stream: S,
) -> impl Future<Output = Result<u64, Error>>
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
self.client.copy_in(statement, params, stream)
self.client.copy_in(statement, params, stream).await
}
/// Like `Client::copy_out`.
pub fn copy_out(
&self,
statement: &Statement,
params: &[&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Bytes, Error>> {
pub fn copy_out<'b, T>(
&'b self,
statement: &'b T,
params: &'b [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Bytes, Error>> + 'b
where
T: ?Sized + ToStatement,
{
self.client.copy_out(statement, params)
}
/// Like `Client::simple_query`.
pub fn simple_query(
&self,
query: &str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> {
pub fn simple_query<'b>(
&'b self,
query: &'b str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> + 'b {
self.client.simple_query(query)
}
/// Like `Client::batch_execute`.
pub fn batch_execute(&self, query: &str) -> impl Future<Output = Result<(), Error>> {
self.client.batch_execute(query)
pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
self.client.batch_execute(query).await
}
/// Like `Client::cancel_query`.
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, tls: T) -> impl Future<Output = Result<(), Error>>
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where
T: MakeTlsConnect<Socket>,
{
self.client.cancel_query(tls)
self.client.cancel_query(tls).await
}
/// Like `Client::cancel_query_raw`.
pub fn cancel_query_raw<S, T>(
&self,
stream: S,
tls: T,
) -> impl Future<Output = Result<(), Error>>
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
self.client.cancel_query_raw(stream, tls)
self.client.cancel_query_raw(stream, tls).await
}
/// Like `Client::transaction`.

View File

@ -98,7 +98,7 @@ async fn scram_password_ok() {
#[tokio::test]
async fn pipelined_prepare() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let prepare1 = client.prepare("SELECT $1::HSTORE[]");
let prepare2 = client.prepare("SELECT $1::BIGINT");
@ -114,7 +114,7 @@ async fn pipelined_prepare() {
#[tokio::test]
async fn insert_select() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
@ -138,7 +138,7 @@ async fn insert_select() {
#[tokio::test]
async fn custom_enum() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -167,7 +167,7 @@ async fn custom_enum() {
#[tokio::test]
async fn custom_domain() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)")
@ -183,7 +183,7 @@ async fn custom_domain() {
#[tokio::test]
async fn custom_array() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let select = client.prepare("SELECT $1::HSTORE[]").await.unwrap();
@ -200,7 +200,7 @@ async fn custom_array() {
#[tokio::test]
async fn custom_composite() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -232,7 +232,7 @@ async fn custom_composite() {
#[tokio::test]
async fn custom_range() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -253,7 +253,7 @@ async fn custom_range() {
#[tokio::test]
async fn simple_query() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let messages = client
.simple_query(
@ -299,7 +299,7 @@ async fn simple_query() {
#[tokio::test]
async fn cancel_query_raw() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let cancel = client.cancel_query_raw(socket, NoTls);
@ -327,7 +327,7 @@ async fn transaction_commit() {
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
let transaction = client.transaction().await.unwrap();
transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await
@ -359,7 +359,7 @@ async fn transaction_rollback() {
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
let transaction = client.transaction().await.unwrap();
transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await
@ -390,7 +390,7 @@ async fn transaction_rollback_drop() {
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
let transaction = client.transaction().await.unwrap();
transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await
@ -409,7 +409,7 @@ async fn transaction_rollback_drop() {
#[tokio::test]
async fn copy_in() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -449,7 +449,7 @@ async fn copy_in() {
#[tokio::test]
async fn copy_in_large() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -480,7 +480,7 @@ async fn copy_in_large() {
#[tokio::test]
async fn copy_in_error() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -511,7 +511,7 @@ async fn copy_in_error() {
#[tokio::test]
async fn copy_out() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -532,7 +532,7 @@ async fn copy_out() {
#[tokio::test]
async fn notifications() {
let (mut client, mut connection) = connect_raw("user=postgres").await.unwrap();
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();
let (tx, rx) = mpsc::unbounded();
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
@ -585,7 +585,7 @@ async fn query_portal() {
.await
.unwrap();
let mut transaction = client.transaction().await.unwrap();
let transaction = client.transaction().await.unwrap();
let portal = transaction.bind(&stmt, &[]).await.unwrap();
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
@ -624,3 +624,36 @@ async fn prefer_channel_binding() {
async fn disable_channel_binding() {
connect("user=postgres channel_binding=disable").await;
}
#[tokio::test]
async fn check_send() {
fn is_send<T: Send>(_: &T) {}
let f = connect("user=postgres");
is_send(&f);
let mut client = f.await;
let f = client.prepare("SELECT $1::TEXT");
is_send(&f);
let stmt = f.await.unwrap();
let f = client.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.transaction();
is_send(&f);
let trans = f.await.unwrap();
let f = trans.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = trans.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
}

View File

@ -13,7 +13,7 @@ async fn connect(s: &str) -> Client {
}
async fn smoke_test(s: &str) {
let mut client = connect(s).await;
let client = connect(s).await;
let stmt = client.prepare("SELECT $1::INT").await.unwrap();
let rows = client
@ -72,7 +72,7 @@ async fn target_session_attrs_err() {
#[tokio::test]
async fn cancel_query() {
let mut client = connect("host=localhost port=5433 user=postgres").await;
let client = connect("host=localhost port=5433 user=postgres").await;
let cancel = client.cancel_query(NoTls);
let cancel = timer::delay(Instant::now() + Duration::from_millis(100)).then(|()| cancel);

View File

@ -1,4 +1,5 @@
use futures::TryStreamExt;
use postgres_types::to_sql_checked;
use std::collections::HashMap;
use std::error::Error;
use std::f32;
@ -7,7 +8,6 @@ use std::fmt;
use std::net::IpAddr;
use std::result;
use std::time::{Duration, UNIX_EPOCH};
use postgres_types::to_sql_checked;
use tokio_postgres::types::{FromSql, FromSqlOwned, IsNull, Kind, ToSql, Type, WrongType};
use crate::connect;
@ -30,27 +30,19 @@ where
T: PartialEq + for<'a> FromSqlOwned + ToSql + Sync,
S: fmt::Display,
{
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
for (val, repr) in checks {
let stmt = client
.prepare(&format!("SELECT {}::{}", repr, sql_type))
.await
.unwrap();
let rows = client
.query(&stmt, &[])
.query(&*format!("SELECT {}::{}", repr, sql_type), &[])
.try_collect::<Vec<_>>()
.await
.unwrap();
let result = rows[0].get(0);
assert_eq!(val, &result);
let stmt = client
.prepare(&format!("SELECT $1::{}", sql_type))
.await
.unwrap();
let rows = client
.query(&stmt, &[&val])
.query(&*format!("SELECT $1::{}", sql_type), &[&val])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -203,7 +195,7 @@ async fn test_text_params() {
#[tokio::test]
async fn test_borrowed_text() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'").await.unwrap();
let rows = client
@ -217,7 +209,7 @@ async fn test_borrowed_text() {
#[tokio::test]
async fn test_bpchar_params() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -257,7 +249,7 @@ async fn test_bpchar_params() {
#[tokio::test]
async fn test_citext_params() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -306,7 +298,7 @@ async fn test_bytea_params() {
#[tokio::test]
async fn test_borrowed_bytea() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
let rows = client
.query(&stmt, &[])
@ -365,7 +357,7 @@ async fn test_nan_param<T>(sql_type: &str)
where
T: PartialEq + ToSql + FromSqlOwned,
{
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let stmt = client
.prepare(&format!("SELECT 'NaN'::{}", sql_type))
@ -392,7 +384,7 @@ async fn test_f64_nan_param() {
#[tokio::test]
async fn test_pg_database_datname() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let stmt = client
.prepare("SELECT datname FROM pg_database")
.await
@ -407,7 +399,7 @@ async fn test_pg_database_datname() {
#[tokio::test]
async fn test_slice() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -436,7 +428,7 @@ async fn test_slice() {
#[tokio::test]
async fn test_slice_wrong_type() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -465,7 +457,7 @@ async fn test_slice_wrong_type() {
#[tokio::test]
async fn test_slice_range() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
let err = client
@ -520,7 +512,7 @@ async fn domain() {
}
}
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -551,7 +543,7 @@ async fn domain() {
#[tokio::test]
async fn composite() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -582,7 +574,7 @@ async fn composite() {
#[tokio::test]
async fn enum_() {
let mut client = connect("user=postgres").await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')")
@ -656,36 +648,3 @@ async fn inet() {
)
.await;
}
#[tokio::test]
async fn check_send() {
fn is_send<T: Send>(_: &T) {}
let f = connect("user=postgres");
is_send(&f);
let mut client = f.await;
let f = client.prepare("SELECT $1::TEXT");
is_send(&f);
let stmt = f.await.unwrap();
let f = client.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.transaction();
is_send(&f);
let mut trans = f.await.unwrap();
let f = trans.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = trans.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
}