Move postgres-derive in-tree

This commit is contained in:
Steven Fackler 2019-10-09 19:23:12 -07:00
parent 31855141d2
commit 218d889042
16 changed files with 1076 additions and 0 deletions

View File

@ -2,6 +2,8 @@
members = [
"codegen",
"postgres",
"postgres-derive",
"postgres-derive-test",
"postgres-native-tls",
"postgres-openssl",
"postgres-protocol",

View File

@ -0,0 +1,9 @@
[package]
name = "postgres-derive-test"
version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
edition = "2018"
[dependencies]
postgres-types = { path = "../postgres-types", features = ["derive"] }
postgres = { path = "../postgres" }

View File

@ -0,0 +1,217 @@
use crate::test_type;
use postgres::{Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use std::error::Error;
#[test]
fn defaults() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
struct InventoryItem {
name: String,
supplier_id: i32,
price: Option<f64>,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.\"InventoryItem\" AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: Some(15.50),
};
let item_null = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: None,
};
test_type(
&mut conn,
"\"InventoryItem\"",
&[
(item, "ROW('foobar', 100, 15.50)"),
(item_null, "ROW('foobar', 100, NULL)"),
],
);
}
#[test]
fn name_overrides() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "inventory_item")]
struct InventoryItem {
#[postgres(name = "name")]
_name: String,
#[postgres(name = "supplier_id")]
_supplier_id: i32,
#[postgres(name = "price")]
_price: Option<f64>,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItem {
_name: "foobar".to_owned(),
_supplier_id: 100,
_price: Some(15.50),
};
let item_null = InventoryItem {
_name: "foobar".to_owned(),
_supplier_id: 100,
_price: None,
};
test_type(
&mut conn,
"inventory_item",
&[
(item, "ROW('foobar', 100, 15.50)"),
(item_null, "ROW('foobar', 100, NULL)"),
],
);
}
#[test]
fn wrong_name() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
struct InventoryItem {
name: String,
supplier_id: i32,
price: Option<f64>,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: Some(15.50),
};
let err = conn
.execute("SELECT $1::inventory_item", &[&item])
.unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn extra_field() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "inventory_item")]
struct InventoryItem {
name: String,
supplier_id: i32,
price: Option<f64>,
foo: i32,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: Some(15.50),
foo: 0,
};
let err = conn
.execute("SELECT $1::inventory_item", &[&item])
.unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn missing_field() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "inventory_item")]
struct InventoryItem {
name: String,
supplier_id: i32,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
};
let err = conn
.execute("SELECT $1::inventory_item", &[&item])
.unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn wrong_type() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "inventory_item")]
struct InventoryItem {
name: String,
supplier_id: i32,
price: i32,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier_id INT,
price DOUBLE PRECISION
);",
)
.unwrap();
let item = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: 0,
};
let err = conn
.execute("SELECT $1::inventory_item", &[&item])
.unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}

View File

@ -0,0 +1,121 @@
use crate::test_type;
use postgres::{Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use std::error::Error;
#[test]
fn defaults() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
struct SessionId(Vec<u8>);
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute(
"CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);",
&[],
)
.unwrap();
test_type(
&mut conn,
"\"SessionId\"",
&[(
SessionId(b"0123456789abcdef".to_vec()),
"'0123456789abcdef'",
)],
);
}
#[test]
fn name_overrides() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "session_id")]
struct SessionId(Vec<u8>);
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);",
&[],
)
.unwrap();
test_type(
&mut conn,
"session_id",
&[(
SessionId(b"0123456789abcdef".to_vec()),
"'0123456789abcdef'",
)],
);
}
#[test]
fn wrong_name() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
struct SessionId(Vec<u8>);
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);",
&[],
)
.unwrap();
let err = conn
.execute("SELECT $1::session_id", &[&SessionId(vec![])])
.unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn wrong_type() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "session_id")]
struct SessionId(i32);
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);",
&[],
)
.unwrap();
let err = conn
.execute("SELECT $1::session_id", &[&SessionId(0)])
.unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn domain_in_composite() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "domain")]
struct Domain(String);
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "composite")]
struct Composite {
domain: Domain,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"
CREATE DOMAIN pg_temp.domain AS TEXT;\
CREATE TYPE pg_temp.composite AS (
domain domain
);
",
)
.unwrap();
test_type(
&mut conn,
"composite",
&[(
Composite {
domain: Domain("hello".to_string()),
},
"ROW('hello')",
)],
);
}

