Move postgres-derive in-tree
This commit is contained in:
parent
31855141d2
commit
218d889042
@ -2,6 +2,8 @@
|
||||
members = [
|
||||
"codegen",
|
||||
"postgres",
|
||||
"postgres-derive",
|
||||
"postgres-derive-test",
|
||||
"postgres-native-tls",
|
||||
"postgres-openssl",
|
||||
"postgres-protocol",
|
||||
|
9
postgres-derive-test/Cargo.toml
Normal file
9
postgres-derive-test/Cargo.toml
Normal file
@ -0,0 +1,9 @@
|
||||
[package]
|
||||
name = "postgres-derive-test"
|
||||
version = "0.1.0"
|
||||
authors = ["Steven Fackler <sfackler@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
postgres-types = { path = "../postgres-types", features = ["derive"] }
|
||||
postgres = { path = "../postgres" }
|
217
postgres-derive-test/src/composites.rs
Normal file
217
postgres-derive-test/src/composites.rs
Normal file
@ -0,0 +1,217 @@
|
||||
use crate::test_type;
|
||||
use postgres::{Client, NoTls};
|
||||
use postgres_types::{FromSql, ToSql, WrongType};
|
||||
use std::error::Error;
|
||||
|
||||
#[test]
|
||||
fn defaults() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
struct InventoryItem {
|
||||
name: String,
|
||||
supplier_id: i32,
|
||||
price: Option<f64>,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"CREATE TYPE pg_temp.\"InventoryItem\" AS (
|
||||
name TEXT,
|
||||
supplier_id INT,
|
||||
price DOUBLE PRECISION
|
||||
);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let item = InventoryItem {
|
||||
name: "foobar".to_owned(),
|
||||
supplier_id: 100,
|
||||
price: Some(15.50),
|
||||
};
|
||||
|
||||
let item_null = InventoryItem {
|
||||
name: "foobar".to_owned(),
|
||||
supplier_id: 100,
|
||||
price: None,
|
||||
};
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"\"InventoryItem\"",
|
||||
&[
|
||||
(item, "ROW('foobar', 100, 15.50)"),
|
||||
(item_null, "ROW('foobar', 100, NULL)"),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_overrides() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "inventory_item")]
|
||||
struct InventoryItem {
|
||||
#[postgres(name = "name")]
|
||||
_name: String,
|
||||
#[postgres(name = "supplier_id")]
|
||||
_supplier_id: i32,
|
||||
#[postgres(name = "price")]
|
||||
_price: Option<f64>,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"CREATE TYPE pg_temp.inventory_item AS (
|
||||
name TEXT,
|
||||
supplier_id INT,
|
||||
price DOUBLE PRECISION
|
||||
);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let item = InventoryItem {
|
||||
_name: "foobar".to_owned(),
|
||||
_supplier_id: 100,
|
||||
_price: Some(15.50),
|
||||
};
|
||||
|
||||
let item_null = InventoryItem {
|
||||
_name: "foobar".to_owned(),
|
||||
_supplier_id: 100,
|
||||
_price: None,
|
||||
};
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"inventory_item",
|
||||
&[
|
||||
(item, "ROW('foobar', 100, 15.50)"),
|
||||
(item_null, "ROW('foobar', 100, NULL)"),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_name() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
struct InventoryItem {
|
||||
name: String,
|
||||
supplier_id: i32,
|
||||
price: Option<f64>,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"CREATE TYPE pg_temp.inventory_item AS (
|
||||
name TEXT,
|
||||
supplier_id INT,
|
||||
price DOUBLE PRECISION
|
||||
);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let item = InventoryItem {
|
||||
name: "foobar".to_owned(),
|
||||
supplier_id: 100,
|
||||
price: Some(15.50),
|
||||
};
|
||||
|
||||
let err = conn
|
||||
.execute("SELECT $1::inventory_item", &[&item])
|
||||
.unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extra_field() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "inventory_item")]
|
||||
struct InventoryItem {
|
||||
name: String,
|
||||
supplier_id: i32,
|
||||
price: Option<f64>,
|
||||
foo: i32,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"CREATE TYPE pg_temp.inventory_item AS (
|
||||
name TEXT,
|
||||
supplier_id INT,
|
||||
price DOUBLE PRECISION
|
||||
);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let item = InventoryItem {
|
||||
name: "foobar".to_owned(),
|
||||
supplier_id: 100,
|
||||
price: Some(15.50),
|
||||
foo: 0,
|
||||
};
|
||||
|
||||
let err = conn
|
||||
.execute("SELECT $1::inventory_item", &[&item])
|
||||
.unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_field() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "inventory_item")]
|
||||
struct InventoryItem {
|
||||
name: String,
|
||||
supplier_id: i32,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"CREATE TYPE pg_temp.inventory_item AS (
|
||||
name TEXT,
|
||||
supplier_id INT,
|
||||
price DOUBLE PRECISION
|
||||
);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let item = InventoryItem {
|
||||
name: "foobar".to_owned(),
|
||||
supplier_id: 100,
|
||||
};
|
||||
|
||||
let err = conn
|
||||
.execute("SELECT $1::inventory_item", &[&item])
|
||||
.unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_type() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "inventory_item")]
|
||||
struct InventoryItem {
|
||||
name: String,
|
||||
supplier_id: i32,
|
||||
price: i32,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"CREATE TYPE pg_temp.inventory_item AS (
|
||||
name TEXT,
|
||||
supplier_id INT,
|
||||
price DOUBLE PRECISION
|
||||
);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let item = InventoryItem {
|
||||
name: "foobar".to_owned(),
|
||||
supplier_id: 100,
|
||||
price: 0,
|
||||
};
|
||||
|
||||
let err = conn
|
||||
.execute("SELECT $1::inventory_item", &[&item])
|
||||
.unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
121
postgres-derive-test/src/domains.rs
Normal file
121
postgres-derive-test/src/domains.rs
Normal file
@ -0,0 +1,121 @@
|
||||
use crate::test_type;
|
||||
use postgres::{Client, NoTls};
|
||||
use postgres_types::{FromSql, ToSql, WrongType};
|
||||
use std::error::Error;
|
||||
|
||||
#[test]
|
||||
fn defaults() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
struct SessionId(Vec<u8>);
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute(
|
||||
"CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);",
|
||||
&[],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"\"SessionId\"",
|
||||
&[(
|
||||
SessionId(b"0123456789abcdef".to_vec()),
|
||||
"'0123456789abcdef'",
|
||||
)],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_overrides() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "session_id")]
|
||||
struct SessionId(Vec<u8>);
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute(
|
||||
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);",
|
||||
&[],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"session_id",
|
||||
&[(
|
||||
SessionId(b"0123456789abcdef".to_vec()),
|
||||
"'0123456789abcdef'",
|
||||
)],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_name() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
struct SessionId(Vec<u8>);
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute(
|
||||
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);",
|
||||
&[],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = conn
|
||||
.execute("SELECT $1::session_id", &[&SessionId(vec![])])
|
||||
.unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_type() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "session_id")]
|
||||
struct SessionId(i32);
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute(
|
||||
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);",
|
||||
&[],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = conn
|
||||
.execute("SELECT $1::session_id", &[&SessionId(0)])
|
||||
.unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain_in_composite() {
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "domain")]
|
||||
struct Domain(String);
|
||||
|
||||
#[derive(FromSql, ToSql, Debug, PartialEq)]
|
||||
#[postgres(name = "composite")]
|
||||
struct Composite {
|
||||
domain: Domain,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.batch_execute(
|
||||
"
|
||||
CREATE DOMAIN pg_temp.domain AS TEXT;\
|
||||
CREATE TYPE pg_temp.composite AS (
|
||||
domain domain
|
||||
);
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"composite",
|
||||
&[(
|
||||
Composite {
|
||||
domain: Domain("hello".to_string()),
|
||||
},
|
||||
"ROW('hello')",
|
||||
)],
|
||||
);
|
||||
}
|
104
postgres-derive-test/src/enums.rs
Normal file
104
postgres-derive-test/src/enums.rs
Normal file
@ -0,0 +1,104 @@
|
||||
use crate::test_type;
|
||||
use postgres::{Client, NoTls};
|
||||
use postgres_types::{FromSql, ToSql, WrongType};
|
||||
use std::error::Error;
|
||||
|
||||
#[test]
|
||||
fn defaults() {
|
||||
#[derive(Debug, ToSql, FromSql, PartialEq)]
|
||||
enum Foo {
|
||||
Bar,
|
||||
Baz,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
|
||||
.unwrap();
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"\"Foo\"",
|
||||
&[(Foo::Bar, "'Bar'"), (Foo::Baz, "'Baz'")],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_overrides() {
|
||||
#[derive(Debug, ToSql, FromSql, PartialEq)]
|
||||
#[postgres(name = "mood")]
|
||||
enum Mood {
|
||||
#[postgres(name = "sad")]
|
||||
Sad,
|
||||
#[postgres(name = "ok")]
|
||||
Ok,
|
||||
#[postgres(name = "happy")]
|
||||
Happy,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute(
|
||||
"CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')",
|
||||
&[],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
test_type(
|
||||
&mut conn,
|
||||
"mood",
|
||||
&[
|
||||
(Mood::Sad, "'sad'"),
|
||||
(Mood::Ok, "'ok'"),
|
||||
(Mood::Happy, "'happy'"),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_name() {
|
||||
#[derive(Debug, ToSql, FromSql, PartialEq)]
|
||||
enum Foo {
|
||||
Bar,
|
||||
Baz,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
|
||||
.unwrap();
|
||||
|
||||
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extra_variant() {
|
||||
#[derive(Debug, ToSql, FromSql, PartialEq)]
|
||||
#[postgres(name = "foo")]
|
||||
enum Foo {
|
||||
Bar,
|
||||
Baz,
|
||||
Buz,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
|
||||
.unwrap();
|
||||
|
||||
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_variant() {
|
||||
#[derive(Debug, ToSql, FromSql, PartialEq)]
|
||||
#[postgres(name = "foo")]
|
||||
enum Foo {
|
||||
Bar,
|
||||
}
|
||||
|
||||
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
|
||||
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
|
||||
.unwrap();
|
||||
|
||||
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
|
||||
assert!(err.source().unwrap().is::<WrongType>());
|
||||
}
|
27
postgres-derive-test/src/lib.rs
Normal file
27
postgres-derive-test/src/lib.rs
Normal file
@ -0,0 +1,27 @@
|
||||
#![cfg(test)]
|
||||
|
||||
use postgres::Client;
|
||||
use postgres_types::{FromSqlOwned, ToSql};
|
||||
use std::fmt;
|
||||
|
||||
mod composites;
|
||||
mod domains;
|
||||
mod enums;
|
||||
|
||||
pub fn test_type<T, S>(conn: &mut Client, sql_type: &str, checks: &[(T, S)])
|
||||
where
|
||||
T: PartialEq + FromSqlOwned + ToSql + Sync,
|
||||
S: fmt::Display,
|
||||
{
|
||||
for &(ref val, ref repr) in checks.iter() {
|
||||
let stmt = conn
|
||||
.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
|
||||
.unwrap();
|
||||
let result = conn.query_one(&stmt, &[]).unwrap().get(0);
|
||||
assert_eq!(val, &result);
|
||||
|
||||
let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
|
||||
let result = conn.query_one(&stmt, &[val]).unwrap().get(0);
|
||||
assert_eq!(val, &result);
|
||||
}
|
||||
}
|
18
postgres-derive/Cargo.toml
Normal file
18
postgres-derive/Cargo.toml
Normal file
@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "postgres-derive"
|
||||
version = "0.3.3"
|
||||
authors = ["Steven Fackler <sfackler@palantir.com>"]
|
||||
license = "MIT/Apache-2.0"
|
||||
description = "Deriving plugin support for Postgres enum, domain, and composite types"
|
||||
repository = "https://github.com/sfackler/rust-postgres-derive"
|
||||
readme = "README.md"
|
||||
keywords = ["database", "postgres", "postgresql", "sql"]
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
test = false
|
||||
|
||||
[dependencies]
|
||||
syn = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
quote = "1.0"
|
86
postgres-derive/src/accepts.rs
Normal file
86
postgres-derive/src/accepts.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use std::iter;
|
||||
use syn::Ident;
|
||||
|
||||
use composites::Field;
|
||||
use enums::Variant;
|
||||
|
||||
pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
|
||||
let ty = &field.ty;
|
||||
|
||||
quote! {
|
||||
if type_.name() != #name {
|
||||
return false;
|
||||
}
|
||||
|
||||
match *type_.kind() {
|
||||
::postgres_types::Kind::Domain(ref type_) => {
|
||||
<#ty as ::postgres_types::ToSql>::accepts(type_)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
|
||||
let num_variants = variants.len();
|
||||
let variant_names = variants.iter().map(|v| &v.name);
|
||||
|
||||
quote! {
|
||||
if type_.name() != #name {
|
||||
return false;
|
||||
}
|
||||
|
||||
match *type_.kind() {
|
||||
::postgres_types::Kind::Enum(ref variants) => {
|
||||
if variants.len() != #num_variants {
|
||||
return false;
|
||||
}
|
||||
|
||||
variants.iter().all(|v| {
|
||||
match &**v {
|
||||
#(
|
||||
#variant_names => true,
|
||||
)*
|
||||
_ => false,
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream {
|
||||
let num_fields = fields.len();
|
||||
let trait_ = Ident::new(trait_, Span::call_site());
|
||||
let traits = iter::repeat(&trait_);
|
||||
let field_names = fields.iter().map(|f| &f.name);
|
||||
let field_types = fields.iter().map(|f| &f.type_);
|
||||
|
||||
quote! {
|
||||
if type_.name() != #name {
|
||||
return false;
|
||||
}
|
||||
|
||||
match *type_.kind() {
|
||||
::postgres_types::Kind::Composite(ref fields) => {
|
||||
if fields.len() != #num_fields {
|
||||
return false;
|
||||
}
|
||||
|
||||
fields.iter().all(|f| {
|
||||
match f.name() {
|
||||
#(
|
||||
#field_names => {
|
||||
<#field_types as ::postgres_types::#traits>::accepts(f.type_())
|
||||
}
|
||||
)*
|
||||
_ => false,
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
22
postgres-derive/src/composites.rs
Normal file
22
postgres-derive/src/composites.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use syn::{self, Error, Ident, Type};
|
||||
|
||||
use overrides::Overrides;
|
||||
|
||||
pub struct Field {
|
||||
pub name: String,
|
||||
pub ident: Ident,
|
||||
pub type_: Type,
|
||||
}
|
||||
|
||||
impl Field {
|
||||
pub fn parse(raw: &syn::Field) -> Result<Field, Error> {
|
||||
let overrides = Overrides::extract(&raw.attrs)?;
|
||||
|
||||
let ident = raw.ident.as_ref().unwrap().clone();
|
||||
Ok(Field {
|
||||
name: overrides.name.unwrap_or_else(|| ident.to_string()),
|
||||
ident,
|
||||
type_: raw.ty.clone(),
|
||||
})
|
||||
}
|
||||
}
|
28
postgres-derive/src/enums.rs
Normal file
28
postgres-derive/src/enums.rs
Normal file
@ -0,0 +1,28 @@
|
||||
use syn::{self, Error, Fields, Ident};
|
||||
|
||||
use overrides::Overrides;
|
||||
|
||||
pub struct Variant {
|
||||
pub ident: Ident,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
impl Variant {
|
||||
pub fn parse(raw: &syn::Variant) -> Result<Variant, Error> {
|
||||
match raw.fields {
|
||||
Fields::Unit => {}
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
raw,
|
||||
"non-C-like enums are not supported",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
let overrides = Overrides::extract(&raw.attrs)?;
|
||||
Ok(Variant {
|
||||
ident: raw.ident.clone(),
|
||||
name: overrides.name.unwrap_or_else(|| raw.ident.to_string()),
|
||||
})
|
||||
}
|
||||
}
|
200
postgres-derive/src/fromsql.rs
Normal file
200
postgres-derive/src/fromsql.rs
Normal file
@ -0,0 +1,200 @@
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use std::iter;
|
||||
use syn::{self, Data, DataStruct, DeriveInput, Error, Fields, Ident};
|
||||
|
||||
use accepts;
|
||||
use composites::Field;
|
||||
use enums::Variant;
|
||||
use overrides::Overrides;
|
||||
|
||||
pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
|
||||
let overrides = Overrides::extract(&input.attrs)?;
|
||||
|
||||
let name = overrides.name.unwrap_or_else(|| input.ident.to_string());
|
||||
|
||||
let (accepts_body, to_sql_body) = match input.data {
|
||||
Data::Enum(ref data) => {
|
||||
let variants = data
|
||||
.variants
|
||||
.iter()
|
||||
.map(Variant::parse)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
(
|
||||
accepts::enum_body(&name, &variants),
|
||||
enum_body(&input.ident, &variants),
|
||||
)
|
||||
}
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(ref fields),
|
||||
..
|
||||
}) if fields.unnamed.len() == 1 => {
|
||||
let field = fields.unnamed.first().unwrap();
|
||||
(
|
||||
domain_accepts_body(&name, field),
|
||||
domain_body(&input.ident, field),
|
||||
)
|
||||
}
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Named(ref fields),
|
||||
..
|
||||
}) => {
|
||||
let fields = fields
|
||||
.named
|
||||
.iter()
|
||||
.map(Field::parse)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
(
|
||||
accepts::composite_body(&name, "FromSql", &fields),
|
||||
composite_body(&input.ident, &fields),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
input,
|
||||
"#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let ident = &input.ident;
|
||||
let out = quote! {
|
||||
impl<'a> ::postgres_types::FromSql<'a> for #ident {
|
||||
fn from_sql(_type: &::postgres_types::Type, buf: &'a [u8])
|
||||
-> ::std::result::Result<#ident,
|
||||
::std::boxed::Box<::std::error::Error +
|
||||
::std::marker::Sync +
|
||||
::std::marker::Send>> {
|
||||
#to_sql_body
|
||||
}
|
||||
|
||||
fn accepts(type_: &::postgres_types::Type) -> bool {
|
||||
#accepts_body
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
|
||||
let variant_names = variants.iter().map(|v| &v.name);
|
||||
let idents = iter::repeat(ident);
|
||||
let variant_idents = variants.iter().map(|v| &v.ident);
|
||||
|
||||
quote! {
|
||||
match ::std::str::from_utf8(buf)? {
|
||||
#(
|
||||
#variant_names => ::std::result::Result::Ok(#idents::#variant_idents),
|
||||
)*
|
||||
s => {
|
||||
::std::result::Result::Err(
|
||||
::std::convert::Into::into(format!("invalid variant `{}`", s)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Domains are sometimes but not always just represented by the bare type (!?)
|
||||
fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
|
||||
let ty = &field.ty;
|
||||
let normal_body = accepts::domain_body(name, field);
|
||||
|
||||
quote! {
|
||||
if <#ty as ::postgres_types::FromSql>::accepts(type_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
#normal_body
|
||||
}
|
||||
}
|
||||
|
||||
fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
|
||||
let ty = &field.ty;
|
||||
quote! {
|
||||
<#ty as ::postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
|
||||
}
|
||||
}
|
||||
|
||||
fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
|
||||
let temp_vars = &fields
|
||||
.iter()
|
||||
.map(|f| Ident::new(&format!("__{}", f.ident), Span::call_site()))
|
||||
.collect::<Vec<_>>();
|
||||
let field_names = &fields.iter().map(|f| &f.name).collect::<Vec<_>>();
|
||||
let field_idents = &fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
|
||||
|
||||
quote! {
|
||||
fn read_be_i32(buf: &mut &[u8]) -> ::std::io::Result<i32> {
|
||||
let mut bytes = [0; 4];
|
||||
::std::io::Read::read_exact(buf, &mut bytes)?;
|
||||
let num = ((bytes[0] as i32) << 24) |
|
||||
((bytes[1] as i32) << 16) |
|
||||
((bytes[2] as i32) << 8) |
|
||||
(bytes[3] as i32);
|
||||
::std::result::Result::Ok(num)
|
||||
}
|
||||
|
||||
fn read_value<'a, T>(type_: &::postgres_types::Type,
|
||||
buf: &mut &'a [u8])
|
||||
-> ::std::result::Result<T,
|
||||
::std::boxed::Box<::std::error::Error +
|
||||
::std::marker::Sync +
|
||||
::std::marker::Send>>
|
||||
where T: ::postgres_types::FromSql<'a>
|
||||
{
|
||||
let len = read_be_i32(buf)?;
|
||||
let value = if len < 0 {
|
||||
::std::option::Option::None
|
||||
} else {
|
||||
if len as usize > buf.len() {
|
||||
return ::std::result::Result::Err(
|
||||
::std::convert::Into::into("invalid buffer size"));
|
||||
}
|
||||
let (head, tail) = buf.split_at(len as usize);
|
||||
*buf = tail;
|
||||
::std::option::Option::Some(&head[..])
|
||||
};
|
||||
::postgres_types::FromSql::from_sql_nullable(type_, value)
|
||||
}
|
||||
|
||||
let fields = match *_type.kind() {
|
||||
::postgres_types::Kind::Composite(ref fields) => fields,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let mut buf = buf;
|
||||
let num_fields = read_be_i32(&mut buf)?;
|
||||
if num_fields as usize != fields.len() {
|
||||
return ::std::result::Result::Err(
|
||||
::std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields,
|
||||
fields.len())));
|
||||
}
|
||||
|
||||
#(
|
||||
let mut #temp_vars = ::std::option::Option::None;
|
||||
)*
|
||||
|
||||
for field in fields {
|
||||
let oid = read_be_i32(&mut buf)? as u32;
|
||||
if oid != field.type_().oid() {
|
||||
return ::std::result::Result::Err(::std::convert::Into::into("unexpected OID"));
|
||||
}
|
||||
|
||||
match field.name() {
|
||||
#(
|
||||
#field_names => {
|
||||
#temp_vars = ::std::option::Option::Some(
|
||||
read_value(field.type_(), &mut buf)?);
|
||||
}
|
||||
)*
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
::std::result::Result::Ok(#ident {
|
||||
#(
|
||||
#field_idents: #temp_vars.unwrap(),
|
||||
)*
|
||||
})
|
||||
}
|
||||
}
|
28
postgres-derive/src/lib.rs
Normal file
28
postgres-derive/src/lib.rs
Normal file
@ -0,0 +1,28 @@
|
||||
#![recursion_limit = "256"]
|
||||
|
||||
extern crate proc_macro;
|
||||
extern crate syn;
|
||||
#[macro_use]
|
||||
extern crate quote;
|
||||
extern crate proc_macro2;
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
|
||||
mod accepts;
|
||||
mod composites;
|
||||
mod enums;
|
||||
mod fromsql;
|
||||
mod overrides;
|
||||
mod tosql;
|
||||
|
||||
#[proc_macro_derive(ToSql, attributes(postgres))]
|
||||
pub fn derive_tosql(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
tosql::expand_derive_tosql(input).unwrap().into()
|
||||
}
|
||||
|
||||
#[proc_macro_derive(FromSql, attributes(postgres))]
|
||||
pub fn derive_fromsql(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
fromsql::expand_derive_fromsql(input).unwrap().into()
|
||||
}
|
49
postgres-derive/src/overrides.rs
Normal file
49
postgres-derive/src/overrides.rs
Normal file
@ -0,0 +1,49 @@
|
||||
use syn::{Attribute, Error, Lit, Meta, NestedMeta};
|
||||
|
||||
pub struct Overrides {
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl Overrides {
|
||||
pub fn extract(attrs: &[Attribute]) -> Result<Overrides, Error> {
|
||||
let mut overrides = Overrides { name: None };
|
||||
|
||||
for attr in attrs {
|
||||
let attr = match attr.parse_meta() {
|
||||
Ok(meta) => meta,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if !attr.path().is_ident("postgres") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let list = match attr {
|
||||
Meta::List(ref list) => list,
|
||||
bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")),
|
||||
};
|
||||
|
||||
for item in &list.nested {
|
||||
match item {
|
||||
NestedMeta::Meta(Meta::NameValue(meta)) => {
|
||||
if !meta.path.is_ident("name") {
|
||||
return Err(Error::new_spanned(&meta.path, "unknown override"));
|
||||
}
|
||||
|
||||
let value = match &meta.lit {
|
||||
Lit::Str(s) => s.value(),
|
||||
bad => {
|
||||
return Err(Error::new_spanned(bad, "expected a string literal"))
|
||||
}
|
||||
};
|
||||
|
||||
overrides.name = Some(value);
|
||||
}
|
||||
bad => return Err(Error::new_spanned(bad, "expected a name-value meta item")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(overrides)
|
||||
}
|
||||
}
|
160
postgres-derive/src/tosql.rs
Normal file
160
postgres-derive/src/tosql.rs
Normal file
@ -0,0 +1,160 @@
|
||||
use std::iter;
|
||||
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
|
||||
|
||||
use accepts;
|
||||
use composites::Field;
|
||||
use enums::Variant;
|
||||
use overrides::Overrides;
|
||||
use proc_macro2::TokenStream;
|
||||
|
||||
pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
|
||||
let overrides = Overrides::extract(&input.attrs)?;
|
||||
|
||||
let name = overrides.name.unwrap_or_else(|| input.ident.to_string());
|
||||
|
||||
let (accepts_body, to_sql_body) = match input.data {
|
||||
Data::Enum(ref data) => {
|
||||
let variants = data
|
||||
.variants
|
||||
.iter()
|
||||
.map(Variant::parse)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
(
|
||||
accepts::enum_body(&name, &variants),
|
||||
enum_body(&input.ident, &variants),
|
||||
)
|
||||
}
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Unnamed(ref fields),
|
||||
..
|
||||
}) if fields.unnamed.len() == 1 => {
|
||||
let field = fields.unnamed.first().unwrap();
|
||||
(accepts::domain_body(&name, &field), domain_body())
|
||||
}
|
||||
Data::Struct(DataStruct {
|
||||
fields: Fields::Named(ref fields),
|
||||
..
|
||||
}) => {
|
||||
let fields = fields
|
||||
.named
|
||||
.iter()
|
||||
.map(Field::parse)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
(
|
||||
accepts::composite_body(&name, "ToSql", &fields),
|
||||
composite_body(&fields),
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::new_spanned(
|
||||
input,
|
||||
"#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let ident = &input.ident;
|
||||
let out = quote! {
|
||||
impl ::postgres_types::ToSql for #ident {
|
||||
fn to_sql(&self,
|
||||
_type: &::postgres_types::Type,
|
||||
buf: &mut ::std::vec::Vec<u8>)
|
||||
-> ::std::result::Result<::postgres_types::IsNull,
|
||||
::std::boxed::Box<::std::error::Error +
|
||||
::std::marker::Sync +
|
||||
::std::marker::Send>> {
|
||||
#to_sql_body
|
||||
}
|
||||
|
||||
fn accepts(type_: &::postgres_types::Type) -> bool {
|
||||
#accepts_body
|
||||
}
|
||||
|
||||
::postgres_types::to_sql_checked!();
|
||||
}
|
||||
};
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
|
||||
let idents = iter::repeat(ident);
|
||||
let variant_idents = variants.iter().map(|v| &v.ident);
|
||||
let variant_names = variants.iter().map(|v| &v.name);
|
||||
|
||||
quote! {
|
||||
let s = match *self {
|
||||
#(
|
||||
#idents::#variant_idents => #variant_names,
|
||||
)*
|
||||
};
|
||||
|
||||
buf.extend_from_slice(s.as_bytes());
|
||||
::std::result::Result::Ok(::postgres_types::IsNull::No)
|
||||
}
|
||||
}
|
||||
|
||||
fn domain_body() -> TokenStream {
|
||||
quote! {
|
||||
let type_ = match *_type.kind() {
|
||||
::postgres_types::Kind::Domain(ref type_) => type_,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
::postgres_types::ToSql::to_sql(&self.0, type_, buf)
|
||||
}
|
||||
}
|
||||
|
||||
fn composite_body(fields: &[Field]) -> TokenStream {
|
||||
let field_names = fields.iter().map(|f| &f.name);
|
||||
let field_idents = fields.iter().map(|f| &f.ident);
|
||||
|
||||
quote! {
|
||||
fn write_be_i32<W>(buf: &mut W, n: i32) -> ::std::io::Result<()>
|
||||
where W: ::std::io::Write
|
||||
{
|
||||
let be = [(n >> 24) as u8, (n >> 16) as u8, (n >> 8) as u8, n as u8];
|
||||
buf.write_all(&be)
|
||||
}
|
||||
|
||||
let fields = match *_type.kind() {
|
||||
::postgres_types::Kind::Composite(ref fields) => fields,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
write_be_i32(buf, fields.len() as i32)?;
|
||||
|
||||
for field in fields {
|
||||
write_be_i32(buf, field.type_().oid() as i32)?;
|
||||
|
||||
let base = buf.len();
|
||||
write_be_i32(buf, 0)?;
|
||||
let r = match field.name() {
|
||||
#(
|
||||
#field_names => {
|
||||
::postgres_types::ToSql::to_sql(&self.#field_idents,
|
||||
field.type_(),
|
||||
buf)
|
||||
}
|
||||
)*
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let count = match r? {
|
||||
::postgres_types::IsNull::Yes => -1,
|
||||
::postgres_types::IsNull::No => {
|
||||
let len = buf.len() - base - 4;
|
||||
if len > i32::max_value() as usize {
|
||||
return ::std::result::Result::Err(
|
||||
::std::convert::Into::into("value too large to transmit"));
|
||||
}
|
||||
len as i32
|
||||
}
|
||||
};
|
||||
|
||||
write_be_i32(&mut &mut buf[base..base + 4], count)?;
|
||||
}
|
||||
|
||||
::std::result::Result::Ok(::postgres_types::IsNull::No)
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ authors = ["Steven Fackler <sfackler@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
"derive" = ["postgres-derive"]
|
||||
"with-bit-vec-0_6" = ["bit-vec-06"]
|
||||
"with-chrono-0_4" = ["chrono-04"]
|
||||
"with-eui48-0_4" = ["eui48-04"]
|
||||
@ -15,6 +16,7 @@ with-serde_json-1 = ["serde-1", "serde_json-1"]
|
||||
[dependencies]
|
||||
fallible-iterator = "0.2"
|
||||
postgres-protocol = { version = "0.4.1", path = "../postgres-protocol" }
|
||||
postgres-derive = { version = "0.3.3", optional = true, path = "../postgres-derive" }
|
||||
|
||||
bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true }
|
||||
chrono-04 = { version = "0.4", package = "chrono", optional = true }
|
||||
|
@ -17,6 +17,9 @@ use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
#[cfg(feature = "derive")]
|
||||
pub use postgres_derive::{FromSql, ToSql};
|
||||
|
||||
#[cfg(feature = "with-serde_json-1")]
|
||||
pub use crate::serde_json_1::Json;
|
||||
use crate::type_gen::{Inner, Other};
|
||||
|
Loading…
Reference in New Issue
Block a user