derive generic FromSql/ToSql

This commit is contained in:
Matt Duch 2022-08-08 19:22:06 -05:00
parent 54331183ea
commit ed05675888
4 changed files with 116 additions and 8 deletions

View File

@ -1,6 +1,6 @@
use crate::test_type;
use crate::{test_type, test_type_asymmetric};
use postgres::{Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use postgres_types::{FromSql, FromSqlOwned, ToSql, WrongType};
use std::error::Error;
#[test]
@ -238,3 +238,68 @@ fn raw_ident_field() {
test_type(&mut conn, "inventory_item", &[(item, "ROW('foo')")]);
}
#[test]
fn generics() {
#[derive(FromSql, Debug, PartialEq)]
struct InventoryItem<T: FromSqlOwned, U>
where
U: FromSqlOwned,
{
name: String,
supplier_id: T,
price: Option<U>,
}
// doesn't make sense to implement derived FromSql on a type with borrows
#[derive(ToSql, Debug, PartialEq)]
#[postgres(name = "InventoryItem")]
struct InventoryItemRef<'a, T: 'a + ToSql, U>
where
U: 'a + ToSql,
{
name: &'a str,
supplier_id: &'a T,
price: Option<&'a U>,
}
const NAME: &str = "foobar";
const SUPPLIER_ID: i32 = 100;
const PRICE: f64 = 15.50;
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.\"InventoryItem\" AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItemRef {
name: NAME,
supplier_id: &SUPPLIER_ID,
price: Some(&PRICE),
};
let item_null = InventoryItemRef {
name: NAME,
supplier_id: &SUPPLIER_ID,
price: None,
};
test_type_asymmetric(
&mut conn,
"\"InventoryItem\"",
&[
(item, "ROW('foobar', 100, 15.50)"),
(item_null, "ROW('foobar', 100, NULL)"),
],
|t: &InventoryItemRef<i32, f64>, f: &InventoryItem<i32, f64>| {
t.name == f.name.as_str()
&& t.supplier_id == &f.supplier_id
&& t.price == f.price.as_ref()
},
);
}

View File

@ -27,6 +27,30 @@ where
}
}
pub fn test_type_asymmetric<T, F, S, C>(
conn: &mut Client,
sql_type: &str,
checks: &[(T, S)],
cmp: C,
) where
T: ToSql + Sync,
F: FromSqlOwned,
S: fmt::Display,
C: Fn(&T, &F) -> bool,
{
for &(ref val, ref repr) in checks.iter() {
let stmt = conn
.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
.unwrap();
let result: F = conn.query_one(&stmt, &[]).unwrap().get(0);
assert!(cmp(val, &result));
let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
let result: F = conn.query_one(&stmt, &[val]).unwrap().get(0);
assert!(cmp(val, &result));
}
}
#[test]
fn compile_fail() {
trybuild::TestCases::new().compile_fail("src/compile-fail/*.rs");

View File

@ -1,7 +1,10 @@
use proc_macro2::TokenStream;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use std::iter;
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
use syn::{
Data, DataStruct, DeriveInput, Error, Fields, GenericParam, Generics, Ident, Lifetime,
LifetimeDef,
};
use crate::accepts;
use crate::composites::Field;
@ -86,10 +89,13 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
};
let ident = &input.ident;
let (generics, lifetime) = build_generics(&input.generics);
let (impl_generics, _, _) = generics.split_for_impl();
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let out = quote! {
impl<'a> postgres_types::FromSql<'a> for #ident {
fn from_sql(_type: &postgres_types::Type, buf: &'a [u8])
-> std::result::Result<#ident,
impl#impl_generics postgres_types::FromSql<#lifetime> for #ident#ty_generics #where_clause {
fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8])
-> std::result::Result<#ident#ty_generics,
std::boxed::Box<dyn std::error::Error +
std::marker::Sync +
std::marker::Send>> {
@ -200,3 +206,15 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
})
}
}
fn build_generics(source: &Generics) -> (Generics, Lifetime) {
let mut out = source.to_owned();
// don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
let lifetime = Lifetime::new("'a", Span::call_site());
out.params.insert(
0,
GenericParam::Lifetime(LifetimeDef::new(lifetime.to_owned())),
);
(out, lifetime)
}

View File

@ -82,8 +82,9 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
};
let ident = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let out = quote! {
impl postgres_types::ToSql for #ident {
impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause {
fn to_sql(&self,
_type: &postgres_types::Type,
buf: &mut postgres_types::private::BytesMut)