View File

@ -0,0 +1,104 @@
use crate::test_type;
use postgres::{Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use std::error::Error;
#[test]
fn defaults() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
enum Foo {
Bar,
Baz,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
.unwrap();
test_type(
&mut conn,
"\"Foo\"",
&[(Foo::Bar, "'Bar'"), (Foo::Baz, "'Baz'")],
);
}
#[test]
fn name_overrides() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(name = "mood")]
enum Mood {
#[postgres(name = "sad")]
Sad,
#[postgres(name = "ok")]
Ok,
#[postgres(name = "happy")]
Happy,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute(
"CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')",
&[],
)
.unwrap();
test_type(
&mut conn,
"mood",
&[
(Mood::Sad, "'sad'"),
(Mood::Ok, "'ok'"),
(Mood::Happy, "'happy'"),
],
);
}
#[test]
fn wrong_name() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
enum Foo {
Bar,
Baz,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
.unwrap();
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn extra_variant() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(name = "foo")]
enum Foo {
Bar,
Baz,
Buz,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
.unwrap();
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
#[test]
fn missing_variant() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(name = "foo")]
enum Foo {
Bar,
}
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
.unwrap();
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}

View File

@ -0,0 +1,27 @@
#![cfg(test)]
use postgres::Client;
use postgres_types::{FromSqlOwned, ToSql};
use std::fmt;
mod composites;
mod domains;
mod enums;
pub fn test_type<T, S>(conn: &mut Client, sql_type: &str, checks: &[(T, S)])
where
T: PartialEq + FromSqlOwned + ToSql + Sync,
S: fmt::Display,
{
for &(ref val, ref repr) in checks.iter() {
let stmt = conn
.prepare(&*format!("SELECT {}::{}", *repr, sql_type))
.unwrap();
let result = conn.query_one(&stmt, &[]).unwrap().get(0);
assert_eq!(val, &result);
let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap();
let result = conn.query_one(&stmt, &[val]).unwrap().get(0);
assert_eq!(val, &result);
}
}

View File

@ -0,0 +1,18 @@
[package]
name = "postgres-derive"
version = "0.3.3"
authors = ["Steven Fackler <sfackler@palantir.com>"]
license = "MIT/Apache-2.0"
description = "Deriving plugin support for Postgres enum, domain, and composite types"
repository = "https://github.com/sfackler/rust-postgres-derive"
readme = "README.md"
keywords = ["database", "postgres", "postgresql", "sql"]
[lib]
proc-macro = true
test = false
[dependencies]
syn = "1.0"
proc-macro2 = "1.0"
quote = "1.0"

View File

@ -0,0 +1,86 @@
use proc_macro2::{Span, TokenStream};
use std::iter;
use syn::Ident;
use composites::Field;
use enums::Variant;
pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
let ty = &field.ty;
quote! {
if type_.name() != #name {
return false;
}
match *type_.kind() {
::postgres_types::Kind::Domain(ref type_) => {
<#ty as ::postgres_types::ToSql>::accepts(type_)
}
_ => false,
}
}
}
pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
let num_variants = variants.len();
let variant_names = variants.iter().map(|v| &v.name);
quote! {
if type_.name() != #name {
return false;
}
match *type_.kind() {
::postgres_types::Kind::Enum(ref variants) => {
if variants.len() != #num_variants {
return false;
}
variants.iter().all(|v| {
match &**v {
#(
#variant_names => true,
)*
_ => false,
}
})
}
_ => false,
}
}
}
pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream {
let num_fields = fields.len();
let trait_ = Ident::new(trait_, Span::call_site());
let traits = iter::repeat(&trait_);
let field_names = fields.iter().map(|f| &f.name);
let field_types = fields.iter().map(|f| &f.type_);
quote! {
if type_.name() != #name {
return false;
}
match *type_.kind() {
::postgres_types::Kind::Composite(ref fields) => {
if fields.len() != #num_fields {
return false;
}
fields.iter().all(|f| {
match f.name() {
#(
#field_names => {
<#field_types as ::postgres_types::#traits>::accepts(f.type_())
}
)*
_ => false,
}
})
}
_ => false,
}
}
}

