Add a ToStatement trait in tokio-postgres

This commit is contained in:
Steven Fackler 2019-09-28 14:42:00 -04:00
parent 286ecdb5b9
commit 0d2d554122
19 changed files with 410 additions and 407 deletions

View File

@ -38,7 +38,7 @@
//! .build()?; //! .build()?;
//! let connector = MakeTlsConnector::new(connector); //! let connector = MakeTlsConnector::new(connector);
//! //!
//! let mut client = postgres::Client::connect( //! let client = postgres::Client::connect(
//! "host=localhost user=postgres sslmode=require", //! "host=localhost user=postgres sslmode=require",
//! connector, //! connector,
//! )?; //! )?;

View File

@ -30,7 +30,7 @@
//! builder.set_ca_file("database_cert.pem")?; //! builder.set_ca_file("database_cert.pem")?;
//! let connector = MakeTlsConnector::new(builder.build()); //! let connector = MakeTlsConnector::new(builder.build());
//! //!
//! let mut client = postgres::Client::connect( //! let client = postgres::Client::connect(
//! "host=localhost user=postgres sslmode=require", //! "host=localhost user=postgres sslmode=require",
//! connector, //! connector,
//! )?; //! )?;

View File

@ -84,8 +84,7 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; executor::block_on(self.0.execute(query, params))
executor::block_on(self.0.execute(&statement, params))
} }
/// Executes a statement, returning the resulting rows. /// Executes a statement, returning the resulting rows.
@ -154,14 +153,13 @@ impl Client {
/// ``` /// ```
pub fn query_iter<'a, T>( pub fn query_iter<'a, T>(
&'a mut self, &'a mut self,
query: &T, query: &'a T,
params: &[&(dyn ToSql + Sync)], params: &'a [&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'a, Error> ) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'a, Error>
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; Ok(Iter::new(self.0.query(query, params)))
Ok(Iter::new(self.0.query(&statement, params)))
} }
/// Creates a new prepared statement. /// Creates a new prepared statement.
@ -249,8 +247,7 @@ impl Client {
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
R: Read + Unpin, R: Read + Unpin,
{ {
let statement = query.__statement(self)?; executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
executor::block_on(self.0.copy_in(&statement, params, CopyInStream(reader)))
} }
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data. /// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
@ -274,14 +271,13 @@ impl Client {
/// ``` /// ```
pub fn copy_out<'a, T>( pub fn copy_out<'a, T>(
&'a mut self, &'a mut self,
query: &T, query: &'a T,
params: &[&(dyn ToSql + Sync)], params: &'a [&(dyn ToSql + Sync)],
) -> Result<impl BufRead + 'a, Error> ) -> Result<impl BufRead + 'a, Error>
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; let stream = self.0.copy_out(query, params);
let stream = self.0.copy_out(&statement, params);
CopyOutReader::new(stream) CopyOutReader::new(stream)
} }
@ -314,7 +310,7 @@ impl Client {
/// them to this method! /// them to this method!
pub fn simple_query_iter<'a>( pub fn simple_query_iter<'a>(
&'a mut self, &'a mut self,
query: &str, query: &'a str,
) -> Result<impl FallibleIterator<Item = SimpleQueryMessage, Error = Error> + 'a, Error> { ) -> Result<impl FallibleIterator<Item = SimpleQueryMessage, Error = Error> + 'a, Error> {
Ok(Iter::new(self.0.simple_query(query))) Ok(Iter::new(self.0.simple_query(query)))
} }

View File

@ -62,7 +62,9 @@ use tokio::runtime::{self, Runtime};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub use tokio_postgres::Socket; pub use tokio_postgres::Socket;
pub use tokio_postgres::{error, row, tls, types, Column, Portal, SimpleQueryMessage, Statement}; pub use tokio_postgres::{
error, row, tls, types, Column, Portal, SimpleQueryMessage, Statement, ToStatement,
};
pub use crate::client::*; pub use crate::client::*;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -73,7 +75,6 @@ pub use crate::error::Error;
pub use crate::row::{Row, SimpleQueryRow}; pub use crate::row::{Row, SimpleQueryRow};
#[doc(no_inline)] #[doc(no_inline)]
pub use crate::tls::NoTls; pub use crate::tls::NoTls;
pub use crate::to_statement::*;
pub use crate::transaction::*; pub use crate::transaction::*;
mod client; mod client;
@ -82,7 +83,6 @@ pub mod config;
mod copy_in_stream; mod copy_in_stream;
mod copy_out_reader; mod copy_out_reader;
mod iter; mod iter;
mod to_statement;
mod transaction; mod transaction;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]

View File

