make iterators from &dyn ToSql or T: ToSql work as parameters

This commit is contained in:
Bernardo Uriarte Blanco 2020-10-26 20:59:28 +01:00
parent 5e065c36cd
commit 0eab5fad70
5 changed files with 60 additions and 21 deletions

View File

@ -951,3 +951,21 @@ fn downcast(len: usize) -> Result<i32, Box<dyn Error + Sync + Send>> {
Ok(len as i32) Ok(len as i32)
} }
} }
/// A helper trait to be able create a parameters iterator from `&dyn ToSql` or `T: ToSql`
pub trait BorrowToSql {
/// Get a reference to a `ToSql` trait object
fn borrow_to_sql(&self) -> &dyn ToSql;
}
impl BorrowToSql for &dyn ToSql {
fn borrow_to_sql(&self) -> &dyn ToSql {
*self
}
}
impl<T: ToSql> BorrowToSql for T {
fn borrow_to_sql(&self) -> &dyn ToSql {
self
}
}

View File

@ -6,6 +6,7 @@ use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{ready, SinkExt, Stream}; use futures::{ready, SinkExt, Stream};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use postgres_types::BorrowToSql;
use std::convert::TryFrom; use std::convert::TryFrom;
use std::io; use std::io;
use std::io::Cursor; use std::io::Cursor;
@ -58,9 +59,10 @@ impl BinaryCopyInWriter {
/// # Panics /// # Panics
/// ///
/// Panics if the number of values provided does not match the number expected. /// Panics if the number of values provided does not match the number expected.
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error> pub async fn write_raw<P, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let mut this = self.project(); let mut this = self.project();
@ -79,6 +81,7 @@ impl BinaryCopyInWriter {
let idx = this.buf.len(); let idx = this.buf.len();
this.buf.put_i32(0); this.buf.put_i32(0);
let len = match value let len = match value
.borrow_to_sql()
.to_sql_checked(type_, this.buf) .to_sql_checked(type_, this.buf)
.map_err(|e| Error::to_sql(e, i))? .map_err(|e| Error::to_sql(e, i))?
{ {

View File

@ -1,7 +1,7 @@
use crate::client::InnerClient; use crate::client::InnerClient;
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::types::ToSql; use crate::types::BorrowToSql;
use crate::{query, Error, Portal, Statement}; use crate::{query, Error, Portal, Statement};
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
@ -10,13 +10,14 @@ use std::sync::Arc;
static NEXT_ID: AtomicUsize = AtomicUsize::new(0); static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn bind<'a, I>( pub async fn bind<P, I>(
client: &Arc<InnerClient>, client: &Arc<InnerClient>,
statement: Statement, statement: Statement,
params: I, params: I,
) -> Result<Portal, Error> ) -> Result<Portal, Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));

View File

@ -20,6 +20,7 @@ use futures::channel::mpsc;
use futures::{future, pin_mut, ready, StreamExt, TryStreamExt}; use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
use parking_lot::Mutex; use parking_lot::Mutex;
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_types::BorrowToSql;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
@ -342,10 +343,11 @@ impl Client {
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error> pub async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let statement = statement.__convert().into_statement(self).await?; let statement = statement.__convert().into_statement(self).await?;
@ -391,10 +393,11 @@ impl Client {
/// Panics if the number of parameters provided does not match the number expected. /// Panics if the number of parameters provided does not match the number expected.
/// ///
/// [`execute`]: #method.execute /// [`execute`]: #method.execute
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error> pub async fn execute_raw<T, P, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let statement = statement.__convert().into_statement(self).await?; let statement = statement.__convert().into_statement(self).await?;

View File

@ -1,7 +1,7 @@
use crate::client::{InnerClient, Responses}; use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage; use crate::codec::FrontendMessage;
use crate::connection::RequestMessages; use crate::connection::RequestMessages;
use crate::types::{IsNull, ToSql}; use crate::types::{BorrowToSql, IsNull};
use crate::{Error, Portal, Row, Statement}; use crate::{Error, Portal, Row, Statement};
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{ready, Stream}; use futures::{ready, Stream};
@ -9,17 +9,28 @@ use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use postgres_protocol::message::backend::Message; use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use std::fmt;
use std::marker::PhantomPinned; use std::marker::PhantomPinned;
use std::pin::Pin; use std::pin::Pin;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
pub async fn query<'a, I>( struct BorrowToSqlParamsDebug<'a, T: BorrowToSql>(&'a [T]);
impl<'a, T: BorrowToSql> std::fmt::Debug for BorrowToSqlParamsDebug<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list()
.entries(self.0.iter().map(|x| x.borrow_to_sql()))
.finish()
}
}
pub async fn query<P, I>(
client: &InnerClient, client: &InnerClient,
statement: Statement, statement: Statement,
params: I, params: I,
) -> Result<RowStream, Error> ) -> Result<RowStream, Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let buf = if log_enabled!(Level::Debug) { let buf = if log_enabled!(Level::Debug) {
@ -27,7 +38,7 @@ where
debug!( debug!(
"executing statement {} with parameters: {:?}", "executing statement {} with parameters: {:?}",
statement.name(), statement.name(),
params, BorrowToSqlParamsDebug(params.as_slice()),
); );
encode(client, &statement, params)? encode(client, &statement, params)?
} else { } else {
@ -61,13 +72,14 @@ pub async fn query_portal(
}) })
} }
pub async fn execute<'a, I>( pub async fn execute<P, I>(
client: &InnerClient, client: &InnerClient,
statement: Statement, statement: Statement,
params: I, params: I,
) -> Result<u64, Error> ) -> Result<u64, Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let buf = if log_enabled!(Level::Debug) { let buf = if log_enabled!(Level::Debug) {
@ -75,7 +87,7 @@ where
debug!( debug!(
"executing statement {} with parameters: {:?}", "executing statement {} with parameters: {:?}",
statement.name(), statement.name(),
params, BorrowToSqlParamsDebug(params.as_slice()),
); );
encode(client, &statement, params)? encode(client, &statement, params)?
} else { } else {
@ -114,9 +126,10 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
Ok(responses) Ok(responses)
} }
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error> pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
client.with_buf(|buf| { client.with_buf(|buf| {
@ -127,14 +140,15 @@ where
}) })
} }
pub fn encode_bind<'a, I>( pub fn encode_bind<P, I>(
statement: &Statement, statement: &Statement,
params: I, params: I,
portal: &str, portal: &str,
buf: &mut BytesMut, buf: &mut BytesMut,
) -> Result<(), Error> ) -> Result<(), Error>
where where
I: IntoIterator<Item = &'a dyn ToSql>, P: BorrowToSql,
I: IntoIterator<Item = P>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let params = params.into_iter(); let params = params.into_iter();
@ -152,7 +166,7 @@ where
statement.name(), statement.name(),
Some(1), Some(1),
params.zip(statement.params()).enumerate(), params.zip(statement.params()).enumerate(),
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
Err(e) => { Err(e) => {