View File

@ -0,0 +1,22 @@
use syn::{self, Error, Ident, Type};
use overrides::Overrides;
pub struct Field {
pub name: String,
pub ident: Ident,
pub type_: Type,
}
impl Field {
pub fn parse(raw: &syn::Field) -> Result<Field, Error> {
let overrides = Overrides::extract(&raw.attrs)?;
let ident = raw.ident.as_ref().unwrap().clone();
Ok(Field {
name: overrides.name.unwrap_or_else(|| ident.to_string()),
ident,
type_: raw.ty.clone(),
})
}
}

View File

@ -0,0 +1,28 @@
use syn::{self, Error, Fields, Ident};
use overrides::Overrides;
pub struct Variant {
pub ident: Ident,
pub name: String,
}
impl Variant {
pub fn parse(raw: &syn::Variant) -> Result<Variant, Error> {
match raw.fields {
Fields::Unit => {}
_ => {
return Err(Error::new_spanned(
raw,
"non-C-like enums are not supported",
))
}
}
let overrides = Overrides::extract(&raw.attrs)?;
Ok(Variant {
ident: raw.ident.clone(),
name: overrides.name.unwrap_or_else(|| raw.ident.to_string()),
})
}
}

View File

@ -0,0 +1,200 @@
use proc_macro2::{Span, TokenStream};
use std::iter;
use syn::{self, Data, DataStruct, DeriveInput, Error, Fields, Ident};
use accepts;
use composites::Field;
use enums::Variant;
use overrides::Overrides;
pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let overrides = Overrides::extract(&input.attrs)?;
let name = overrides.name.unwrap_or_else(|| input.ident.to_string());
let (accepts_body, to_sql_body) = match input.data {
Data::Enum(ref data) => {
let variants = data
.variants
.iter()
.map(Variant::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
enum_body(&input.ident, &variants),
)
}
Data::Struct(DataStruct {
fields: Fields::Unnamed(ref fields),
..
}) if fields.unnamed.len() == 1 => {
let field = fields.unnamed.first().unwrap();
(
domain_accepts_body(&name, field),
domain_body(&input.ident, field),
)
}
Data::Struct(DataStruct {
fields: Fields::Named(ref fields),
..
}) => {
let fields = fields
.named
.iter()
.map(Field::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::composite_body(&name, "FromSql", &fields),
composite_body(&input.ident, &fields),
)
}
_ => {
return Err(Error::new_spanned(
input,
"#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums",
))
}
};
let ident = &input.ident;
let out = quote! {
impl<'a> ::postgres_types::FromSql<'a> for #ident {
fn from_sql(_type: &::postgres_types::Type, buf: &'a [u8])
-> ::std::result::Result<#ident,
::std::boxed::Box<::std::error::Error +
::std::marker::Sync +
::std::marker::Send>> {
#to_sql_body
}
fn accepts(type_: &::postgres_types::Type) -> bool {
#accepts_body
}
}
};
Ok(out)
}
fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
let variant_names = variants.iter().map(|v| &v.name);
let idents = iter::repeat(ident);
let variant_idents = variants.iter().map(|v| &v.ident);
quote! {
match ::std::str::from_utf8(buf)? {
#(
#variant_names => ::std::result::Result::Ok(#idents::#variant_idents),
)*
s => {
::std::result::Result::Err(
::std::convert::Into::into(format!("invalid variant `{}`", s)))
}
}
}
}
// Domains are sometimes but not always just represented by the bare type (!?)
fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
let ty = &field.ty;
let normal_body = accepts::domain_body(name, field);
quote! {
if <#ty as ::postgres_types::FromSql>::accepts(type_) {
return true;
}
#normal_body
}
}
fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
let ty = &field.ty;
quote! {
<#ty as ::postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
}
}
fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
let temp_vars = &fields
.iter()
.map(|f| Ident::new(&format!("__{}", f.ident), Span::call_site()))
.collect::<Vec<_>>();
let field_names = &fields.iter().map(|f| &f.name).collect::<Vec<_>>();
let field_idents = &fields.iter().map(|f| &f.ident).collect::<Vec<_>>();
quote! {
fn read_be_i32(buf: &mut &[u8]) -> ::std::io::Result<i32> {
let mut bytes = [0; 4];
::std::io::Read::read_exact(buf, &mut bytes)?;
let num = ((bytes[0] as i32) << 24) |
((bytes[1] as i32) << 16) |
((bytes[2] as i32) << 8) |
(bytes[3] as i32);
::std::result::Result::Ok(num)
}
fn read_value<'a, T>(type_: &::postgres_types::Type,
buf: &mut &'a [u8])
-> ::std::result::Result<T,
::std::boxed::Box<::std::error::Error +
::std::marker::Sync +
::std::marker::Send>>
where T: ::postgres_types::FromSql<'a>
{
let len = read_be_i32(buf)?;
let value = if len < 0 {
::std::option::Option::None
} else {
if len as usize > buf.len() {
return ::std::result::Result::Err(
::std::convert::Into::into("invalid buffer size"));
}
let (head, tail) = buf.split_at(len as usize);
*buf = tail;
::std::option::Option::Some(&head[..])
};
::postgres_types::FromSql::from_sql_nullable(type_, value)
}
let fields = match *_type.kind() {
::postgres_types::Kind::Composite(ref fields) => fields,
_ => unreachable!(),
};
let mut buf = buf;
let num_fields = read_be_i32(&mut buf)?;
if num_fields as usize != fields.len() {
return ::std::result::Result::Err(
::std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields,
fields.len())));
}
#(
let mut #temp_vars = ::std::option::Option::None;
)*
for field in fields {
let oid = read_be_i32(&mut buf)? as u32;
if oid != field.type_().oid() {
return ::std::result::Result::Err(::std::convert::Into::into("unexpected OID"));
}
match field.name() {
#(
#field_names => {
#temp_vars = ::std::option::Option::Some(
read_value(field.type_(), &mut buf)?);
}
)*
_ => unreachable!(),
}
}
::std::result::Result::Ok(#ident {
#(
#field_idents: #temp_vars.unwrap(),
)*
})
}
}