@ -1,59 +0,0 @@
use tokio_postgres::Error;
use crate::{Client, Statement, Transaction};
mod sealed {
pub trait Sealed {}
}
#[doc(hidden)]
pub trait Prepare {
fn prepare(&mut self, query: &str) -> Result<Statement, Error>;
}
impl Prepare for Client {
fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.prepare(query)
}
}
impl<'a> Prepare for Transaction<'a> {
fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.prepare(query)
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: sealed::Sealed {
#[doc(hidden)]
fn __statement<T>(&self, client: &mut T) -> Result<Statement, Error>
where
T: Prepare;
}
impl sealed::Sealed for str {}
impl ToStatement for str {
fn __statement<T>(&self, client: &mut T) -> Result<Statement, Error>
where
T: Prepare,
{
client.prepare(self)
}
}
impl sealed::Sealed for Statement {}
impl ToStatement for Statement {
fn __statement<T>(&self, _: &mut T) -> Result<Statement, Error>
where
T: Prepare,
{
Ok(self.clone())
}
}

View File

@ -47,8 +47,7 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; executor::block_on(self.0.execute(query, params))
executor::block_on(self.0.execute(&statement, params))
} }
/// Like `Client::query`. /// Like `Client::query`.
@ -60,16 +59,15 @@ impl<'a> Transaction<'a> {
} }
/// Like `Client::query_iter`. /// Like `Client::query_iter`.
pub fn query_iter<T>( pub fn query_iter<'b, T>(
&mut self, &'b mut self,
query: &T, query: &'b T,
params: &[&(dyn ToSql + Sync)], params: &'b [&(dyn ToSql + Sync)],
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> ) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error>
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; Ok(Iter::new(self.0.query(query, params)))
Ok(Iter::new(self.0.query(&statement, params)))
} }
/// Binds parameters to a statement, creating a "portal". /// Binds parameters to a statement, creating a "portal".
@ -86,8 +84,7 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; executor::block_on(self.0.bind(query, params))
executor::block_on(self.0.bind(&statement, params))
} }
/// Continues execution of a portal, returning the next set of rows. /// Continues execution of a portal, returning the next set of rows.
@ -100,11 +97,11 @@ impl<'a> Transaction<'a> {
/// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering /// Like `query_portal`, except that it returns a fallible iterator over the resulting rows rather than buffering
/// the entire response in memory. /// the entire response in memory.
pub fn query_portal_iter( pub fn query_portal_iter<'b>(
&mut self, &'b mut self,
portal: &Portal, portal: &'b Portal,
max_rows: i32, max_rows: i32,
) -> Result<impl FallibleIterator<Item = Row, Error = Error>, Error> { ) -> Result<impl FallibleIterator<Item = Row, Error = Error> + 'b, Error> {
Ok(Iter::new(self.0.query_portal(&portal, max_rows))) Ok(Iter::new(self.0.query_portal(&portal, max_rows)))
} }
@ -119,21 +116,19 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
R: Read + Unpin, R: Read + Unpin,
{ {
let statement = query.__statement(self)?; executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
executor::block_on(self.0.copy_in(&statement, params, CopyInStream(reader)))
} }
/// Like `Client::copy_out`. /// Like `Client::copy_out`.
pub fn copy_out<'b, T>( pub fn copy_out<'b, T>(
&'a mut self, &'b mut self,
query: &T, query: &'b T,
params: &[&(dyn ToSql + Sync)], params: &'b [&(dyn ToSql + Sync)],
) -> Result<impl BufRead + 'b, Error> ) -> Result<impl BufRead + 'b, Error>
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let statement = query.__statement(self)?; let stream = self.0.copy_out(query, params);
let stream = self.0.copy_out(&statement, params);
CopyOutReader::new(stream) CopyOutReader::new(stream)
} }
@ -145,7 +140,7 @@ impl<'a> Transaction<'a> {
/// Like `Client::simple_query_iter`. /// Like `Client::simple_query_iter`.
pub fn simple_query_iter<'b>( pub fn simple_query_iter<'b>(
&'b mut self, &'b mut self,
query: &str, query: &'b str,
) -> Result<impl FallibleIterator<Item = SimpleQueryMessage, Error = Error> + 'b, Error> { ) -> Result<impl FallibleIterator<Item = SimpleQueryMessage, Error = Error> + 'b, Error> {
Ok(Iter::new(self.0.simple_query(query))) Ok(Iter::new(self.0.simple_query(query)))
} }

View File

@ -10,36 +10,25 @@ use std::sync::Arc;
static NEXT_ID: AtomicUsize = AtomicUsize::new(0); static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn bind( pub async fn bind<'a, I>(
client: Arc<InnerClient>, client: &Arc<InnerClient>,
statement: Statement, statement: Statement,
bind: Result<PendingBind, Error>, params: I,
) -> Result<Portal, Error> { ) -> Result<Portal, Error>
let bind = bind?; where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let mut buf = query::encode_bind(&statement, params, &name)?;
frontend::sync(&mut buf);
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(bind.buf)))?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? { match responses.next().await? {
Message::BindComplete => {} Message::BindComplete => {}
_ => return Err(Error::unexpected_message()), _ => return Err(Error::unexpected_message()),
} }
Ok(Portal::new(&client, bind.name, statement)) Ok(Portal::new(client, name, statement))
}
pub struct PendingBind {
buf: Vec<u8>,
name: String,
}
pub fn encode<'a, I>(statement: &Statement, params: I) -> Result<PendingBind, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let mut buf = query::encode_bind(statement, params, &name)?;
frontend::sync(&mut buf);
Ok(PendingBind { buf, name })
} }

View File

