From ed056758887464723881c3f6f21b0140746313de Mon Sep 17 00:00:00 2001 From: Matt Duch Date: Mon, 8 Aug 2022 19:22:06 -0500 Subject: [PATCH 1/2] derive generic FromSql/ToSql --- postgres-derive-test/src/composites.rs | 69 +++++++++++++++++++++++++- postgres-derive-test/src/lib.rs | 24 +++++++++ postgres-derive/src/fromsql.rs | 28 +++++++++-- postgres-derive/src/tosql.rs | 3 +- 4 files changed, 116 insertions(+), 8 deletions(-) diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index ed60bf48..66763351 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -1,6 +1,6 @@ -use crate::test_type; +use crate::{test_type, test_type_asymmetric}; use postgres::{Client, NoTls}; -use postgres_types::{FromSql, ToSql, WrongType}; +use postgres_types::{FromSql, FromSqlOwned, ToSql, WrongType}; use std::error::Error; #[test] @@ -238,3 +238,68 @@ fn raw_ident_field() { test_type(&mut conn, "inventory_item", &[(item, "ROW('foo')")]); } + +#[test] +fn generics() { + #[derive(FromSql, Debug, PartialEq)] + struct InventoryItem + where + U: FromSqlOwned, + { + name: String, + supplier_id: T, + price: Option, + } + + // doesn't make sense to implement derived FromSql on a type with borrows + #[derive(ToSql, Debug, PartialEq)] + #[postgres(name = "InventoryItem")] + struct InventoryItemRef<'a, T: 'a + ToSql, U> + where + U: 'a + ToSql, + { + name: &'a str, + supplier_id: &'a T, + price: Option<&'a U>, + } + + const NAME: &str = "foobar"; + const SUPPLIER_ID: i32 = 100; + const PRICE: f64 = 15.50; + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.\"InventoryItem\" AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItemRef { + name: NAME, + supplier_id: &SUPPLIER_ID, + price: Some(&PRICE), + }; + + let item_null = InventoryItemRef { + name: NAME, + supplier_id: &SUPPLIER_ID, + price: None, + }; + + test_type_asymmetric( + &mut conn, + "\"InventoryItem\"", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + |t: &InventoryItemRef, f: &InventoryItem| { + t.name == f.name.as_str() + && t.supplier_id == &f.supplier_id + && t.price == f.price.as_ref() + }, + ); +} diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs index 279ed141..8bfd147f 100644 --- a/postgres-derive-test/src/lib.rs +++ b/postgres-derive-test/src/lib.rs @@ -27,6 +27,30 @@ where } } +pub fn test_type_asymmetric( + conn: &mut Client, + sql_type: &str, + checks: &[(T, S)], + cmp: C, +) where + T: ToSql + Sync, + F: FromSqlOwned, + S: fmt::Display, + C: Fn(&T, &F) -> bool, +{ + for &(ref val, ref repr) in checks.iter() { + let stmt = conn + .prepare(&*format!("SELECT {}::{}", *repr, sql_type)) + .unwrap(); + let result: F = conn.query_one(&stmt, &[]).unwrap().get(0); + assert!(cmp(val, &result)); + + let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap(); + let result: F = conn.query_one(&stmt, &[val]).unwrap().get(0); + assert!(cmp(val, &result)); + } +} + #[test] fn compile_fail() { trybuild::TestCases::new().compile_fail("src/compile-fail/*.rs"); diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index c89cbb5e..41534365 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -1,7 +1,10 @@ -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use std::iter; -use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident}; +use syn::{ + Data, DataStruct, DeriveInput, Error, Fields, GenericParam, Generics, Ident, Lifetime, + LifetimeDef, +}; use crate::accepts; use crate::composites::Field; @@ -86,10 +89,13 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { }; let ident = &input.ident; + let (generics, lifetime) = build_generics(&input.generics); + let (impl_generics, _, _) = generics.split_for_impl(); + let (_, ty_generics, where_clause) = input.generics.split_for_impl(); let out = quote! { - impl<'a> postgres_types::FromSql<'a> for #ident { - fn from_sql(_type: &postgres_types::Type, buf: &'a [u8]) - -> std::result::Result<#ident, + impl#impl_generics postgres_types::FromSql<#lifetime> for #ident#ty_generics #where_clause { + fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8]) + -> std::result::Result<#ident#ty_generics, std::boxed::Box> { @@ -200,3 +206,15 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream { }) } } + +fn build_generics(source: &Generics) -> (Generics, Lifetime) { + let mut out = source.to_owned(); + // don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime + let lifetime = Lifetime::new("'a", Span::call_site()); + out.params.insert( + 0, + GenericParam::Lifetime(LifetimeDef::new(lifetime.to_owned())), + ); + + (out, lifetime) +} diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index 96f26138..299074f8 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -82,8 +82,9 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { }; let ident = &input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let out = quote! { - impl postgres_types::ToSql for #ident { + impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause { fn to_sql(&self, _type: &postgres_types::Type, buf: &mut postgres_types::private::BytesMut) From 3827b2e44279bbb8b869e07cf91c383d2445893d Mon Sep 17 00:00:00 2001 From: Matt Duch Date: Thu, 25 Aug 2022 19:29:09 -0500 Subject: [PATCH 2/2] derive bounds on generics --- postgres-derive-test/src/composites.rs | 10 ++++----- postgres-derive/src/composites.rs | 26 ++++++++++++++++++++++- postgres-derive/src/fromsql.rs | 29 +++++++++++++++++++++++--- postgres-derive/src/tosql.rs | 20 +++++++++++++++--- 4 files changed, 73 insertions(+), 12 deletions(-) diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index 66763351..a1b76345 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -1,6 +1,6 @@ use crate::{test_type, test_type_asymmetric}; use postgres::{Client, NoTls}; -use postgres_types::{FromSql, FromSqlOwned, ToSql, WrongType}; +use postgres_types::{FromSql, ToSql, WrongType}; use std::error::Error; #[test] @@ -242,9 +242,9 @@ fn raw_ident_field() { #[test] fn generics() { #[derive(FromSql, Debug, PartialEq)] - struct InventoryItem + struct InventoryItem where - U: FromSqlOwned, + U: Clone, { name: String, supplier_id: T, @@ -254,9 +254,9 @@ fn generics() { // doesn't make sense to implement derived FromSql on a type with borrows #[derive(ToSql, Debug, PartialEq)] #[postgres(name = "InventoryItem")] - struct InventoryItemRef<'a, T: 'a + ToSql, U> + struct InventoryItemRef<'a, T: 'a + Clone, U> where - U: 'a + ToSql, + U: 'a + Clone, { name: &'a str, supplier_id: &'a T, diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index c1e49515..15bfabc1 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -1,4 +1,8 @@ -use syn::{Error, Ident, Type}; +use proc_macro2::Span; +use syn::{ + punctuated::Punctuated, Error, GenericParam, Generics, Ident, Path, PathSegment, Type, + TypeParamBound, +}; use crate::overrides::Overrides; @@ -26,3 +30,23 @@ impl Field { }) } } + +pub(crate) fn append_generic_bound(mut generics: Generics, bound: &TypeParamBound) -> Generics { + for param in &mut generics.params { + if let GenericParam::Type(param) = param { + param.bounds.push(bound.to_owned()) + } + } + generics +} + +pub(crate) fn new_derive_path(last: PathSegment) -> Path { + let mut path = Path { + leading_colon: None, + segments: Punctuated::new(), + }; + path.segments + .push(Ident::new("postgres_types", Span::call_site()).into()); + path.segments.push(last); + path +} diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index 41534365..f458c6e3 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -2,12 +2,15 @@ use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote}; use std::iter; use syn::{ - Data, DataStruct, DeriveInput, Error, Fields, GenericParam, Generics, Ident, Lifetime, - LifetimeDef, + punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput, + Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, LifetimeDef, + PathArguments, PathSegment, }; +use syn::{TraitBound, TraitBoundModifier, TypeParamBound}; use crate::accepts; use crate::composites::Field; +use crate::composites::{append_generic_bound, new_derive_path}; use crate::enums::Variant; use crate::overrides::Overrides; @@ -208,9 +211,10 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream { } fn build_generics(source: &Generics) -> (Generics, Lifetime) { - let mut out = source.to_owned(); // don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime let lifetime = Lifetime::new("'a", Span::call_site()); + + let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime)); out.params.insert( 0, GenericParam::Lifetime(LifetimeDef::new(lifetime.to_owned())), @@ -218,3 +222,22 @@ fn build_generics(source: &Generics) -> (Generics, Lifetime) { (out, lifetime) } + +fn new_fromsql_bound(lifetime: &Lifetime) -> TypeParamBound { + let mut path_segment: PathSegment = Ident::new("FromSql", Span::call_site()).into(); + let mut seg_args = Punctuated::new(); + seg_args.push(GenericArgument::Lifetime(lifetime.to_owned())); + path_segment.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments { + colon2_token: None, + lt_token: token::Lt::default(), + args: seg_args, + gt_token: token::Gt::default(), + }); + + TypeParamBound::Trait(TraitBound { + lifetimes: None, + modifier: TraitBoundModifier::None, + paren_token: None, + path: new_derive_path(path_segment), + }) +} diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index 299074f8..e51acc7f 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -1,10 +1,14 @@ -use proc_macro2::TokenStream; +use proc_macro2::{Span, TokenStream}; use quote::quote; use std::iter; -use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident}; +use syn::{ + Data, DataStruct, DeriveInput, Error, Fields, Ident, TraitBound, TraitBoundModifier, + TypeParamBound, +}; use crate::accepts; use crate::composites::Field; +use crate::composites::{append_generic_bound, new_derive_path}; use crate::enums::Variant; use crate::overrides::Overrides; @@ -82,7 +86,8 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { }; let ident = &input.ident; - let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound()); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let out = quote! { impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause { fn to_sql(&self, @@ -182,3 +187,12 @@ fn composite_body(fields: &[Field]) -> TokenStream { std::result::Result::Ok(postgres_types::IsNull::No) } } + +fn new_tosql_bound() -> TypeParamBound { + TypeParamBound::Trait(TraitBound { + lifetimes: None, + modifier: TraitBoundModifier::None, + paren_token: None, + path: new_derive_path(Ident::new("ToSql", Span::call_site()).into()), + }) +}