View File

@ -0,0 +1,28 @@
#![recursion_limit = "256"]
extern crate proc_macro;
extern crate syn;
#[macro_use]
extern crate quote;
extern crate proc_macro2;
use proc_macro::TokenStream;
mod accepts;
mod composites;
mod enums;
mod fromsql;
mod overrides;
mod tosql;
#[proc_macro_derive(ToSql, attributes(postgres))]
pub fn derive_tosql(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
tosql::expand_derive_tosql(input).unwrap().into()
}
#[proc_macro_derive(FromSql, attributes(postgres))]
pub fn derive_fromsql(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
fromsql::expand_derive_fromsql(input).unwrap().into()
}

View File

@ -0,0 +1,49 @@
use syn::{Attribute, Error, Lit, Meta, NestedMeta};
pub struct Overrides {
pub name: Option<String>,
}
impl Overrides {
pub fn extract(attrs: &[Attribute]) -> Result<Overrides, Error> {
let mut overrides = Overrides { name: None };
for attr in attrs {
let attr = match attr.parse_meta() {
Ok(meta) => meta,
Err(_) => continue,
};
if !attr.path().is_ident("postgres") {
continue;
}
let list = match attr {
Meta::List(ref list) => list,
bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")),
};
for item in &list.nested {
match item {
NestedMeta::Meta(Meta::NameValue(meta)) => {
if !meta.path.is_ident("name") {
return Err(Error::new_spanned(&meta.path, "unknown override"));
}
let value = match &meta.lit {
Lit::Str(s) => s.value(),
bad => {
return Err(Error::new_spanned(bad, "expected a string literal"))
}
};
overrides.name = Some(value);
}
bad => return Err(Error::new_spanned(bad, "expected a name-value meta item")),
}
}
}
Ok(overrides)
}
}

View File