@ -3,9 +3,11 @@ use crate::cancel_query;
use crate::codec::BackendMessages; use crate::codec::BackendMessages;
use crate::config::{Host, SslMode}; use crate::config::{Host, SslMode};
use crate::connection::{Request, RequestMessages}; use crate::connection::{Request, RequestMessages};
use crate::slice_iter;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect; use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect; use crate::tls::TlsConnect;
use crate::to_statement::ToStatement;
use crate::types::{Oid, ToSql, Type}; use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::Socket; use crate::Socket;
@ -16,13 +18,12 @@ use crate::{Error, Statement};
use bytes::{Bytes, IntoBuf}; use bytes::{Bytes, IntoBuf};
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::{future, Stream, TryStream}; use futures::{future, Stream, TryFutureExt, TryStream};
use futures::{ready, StreamExt}; use futures::{ready, StreamExt};
use parking_lot::Mutex; use parking_lot::Mutex;
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use std::collections::HashMap; use std::collections::HashMap;
use std::error; use std::error;
use std::future::Future;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
@ -160,8 +161,8 @@ impl Client {
} }
} }
pub(crate) fn inner(&self) -> Arc<InnerClient> { pub(crate) fn inner(&self) -> &Arc<InnerClient> {
self.inner.clone() &self.inner
} }
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -194,29 +195,35 @@ impl Client {
/// # Panics /// # Panics
/// ///
/// Panics if the number of parameters provided does not match the number expected. /// Panics if the number of parameters provided does not match the number expected.
pub fn query<'a>( pub fn query<'a, T>(
&'a self, &'a self,
statement: &'a Statement, statement: &'a T,
params: &'a [&'a (dyn ToSql + Sync)], params: &'a [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'a { ) -> impl Stream<Item = Result<Row, Error>> + 'a
let buf = query::encode(statement, params.iter().map(|s| *s as _)); where
query::query(&self.inner, statement, buf) T: ?Sized + ToStatement,
{
self.query_iter(statement, slice_iter(params))
} }
/// Like [`query`], but takes an iterator of parameters rather than a slice. /// Like [`query`], but takes an iterator of parameters rather than a slice.
/// ///
/// [`query`]: #method.query /// [`query`]: #method.query
pub fn query_iter<'a, I>( pub fn query_iter<'a, T, I>(
&'a self, &'a self,
statement: &'a Statement, statement: &'a T,
params: I, params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'a ) -> impl Stream<Item = Result<Row, Error>> + 'a
where where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql> + 'a, I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let buf = query::encode(statement, params); let f = async move {
query::query(&self.inner, statement, buf) let statement = statement.__convert().into_statement(self).await?;
Ok(query::query(&self.inner, statement, params))
};
f.try_flatten_stream()
} }
/// Executes a statement, returning the number of rows modified. /// Executes a statement, returning the number of rows modified.
@ -226,29 +233,28 @@ impl Client {
/// # Panics /// # Panics
/// ///
/// Panics if the number of parameters provided does not match the number expected. /// Panics if the number of parameters provided does not match the number expected.
pub async fn execute( pub async fn execute<T>(
&self, &self,
statement: &Statement, statement: &T,
params: &[&(dyn ToSql + Sync)], params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error> { ) -> Result<u64, Error>
let buf = query::encode(statement, params.iter().map(|s| *s as _)); where
query::execute(&self.inner, buf).await T: ?Sized + ToStatement,
{
self.execute_iter(statement, slice_iter(params)).await
} }
/// Like [`execute`], but takes an iterator of parameters rather than a slice. /// Like [`execute`], but takes an iterator of parameters rather than a slice.
/// ///
/// [`execute`]: #method.execute /// [`execute`]: #method.execute
pub async fn execute_iter<'a, I>( pub async fn execute_iter<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
&self,
statement: &Statement,
params: I,
) -> Result<u64, Error>
where where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql>, I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let buf = query::encode(statement, params); let statement = statement.__convert().into_statement(self).await?;
query::execute(&self.inner, buf).await query::execute(self.inner(), statement, params).await
} }
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created. /// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
@ -259,20 +265,22 @@ impl Client {
/// # Panics /// # Panics
/// ///
/// Panics if the number of parameters provided does not match the number expected. /// Panics if the number of parameters provided does not match the number expected.
pub fn copy_in<S>( pub async fn copy_in<T, S>(
&self, &self,
statement: &Statement, statement: &T,
params: &[&(dyn ToSql + Sync)], params: &[&(dyn ToSql + Sync)],
stream: S, stream: S,
) -> impl Future<Output = Result<u64, Error>> ) -> Result<u64, Error>
where where
T: ?Sized + ToStatement,
S: TryStream, S: TryStream,
S::Ok: IntoBuf, S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send, <S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>, S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{ {
let buf = query::encode(statement, params.iter().map(|s| *s as _)); let statement = statement.__convert().into_statement(self).await?;
copy_in::copy_in(self.inner(), buf, stream) let params = slice_iter(params);
copy_in::copy_in(self.inner(), statement, params, stream).await
} }
/// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.
@ -280,13 +288,20 @@ impl Client {
/// # Panics /// # Panics
/// ///
/// Panics if the number of parameters provided does not match the number expected. /// Panics if the number of parameters provided does not match the number expected.
pub fn copy_out( pub fn copy_out<'a, T>(
&self, &'a self,
statement: &Statement, statement: &'a T,
params: &[&(dyn ToSql + Sync)], params: &'a [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Bytes, Error>> { ) -> impl Stream<Item = Result<Bytes, Error>> + 'a
let buf = query::encode(statement, params.iter().map(|s| *s as _)); where
copy_out::copy_out(self.inner(), buf) T: ?Sized + ToStatement,
{
let f = async move {
let statement = statement.__convert().into_statement(self).await?;
let params = slice_iter(params);
Ok(copy_out::copy_out(self.inner(), statement, params))
};
f.try_flatten_stream()
} }
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
@ -302,10 +317,10 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the /// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method! /// them to this method!
pub fn simple_query( pub fn simple_query<'a>(
&self, &'a self,
query: &str, query: &'a str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> { ) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> + 'a {
simple_query::simple_query(self.inner(), query) simple_query::simple_query(self.inner(), query)
} }
@ -319,8 +334,8 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the /// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method! /// them to this method!
pub fn batch_execute(&self, query: &str) -> impl Future<Output = Result<(), Error>> { pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
simple_query::batch_execute(self.inner(), query) simple_query::batch_execute(self.inner(), query).await
} }
/// Begins a new database transaction. /// Begins a new database transaction.
@ -338,7 +353,7 @@ impl Client {
/// ///
/// Requires the `runtime` Cargo feature (enabled by default). /// Requires the `runtime` Cargo feature (enabled by default).
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, tls: T) -> impl Future<Output = Result<(), Error>> pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where where
T: MakeTlsConnect<Socket>, T: MakeTlsConnect<Socket>,
{ {
@ -349,15 +364,12 @@ impl Client {
self.process_id, self.process_id,
self.secret_key, self.secret_key,
) )
.await
} }
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself. /// connection itself.
pub fn cancel_query_raw<S, T>( pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
&self,
stream: S,
tls: T,
) -> impl Future<Output = Result<(), Error>>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>, T: TlsConnect<S>,
@ -369,6 +381,7 @@ impl Client {
self.process_id, self.process_id,
self.secret_key, self.secret_key,
) )
.await
} }
/// Determines if the connection to the server has already closed. /// Determines if the connection to the server has already closed.

