From 0eab5fad70a0af4ce754560661cf263f5a847193 Mon Sep 17 00:00:00 2001 From: Bernardo Uriarte Blanco Date: Mon, 26 Oct 2020 20:59:28 +0100 Subject: [PATCH] make iterators from `&dyn ToSql` or `T: ToSql` work as parameters --- postgres-types/src/lib.rs | 18 +++++++++++++++ tokio-postgres/src/binary_copy.rs | 7 ++++-- tokio-postgres/src/bind.rs | 7 +++--- tokio-postgres/src/client.rs | 11 +++++---- tokio-postgres/src/query.rs | 38 +++++++++++++++++++++---------- 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index e9a5846e..c8d65e77 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -951,3 +951,21 @@ fn downcast(len: usize) -> Result> { Ok(len as i32) } } + +/// A helper trait to be able create a parameters iterator from `&dyn ToSql` or `T: ToSql` +pub trait BorrowToSql { + /// Get a reference to a `ToSql` trait object + fn borrow_to_sql(&self) -> &dyn ToSql; +} + +impl BorrowToSql for &dyn ToSql { + fn borrow_to_sql(&self) -> &dyn ToSql { + *self + } +} + +impl BorrowToSql for T { + fn borrow_to_sql(&self) -> &dyn ToSql { + self + } +} diff --git a/tokio-postgres/src/binary_copy.rs b/tokio-postgres/src/binary_copy.rs index 231f202d..20064c72 100644 --- a/tokio-postgres/src/binary_copy.rs +++ b/tokio-postgres/src/binary_copy.rs @@ -6,6 +6,7 @@ use byteorder::{BigEndian, ByteOrder}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::{ready, SinkExt, Stream}; use pin_project_lite::pin_project; +use postgres_types::BorrowToSql; use std::convert::TryFrom; use std::io; use std::io::Cursor; @@ -58,9 +59,10 @@ impl BinaryCopyInWriter { /// # Panics /// /// Panics if the number of values provided does not match the number expected. - pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error> + pub async fn write_raw(self: Pin<&mut Self>, values: I) -> Result<(), Error> where - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let mut this = self.project(); @@ -79,6 +81,7 @@ impl BinaryCopyInWriter { let idx = this.buf.len(); this.buf.put_i32(0); let len = match value + .borrow_to_sql() .to_sql_checked(type_, this.buf) .map_err(|e| Error::to_sql(e, i))? { diff --git a/tokio-postgres/src/bind.rs b/tokio-postgres/src/bind.rs index 69823a9a..9c5c4921 100644 --- a/tokio-postgres/src/bind.rs +++ b/tokio-postgres/src/bind.rs @@ -1,7 +1,7 @@ use crate::client::InnerClient; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::types::ToSql; +use crate::types::BorrowToSql; use crate::{query, Error, Portal, Statement}; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; @@ -10,13 +10,14 @@ use std::sync::Arc; static NEXT_ID: AtomicUsize = AtomicUsize::new(0); -pub async fn bind<'a, I>( +pub async fn bind( client: &Arc, statement: Statement, params: I, ) -> Result where - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index e19caae8..ecf3ea60 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -20,6 +20,7 @@ use futures::channel::mpsc; use futures::{future, pin_mut, ready, StreamExt, TryStreamExt}; use parking_lot::Mutex; use postgres_protocol::message::backend::Message; +use postgres_types::BorrowToSql; use std::collections::HashMap; use std::fmt; use std::sync::Arc; @@ -342,10 +343,11 @@ impl Client { /// # Ok(()) /// # } /// ``` - pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + pub async fn query_raw(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let statement = statement.__convert().into_statement(self).await?; @@ -391,10 +393,11 @@ impl Client { /// Panics if the number of parameters provided does not match the number expected. /// /// [`execute`]: #method.execute - pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + pub async fn execute_raw(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let statement = statement.__convert().into_statement(self).await?; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 7792f0a8..2245b982 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::types::{IsNull, ToSql}; +use crate::types::{BorrowToSql, IsNull}; use crate::{Error, Portal, Row, Statement}; use bytes::{Bytes, BytesMut}; use futures::{ready, Stream}; @@ -9,17 +9,28 @@ use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; +use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; -pub async fn query<'a, I>( +struct BorrowToSqlParamsDebug<'a, T: BorrowToSql>(&'a [T]); +impl<'a, T: BorrowToSql> std::fmt::Debug for BorrowToSqlParamsDebug<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(self.0.iter().map(|x| x.borrow_to_sql())) + .finish() + } +} + +pub async fn query( client: &InnerClient, statement: Statement, params: I, ) -> Result where - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let buf = if log_enabled!(Level::Debug) { @@ -27,7 +38,7 @@ where debug!( "executing statement {} with parameters: {:?}", statement.name(), - params, + BorrowToSqlParamsDebug(params.as_slice()), ); encode(client, &statement, params)? } else { @@ -61,13 +72,14 @@ pub async fn query_portal( }) } -pub async fn execute<'a, I>( +pub async fn execute( client: &InnerClient, statement: Statement, params: I, ) -> Result where - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let buf = if log_enabled!(Level::Debug) { @@ -75,7 +87,7 @@ where debug!( "executing statement {} with parameters: {:?}", statement.name(), - params, + BorrowToSqlParamsDebug(params.as_slice()), ); encode(client, &statement, params)? } else { @@ -114,9 +126,10 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { Ok(responses) } -pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result +pub fn encode(client: &InnerClient, statement: &Statement, params: I) -> Result where - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { client.with_buf(|buf| { @@ -127,14 +140,15 @@ where }) } -pub fn encode_bind<'a, I>( +pub fn encode_bind( statement: &Statement, params: I, portal: &str, buf: &mut BytesMut, ) -> Result<(), Error> where - I: IntoIterator, + P: BorrowToSql, + I: IntoIterator, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); @@ -152,7 +166,7 @@ where statement.name(), Some(1), params.zip(statement.params()).enumerate(), - |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { + |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) { Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Err(e) => {