@ -0,0 +1,160 @@
use std::iter;
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
use accepts;
use composites::Field;
use enums::Variant;
use overrides::Overrides;
use proc_macro2::TokenStream;
pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
let overrides = Overrides::extract(&input.attrs)?;
let name = overrides.name.unwrap_or_else(|| input.ident.to_string());
let (accepts_body, to_sql_body) = match input.data {
Data::Enum(ref data) => {
let variants = data
.variants
.iter()
.map(Variant::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
enum_body(&input.ident, &variants),
)
}
Data::Struct(DataStruct {
fields: Fields::Unnamed(ref fields),
..
}) if fields.unnamed.len() == 1 => {
let field = fields.unnamed.first().unwrap();
(accepts::domain_body(&name, &field), domain_body())
}
Data::Struct(DataStruct {
fields: Fields::Named(ref fields),
..
}) => {
let fields = fields
.named
.iter()
.map(Field::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::composite_body(&name, "ToSql", &fields),
composite_body(&fields),
)
}
_ => {
return Err(Error::new_spanned(
input,
"#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums",
));
}
};
let ident = &input.ident;
let out = quote! {
impl ::postgres_types::ToSql for #ident {
fn to_sql(&self,
_type: &::postgres_types::Type,
buf: &mut ::std::vec::Vec<u8>)
-> ::std::result::Result<::postgres_types::IsNull,
::std::boxed::Box<::std::error::Error +
::std::marker::Sync +
::std::marker::Send>> {
#to_sql_body
}
fn accepts(type_: &::postgres_types::Type) -> bool {
#accepts_body
}
::postgres_types::to_sql_checked!();
}
};
Ok(out)
}
fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
let idents = iter::repeat(ident);
let variant_idents = variants.iter().map(|v| &v.ident);
let variant_names = variants.iter().map(|v| &v.name);
quote! {
let s = match *self {
#(
#idents::#variant_idents => #variant_names,
)*
};
buf.extend_from_slice(s.as_bytes());
::std::result::Result::Ok(::postgres_types::IsNull::No)
}
}
fn domain_body() -> TokenStream {
quote! {
let type_ = match *_type.kind() {
::postgres_types::Kind::Domain(ref type_) => type_,
_ => unreachable!(),
};
::postgres_types::ToSql::to_sql(&self.0, type_, buf)
}
}
fn composite_body(fields: &[Field]) -> TokenStream {
let field_names = fields.iter().map(|f| &f.name);
let field_idents = fields.iter().map(|f| &f.ident);
quote! {
fn write_be_i32<W>(buf: &mut W, n: i32) -> ::std::io::Result<()>
where W: ::std::io::Write
{
let be = [(n >> 24) as u8, (n >> 16) as u8, (n >> 8) as u8, n as u8];
buf.write_all(&be)
}
let fields = match *_type.kind() {
::postgres_types::Kind::Composite(ref fields) => fields,
_ => unreachable!(),
};
write_be_i32(buf, fields.len() as i32)?;
for field in fields {
write_be_i32(buf, field.type_().oid() as i32)?;
let base = buf.len();
write_be_i32(buf, 0)?;
let r = match field.name() {
#(
#field_names => {
::postgres_types::ToSql::to_sql(&self.#field_idents,
field.type_(),
buf)
}
)*
_ => unreachable!(),
};
let count = match r? {
::postgres_types::IsNull::Yes => -1,
::postgres_types::IsNull::No => {
let len = buf.len() - base - 4;
if len > i32::max_value() as usize {
return ::std::result::Result::Err(
::std::convert::Into::into("value too large to transmit"));
}
len as i32
}
};
write_be_i32(&mut &mut buf[base..base + 4], count)?;
}
::std::result::Result::Ok(::postgres_types::IsNull::No)
}
}

View File

@ -5,6 +5,7 @@ authors = ["Steven Fackler <sfackler@gmail.com>"]
edition = "2018"
[features]
"derive" = ["postgres-derive"]
"with-bit-vec-0_6" = ["bit-vec-06"]
"with-chrono-0_4" = ["chrono-04"]
"with-eui48-0_4" = ["eui48-04"]
@ -15,6 +16,7 @@ with-serde_json-1 = ["serde-1", "serde_json-1"]
[dependencies]
fallible-iterator = "0.2"
postgres-protocol = { version = "0.4.1", path = "../postgres-protocol" }
postgres-derive = { version = "0.3.3", optional = true, path = "../postgres-derive" }
bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true }
chrono-04 = { version = "0.4", package = "chrono", optional = true }

View File

@ -17,6 +17,9 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[cfg(feature = "derive")]
pub use postgres_derive::{FromSql, ToSql};
#[cfg(feature = "with-serde_json-1")]
pub use crate::serde_json_1::Json;
use crate::type_gen::{Inner, Other};