View File

@ -1,7 +1,8 @@
use crate::client::InnerClient; use crate::client::InnerClient;
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::Error; use crate::types::ToSql;
use crate::{query, Error, Statement};
use bytes::{Buf, BufMut, BytesMut, IntoBuf}; use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::ready; use futures::ready;
@ -12,7 +13,6 @@ use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::CopyData; use postgres_protocol::message::frontend::CopyData;
use std::error; use std::error;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
enum CopyInMessage { enum CopyInMessage {
@ -62,18 +62,21 @@ impl Stream for CopyInReceiver {
} }
} }
pub async fn copy_in<S>( pub async fn copy_in<'a, I, S>(
client: Arc<InnerClient>, client: &InnerClient,
buf: Result<Vec<u8>, Error>, statement: Statement,
params: I,
stream: S, stream: S,
) -> Result<u64, Error> ) -> Result<u64, Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
S: TryStream, S: TryStream,
S::Ok: IntoBuf, S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send, <S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>, S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{ {
let buf = buf?; let buf = query::encode(&statement, params)?;
let (mut sender, receiver) = mpsc::channel(1); let (mut sender, receiver) = mpsc::channel(1);
let receiver = CopyInReceiver::new(receiver); let receiver = CopyInReceiver::new(receiver);

View File

@ -1,25 +1,32 @@
use crate::client::{InnerClient, Responses}; use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::Error; use crate::types::ToSql;
use crate::{query, Error, Statement};
use bytes::Bytes; use bytes::Bytes;
use futures::{ready, Stream, TryFutureExt}; use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
pub fn copy_out( pub fn copy_out<'a, I>(
client: Arc<InnerClient>, client: &'a InnerClient,
buf: Result<Vec<u8>, Error>, statement: Statement,
) -> impl Stream<Item = Result<Bytes, Error>> { params: I,
start(client, buf) ) -> impl Stream<Item = Result<Bytes, Error>> + 'a
.map_ok(|responses| CopyOut { responses }) where
.try_flatten_stream() I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I::IntoIter: ExactSizeIterator,
{
let f = async move {
let buf = query::encode(&statement, params)?;
let responses = start(client, buf).await?;
Ok(CopyOut { responses })
};
f.try_flatten_stream()
} }
async fn start(client: Arc<InnerClient>, buf: Result<Vec<u8>, Error>) -> Result<Responses, Error> { async fn start(client: &InnerClient, buf: Vec<u8>) -> Result<Responses, Error> {
let buf = buf?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? { match responses.next().await? {

View File

@ -114,11 +114,13 @@ pub use crate::portal::Portal;
pub use crate::row::{Row, SimpleQueryRow}; pub use crate::row::{Row, SimpleQueryRow};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub use crate::socket::Socket; pub use crate::socket::Socket;
pub use crate::statement::{Column, Statement};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect; use crate::tls::MakeTlsConnect;
pub use crate::tls::NoTls; pub use crate::tls::NoTls;
pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction; pub use crate::transaction::Transaction;
pub use statement::{Column, Statement}; use crate::types::ToSql;
mod bind; mod bind;
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -147,6 +149,7 @@ mod simple_query;
mod socket; mod socket;
mod statement; mod statement;
pub mod tls; pub mod tls;
mod to_statement;
mod transaction; mod transaction;
pub mod types; pub mod types;
@ -220,3 +223,9 @@ pub enum SimpleQueryMessage {
#[doc(hidden)] #[doc(hidden)]
__NonExhaustive, __NonExhaustive,
} }
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a dyn ToSql> + 'a {
s.iter().map(|s| *s as _)
}

View File

@ -2,8 +2,8 @@ use crate::client::InnerClient;
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::error::SqlState; use crate::error::SqlState;
use crate::query; use crate::types::{Field, Kind, Oid, Type};
use crate::types::{Field, Kind, Oid, ToSql, Type}; use crate::{query, slice_iter};
use crate::{Column, Error, Statement}; use crate::{Column, Error, Statement};
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use futures::{future, TryStreamExt}; use futures::{future, TryStreamExt};
@ -65,43 +65,43 @@ pub async fn prepare(
let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let buf = encode(&name, query, types); let buf = encode(&name, query, types);
let buf = buf?; let buf = buf?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? { match responses.next().await? {
Message::ParseComplete => {} Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()), _ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(&client, oid).await?;
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(&client, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_);
columns.push(column);
} }
}
let parameter_description = match responses.next().await? { Ok(Statement::new(&client, name, parameters, columns))
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(&client, oid).await?;
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(&client, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_);
columns.push(column);
}
}
Ok(Statement::new(&client, name, parameters, columns))
} }
fn prepare_rec<'a>( fn prepare_rec<'a>(
@ -132,8 +132,8 @@ async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
let stmt = typeinfo_statement(client).await?; let stmt = typeinfo_statement(client).await?;
let buf = query::encode(&stmt, (&[&oid as &dyn ToSql]).iter().cloned()); let params = &[&oid as _];
let rows = query::query(client, &stmt, buf); let rows = query::query(client, stmt, slice_iter(params));
pin_mut!(rows); pin_mut!(rows);
let row = match rows.try_next().await? { let row = match rows.try_next().await? {
@ -203,8 +203,7 @@ async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Erro
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> { async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client).await?; let stmt = typeinfo_enum_statement(client).await?;
let buf = query::encode(&stmt, (&[&oid as &dyn ToSql]).iter().cloned()); query::query(client, stmt, slice_iter(&[&oid]))
query::query(client, &stmt, buf)
.and_then(|row| future::ready(row.try_get(0))) .and_then(|row| future::ready(row.try_get(0)))
.try_collect() .try_collect()
.await .await
@ -230,8 +229,7 @@ async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement,
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> { async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client).await?; let stmt = typeinfo_composite_statement(client).await?;
let buf = query::encode(&stmt, (&[&oid as &dyn ToSql]).iter().cloned()); let rows = query::query(client, stmt, slice_iter(&[&oid]))
let rows = query::query(client, &stmt, buf)
.try_collect::<Vec<_>>() .try_collect::<Vec<_>>()
.await?; .await?;

View File

@ -7,26 +7,33 @@ use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
pub fn query<'a>( pub fn query<'a, I>(
client: &'a Arc<InnerClient>, client: &'a InnerClient,
statement: &'a Statement, statement: Statement,
buf: Result<Vec<u8>, Error>, params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'a { ) -> impl Stream<Item = Result<Row, Error>> + 'a
where
I: IntoIterator<Item = &'a dyn ToSql> + 'a,
I::IntoIter: ExactSizeIterator,
{
let f = async move { let f = async move {
let buf = encode(&statement, params)?;
let responses = start(client, buf).await?; let responses = start(client, buf).await?;
Ok(Query { statement: statement.clone(), responses }) Ok(Query {
statement,
responses,
})
}; };
f.try_flatten_stream() f.try_flatten_stream()
} }
pub fn query_portal( pub fn query_portal<'a>(
client: Arc<InnerClient>, client: &'a InnerClient,
portal: Portal, portal: &'a Portal,
max_rows: i32, max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> { ) -> impl Stream<Item = Result<Row, Error>> + 'a {
let start = async move { let start = async move {
let mut buf = vec![]; let mut buf = vec![];
frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?; frontend::execute(portal.name(), max_rows, &mut buf).map_err(Error::encode)?;
@ -43,7 +50,16 @@ pub fn query_portal(
start.try_flatten_stream() start.try_flatten_stream()
} }
pub async fn execute(client: &InnerClient, buf: Result<Vec<u8>, Error>) -> Result<u64, Error> { pub async fn execute<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<u64, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let buf = encode(&statement, params)?;
let mut responses = start(client, buf).await?; let mut responses = start(client, buf).await?;
loop { loop {
@ -66,8 +82,7 @@ pub async fn execute(client: &InnerClient, buf: Result<Vec<u8>, Error>) -> Resul
} }
} }
async fn start(client: &InnerClient, buf: Result<Vec<u8>, Error>) -> Result<Responses, Error> { async fn start(client: &InnerClient, buf: Vec<u8>) -> Result<Responses, Error> {
let buf = buf?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? { match responses.next().await? {

View File

@ -6,19 +6,16 @@ use fallible_iterator::FallibleIterator;
use futures::{ready, Stream, TryFutureExt}; use futures::{ready, Stream, TryFutureExt};
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
pub fn simple_query( pub fn simple_query<'a>(
client: Arc<InnerClient>, client: &'a InnerClient,
query: &str, query: &'a str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> { ) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> + 'a {
let buf = encode(query); let f = async move {
let buf = encode(query)?;
let start = async move {
let buf = buf?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(SimpleQuery { Ok(SimpleQuery {
@ -26,29 +23,21 @@ pub fn simple_query(
columns: None, columns: None,
}) })
}; };
f.try_flatten_stream()
start.try_flatten_stream()
} }
pub fn batch_execute( pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Error> {
client: Arc<InnerClient>, let buf = encode(query)?;
query: &str, let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
) -> impl Future<Output = Result<(), Error>> {
let buf = encode(query);
async move { loop {
let buf = buf?; match responses.next().await? {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; Message::ReadyForQuery(_) => return Ok(()),
Message::CommandComplete(_)
loop { | Message::EmptyQueryResponse
match responses.next().await? { | Message::RowDescription(_)
Message::ReadyForQuery(_) => return Ok(()), | Message::DataRow(_) => {}
Message::CommandComplete(_) _ => return Err(Error::unexpected_message()),
| Message::EmptyQueryResponse
| Message::RowDescription(_)
| Message::DataRow(_) => {}
_ => return Err(Error::unexpected_message()),
}
} }
} }
} }

View File

@ -0,0 +1,49 @@
use crate::to_statement::private::{Sealed, ToStatementType};
use crate::Statement;
mod private {
use crate::{Client, Error, Statement};
pub trait Sealed {}
pub enum ToStatementType<'a> {
Statement(&'a Statement),
Query(&'a str),
}
impl<'a> ToStatementType<'a> {
pub async fn into_statement(self, client: &Client) -> Result<Statement, Error> {
match self {
ToStatementType::Statement(s) => Ok(s.clone()),
ToStatementType::Query(s) => client.prepare(s).await,
}
}
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: private::Sealed {
#[doc(hidden)]
fn __convert(&self) -> ToStatementType<'_>;
}
impl ToStatement for Statement {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Statement(self)
}
}
impl Sealed for Statement {}
impl ToStatement for str {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for str {}

View File

@ -6,12 +6,13 @@ use crate::tls::TlsConnect;
use crate::types::{ToSql, Type}; use crate::types::{ToSql, Type};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::Socket; use crate::Socket;
use crate::{bind, query, Client, Error, Portal, Row, SimpleQueryMessage, Statement}; use crate::{
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
};
use bytes::{Bytes, IntoBuf}; use bytes::{Bytes, IntoBuf};
use futures::{Stream, TryStream}; use futures::{Stream, TryStream};
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use std::error; use std::error;
use std::future::Future;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
/// A representation of a PostgreSQL database transaction. /// A representation of a PostgreSQL database transaction.
@ -92,21 +93,25 @@ impl<'a> Transaction<'a> {
} }
/// Like `Client::query`. /// Like `Client::query`.
pub fn query<'b>( pub fn query<'b, T>(
&'b self, &'b self,
statement: &'b Statement, statement: &'b T,
params: &'b [&'b (dyn ToSql + Sync)], params: &'b [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Row, Error>> + 'b { ) -> impl Stream<Item = Result<Row, Error>> + 'b
where
T: ?Sized + ToStatement,
{
self.client.query(statement, params) self.client.query(statement, params)
} }
/// Like `Client::query_iter`. /// Like `Client::query_iter`.
pub fn query_iter<'b, I>( pub fn query_iter<'b, T, I>(
&'b self, &'b self,
statement: &'b Statement, statement: &'b T,
params: I, params: I,
) -> impl Stream<Item = Result<Row, Error>> + 'b ) -> impl Stream<Item = Result<Row, Error>> + 'b
where where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql> + 'b, I: IntoIterator<Item = &'b dyn ToSql> + 'b,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
@ -114,21 +119,25 @@ impl<'a> Transaction<'a> {
} }
/// Like `Client::execute`. /// Like `Client::execute`.
pub async fn execute( pub async fn execute<T>(
&self, &self,
statement: &Statement, statement: &T,
params: &[&(dyn ToSql + Sync)], params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error> { ) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.client.execute(statement, params).await self.client.execute(statement, params).await
} }
/// Like `Client::execute_iter`. /// Like `Client::execute_iter`.
pub async fn execute_iter<'b, I>( pub async fn execute_iter<'b, I, T>(
&self, &self,
statement: &Statement, statement: &Statement,
params: I, params: I,
) -> Result<u64, Error> ) -> Result<u64, Error>
where where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>, I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
@ -143,102 +152,100 @@ impl<'a> Transaction<'a> {
/// # Panics /// # Panics
/// ///
/// Panics if the number of parameters provided does not match the number expected. /// Panics if the number of parameters provided does not match the number expected.
pub fn bind( pub async fn bind<T>(
&self, &self,
statement: &Statement, statement: &T,
params: &[&(dyn ToSql + Sync)], params: &[&(dyn ToSql + Sync)],
) -> impl Future<Output = Result<Portal, Error>> { ) -> Result<Portal, Error>
// https://github.com/rust-lang/rust/issues/63032 where
let buf = bind::encode(statement, params.iter().map(|s| *s as _)); T: ?Sized + ToStatement,
bind::bind(self.client.inner(), statement.clone(), buf) {
self.bind_iter(statement, slice_iter(params)).await
} }
/// Like [`bind`], but takes an iterator of parameters rather than a slice. /// Like [`bind`], but takes an iterator of parameters rather than a slice.
/// ///
/// [`bind`]: #method.bind /// [`bind`]: #method.bind
pub fn bind_iter<'b, I>( pub async fn bind_iter<'b, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
&self,
statement: &Statement,
params: I,
) -> impl Future<Output = Result<Portal, Error>>
where where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>, I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let buf = bind::encode(statement, params); let statement = statement.__convert().into_statement(&self.client).await?;
bind::bind(self.client.inner(), statement.clone(), buf) bind::bind(self.client.inner(), statement, params).await
} }
/// Continues execution of a portal, returning a stream of the resulting rows. /// Continues execution of a portal, returning a stream of the resulting rows.
/// ///
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all rows will be returned. /// `query_portal`. If the requested number is negative or 0, all rows will be returned.
pub fn query_portal( pub fn query_portal<'b>(
&self, &'b self,
portal: &Portal, portal: &'b Portal,
max_rows: i32, max_rows: i32,
) -> impl Stream<Item = Result<Row, Error>> { ) -> impl Stream<Item = Result<Row, Error>> + 'b {
query::query_portal(self.client.inner(), portal.clone(), max_rows) query::query_portal(self.client.inner(), portal, max_rows)
} }
/// Like `Client::copy_in`. /// Like `Client::copy_in`.
pub fn copy_in<S>( pub async fn copy_in<T, S>(
&self, &self,
statement: &Statement, statement: &T,
params: &[&(dyn ToSql + Sync)], params: &[&(dyn ToSql + Sync)],
stream: S, stream: S,
) -> impl Future<Output = Result<u64, Error>> ) -> Result<u64, Error>
where where
T: ?Sized + ToStatement,
S: TryStream, S: TryStream,
S::Ok: IntoBuf, S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send, <S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>, S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{ {
self.client.copy_in(statement, params, stream) self.client.copy_in(statement, params, stream).await
} }
/// Like `Client::copy_out`. /// Like `Client::copy_out`.
pub fn copy_out( pub fn copy_out<'b, T>(
&self, &'b self,
statement: &Statement, statement: &'b T,
params: &[&(dyn ToSql + Sync)], params: &'b [&(dyn ToSql + Sync)],
) -> impl Stream<Item = Result<Bytes, Error>> { ) -> impl Stream<Item = Result<Bytes, Error>> + 'b
where
T: ?Sized + ToStatement,
{
self.client.copy_out(statement, params) self.client.copy_out(statement, params)
} }
/// Like `Client::simple_query`. /// Like `Client::simple_query`.
pub fn simple_query( pub fn simple_query<'b>(
&self, &'b self,
query: &str, query: &'b str,
) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> { ) -> impl Stream<Item = Result<SimpleQueryMessage, Error>> + 'b {
self.client.simple_query(query) self.client.simple_query(query)
} }
/// Like `Client::batch_execute`. /// Like `Client::batch_execute`.
pub fn batch_execute(&self, query: &str) -> impl Future<Output = Result<(), Error>> { pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
self.client.batch_execute(query) self.client.batch_execute(query).await
} }
/// Like `Client::cancel_query`. /// Like `Client::cancel_query`.
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, tls: T) -> impl Future<Output = Result<(), Error>> pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where where
T: MakeTlsConnect<Socket>, T: MakeTlsConnect<Socket>,
{ {
self.client.cancel_query(tls) self.client.cancel_query(tls).await
} }
/// Like `Client::cancel_query_raw`. /// Like `Client::cancel_query_raw`.
pub fn cancel_query_raw<S, T>( pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
&self,
stream: S,
tls: T,
) -> impl Future<Output = Result<(), Error>>
where where
S: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>, T: TlsConnect<S>,
{ {
self.client.cancel_query_raw(stream, tls) self.client.cancel_query_raw(stream, tls).await
} }
/// Like `Client::transaction`. /// Like `Client::transaction`.

View File

@ -98,7 +98,7 @@ async fn scram_password_ok() {
#[tokio::test] #[tokio::test]
async fn pipelined_prepare() { async fn pipelined_prepare() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let prepare1 = client.prepare("SELECT $1::HSTORE[]"); let prepare1 = client.prepare("SELECT $1::HSTORE[]");
let prepare2 = client.prepare("SELECT $1::BIGINT"); let prepare2 = client.prepare("SELECT $1::BIGINT");
@ -114,7 +114,7 @@ async fn pipelined_prepare() {
#[tokio::test] #[tokio::test]
async fn insert_select() { async fn insert_select() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)") .batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
@ -138,7 +138,7 @@ async fn insert_select() {
#[tokio::test] #[tokio::test]
async fn custom_enum() { async fn custom_enum() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -167,7 +167,7 @@ async fn custom_enum() {
#[tokio::test] #[tokio::test]
async fn custom_domain() { async fn custom_domain() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)") .batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)")
@ -183,7 +183,7 @@ async fn custom_domain() {
#[tokio::test] #[tokio::test]
async fn custom_array() { async fn custom_array() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let select = client.prepare("SELECT $1::HSTORE[]").await.unwrap(); let select = client.prepare("SELECT $1::HSTORE[]").await.unwrap();
@ -200,7 +200,7 @@ async fn custom_array() {
#[tokio::test] #[tokio::test]
async fn custom_composite() { async fn custom_composite() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -232,7 +232,7 @@ async fn custom_composite() {
#[tokio::test] #[tokio::test]
async fn custom_range() { async fn custom_range() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -253,7 +253,7 @@ async fn custom_range() {
#[tokio::test] #[tokio::test]
async fn simple_query() { async fn simple_query() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let messages = client let messages = client
.simple_query( .simple_query(
@ -299,7 +299,7 @@ async fn simple_query() {
#[tokio::test] #[tokio::test]
async fn cancel_query_raw() { async fn cancel_query_raw() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap(); let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let cancel = client.cancel_query_raw(socket, NoTls); let cancel = client.cancel_query_raw(socket, NoTls);
@ -327,7 +327,7 @@ async fn transaction_commit() {
.await .await
.unwrap(); .unwrap();
let mut transaction = client.transaction().await.unwrap(); let transaction = client.transaction().await.unwrap();
transaction transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')") .batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await .await
@ -359,7 +359,7 @@ async fn transaction_rollback() {
.await .await
.unwrap(); .unwrap();
let mut transaction = client.transaction().await.unwrap(); let transaction = client.transaction().await.unwrap();
transaction transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')") .batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await .await
@ -390,7 +390,7 @@ async fn transaction_rollback_drop() {
.await .await
.unwrap(); .unwrap();
let mut transaction = client.transaction().await.unwrap(); let transaction = client.transaction().await.unwrap();
transaction transaction
.batch_execute("INSERT INTO foo (name) VALUES ('steven')") .batch_execute("INSERT INTO foo (name) VALUES ('steven')")
.await .await
@ -409,7 +409,7 @@ async fn transaction_rollback_drop() {
#[tokio::test] #[tokio::test]
async fn copy_in() { async fn copy_in() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -449,7 +449,7 @@ async fn copy_in() {
#[tokio::test] #[tokio::test]
async fn copy_in_large() { async fn copy_in_large() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -480,7 +480,7 @@ async fn copy_in_large() {
#[tokio::test] #[tokio::test]
async fn copy_in_error() { async fn copy_in_error() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -511,7 +511,7 @@ async fn copy_in_error() {
#[tokio::test] #[tokio::test]
async fn copy_out() { async fn copy_out() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -532,7 +532,7 @@ async fn copy_out() {
#[tokio::test] #[tokio::test]
async fn notifications() { async fn notifications() {
let (mut client, mut connection) = connect_raw("user=postgres").await.unwrap(); let (client, mut connection) = connect_raw("user=postgres").await.unwrap();
let (tx, rx) = mpsc::unbounded(); let (tx, rx) = mpsc::unbounded();
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e)); let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
@ -585,7 +585,7 @@ async fn query_portal() {
.await .await
.unwrap(); .unwrap();
let mut transaction = client.transaction().await.unwrap(); let transaction = client.transaction().await.unwrap();
let portal = transaction.bind(&stmt, &[]).await.unwrap(); let portal = transaction.bind(&stmt, &[]).await.unwrap();
let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>(); let f1 = transaction.query_portal(&portal, 2).try_collect::<Vec<_>>();
@ -624,3 +624,36 @@ async fn prefer_channel_binding() {
async fn disable_channel_binding() { async fn disable_channel_binding() {
connect("user=postgres channel_binding=disable").await; connect("user=postgres channel_binding=disable").await;
} }
#[tokio::test]
async fn check_send() {
fn is_send<T: Send>(_: &T) {}
let f = connect("user=postgres");
is_send(&f);
let mut client = f.await;
let f = client.prepare("SELECT $1::TEXT");
is_send(&f);
let stmt = f.await.unwrap();
let f = client.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.transaction();
is_send(&f);
let trans = f.await.unwrap();
let f = trans.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = trans.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
}

View File

@ -13,7 +13,7 @@ async fn connect(s: &str) -> Client {
} }
async fn smoke_test(s: &str) { async fn smoke_test(s: &str) {
let mut client = connect(s).await; let client = connect(s).await;
let stmt = client.prepare("SELECT $1::INT").await.unwrap(); let stmt = client.prepare("SELECT $1::INT").await.unwrap();
let rows = client let rows = client
@ -72,7 +72,7 @@ async fn target_session_attrs_err() {
#[tokio::test] #[tokio::test]
async fn cancel_query() { async fn cancel_query() {
let mut client = connect("host=localhost port=5433 user=postgres").await; let client = connect("host=localhost port=5433 user=postgres").await;
let cancel = client.cancel_query(NoTls); let cancel = client.cancel_query(NoTls);
let cancel = timer::delay(Instant::now() + Duration::from_millis(100)).then(|()| cancel); let cancel = timer::delay(Instant::now() + Duration::from_millis(100)).then(|()| cancel);

View File

@ -1,4 +1,5 @@
use futures::TryStreamExt; use futures::TryStreamExt;
use postgres_types::to_sql_checked;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error; use std::error::Error;
use std::f32; use std::f32;
@ -7,7 +8,6 @@ use std::fmt;
use std::net::IpAddr; use std::net::IpAddr;
use std::result; use std::result;
use std::time::{Duration, UNIX_EPOCH}; use std::time::{Duration, UNIX_EPOCH};
use postgres_types::to_sql_checked;
use tokio_postgres::types::{FromSql, FromSqlOwned, IsNull, Kind, ToSql, Type, WrongType}; use tokio_postgres::types::{FromSql, FromSqlOwned, IsNull, Kind, ToSql, Type, WrongType};
use crate::connect; use crate::connect;
@ -30,27 +30,19 @@ where
T: PartialEq + for<'a> FromSqlOwned + ToSql + Sync, T: PartialEq + for<'a> FromSqlOwned + ToSql + Sync,
S: fmt::Display, S: fmt::Display,
{ {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
for (val, repr) in checks { for (val, repr) in checks {
let stmt = client
.prepare(&format!("SELECT {}::{}", repr, sql_type))
.await
.unwrap();
let rows = client let rows = client
.query(&stmt, &[]) .query(&*format!("SELECT {}::{}", repr, sql_type), &[])
.try_collect::<Vec<_>>() .try_collect::<Vec<_>>()
.await .await
.unwrap(); .unwrap();
let result = rows[0].get(0); let result = rows[0].get(0);
assert_eq!(val, &result); assert_eq!(val, &result);
let stmt = client
.prepare(&format!("SELECT $1::{}", sql_type))
.await
.unwrap();
let rows = client let rows = client
.query(&stmt, &[&val]) .query(&*format!("SELECT $1::{}", sql_type), &[&val])
.try_collect::<Vec<_>>() .try_collect::<Vec<_>>()
.await .await
.unwrap(); .unwrap();
@ -203,7 +195,7 @@ async fn test_text_params() {
#[tokio::test] #[tokio::test]
async fn test_borrowed_text() { async fn test_borrowed_text() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'").await.unwrap(); let stmt = client.prepare("SELECT 'foo'").await.unwrap();
let rows = client let rows = client
@ -217,7 +209,7 @@ async fn test_borrowed_text() {
#[tokio::test] #[tokio::test]
async fn test_bpchar_params() { async fn test_bpchar_params() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -257,7 +249,7 @@ async fn test_bpchar_params() {
#[tokio::test] #[tokio::test]
async fn test_citext_params() { async fn test_citext_params() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -306,7 +298,7 @@ async fn test_bytea_params() {
#[tokio::test] #[tokio::test]
async fn test_borrowed_bytea() { async fn test_borrowed_bytea() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap(); let stmt = client.prepare("SELECT 'foo'::BYTEA").await.unwrap();
let rows = client let rows = client
.query(&stmt, &[]) .query(&stmt, &[])
@ -365,7 +357,7 @@ async fn test_nan_param<T>(sql_type: &str)
where where
T: PartialEq + ToSql + FromSqlOwned, T: PartialEq + ToSql + FromSqlOwned,
{ {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client let stmt = client
.prepare(&format!("SELECT 'NaN'::{}", sql_type)) .prepare(&format!("SELECT 'NaN'::{}", sql_type))
@ -392,7 +384,7 @@ async fn test_f64_nan_param() {
#[tokio::test] #[tokio::test]
async fn test_pg_database_datname() { async fn test_pg_database_datname() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client let stmt = client
.prepare("SELECT datname FROM pg_database") .prepare("SELECT datname FROM pg_database")
.await .await
@ -407,7 +399,7 @@ async fn test_pg_database_datname() {
#[tokio::test] #[tokio::test]
async fn test_slice() { async fn test_slice() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -436,7 +428,7 @@ async fn test_slice() {
#[tokio::test] #[tokio::test]
async fn test_slice_wrong_type() { async fn test_slice_wrong_type() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -465,7 +457,7 @@ async fn test_slice_wrong_type() {
#[tokio::test] #[tokio::test]
async fn test_slice_range() { async fn test_slice_range() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap(); let stmt = client.prepare("SELECT $1::INT8RANGE").await.unwrap();
let err = client let err = client
@ -520,7 +512,7 @@ async fn domain() {
} }
} }
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -551,7 +543,7 @@ async fn domain() {
#[tokio::test] #[tokio::test]
async fn composite() { async fn composite() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute( .batch_execute(
@ -582,7 +574,7 @@ async fn composite() {
#[tokio::test] #[tokio::test]
async fn enum_() { async fn enum_() {
let mut client = connect("user=postgres").await; let client = connect("user=postgres").await;
client client
.batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')") .batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')")
@ -656,36 +648,3 @@ async fn inet() {
) )
.await; .await;
} }
#[tokio::test]
async fn check_send() {
fn is_send<T: Send>(_: &T) {}
let f = connect("user=postgres");
is_send(&f);
let mut client = f.await;
let f = client.prepare("SELECT $1::TEXT");
is_send(&f);
let stmt = f.await.unwrap();
let f = client.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = client.transaction();
is_send(&f);
let mut trans = f.await.unwrap();
let f = trans.query(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
let f = trans.execute(&stmt, &[&"hello"]);
is_send(&f);
drop(f);
}