diff --git a/postgres-protocol/src/types/mod.rs b/postgres-protocol/src/types/mod.rs index 5939d9f0..05f515f7 100644 --- a/postgres-protocol/src/types/mod.rs +++ b/postgres-protocol/src/types/mod.rs @@ -1060,18 +1060,59 @@ impl Inet { } } -/// Serializes a Postgres l{tree,query,txtquery} string +/// Serializes a Postgres ltree string #[inline] pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { - // A version number is prepended to an Ltree string per spec + // A version number is prepended to an ltree string per spec buf.put_u8(1); // Append the rest of the query buf.put_slice(v.as_bytes()); } -/// Deserialize a Postgres l{tree,query,txtquery} string +/// Deserialize a Postgres ltree string #[inline] pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { - // Remove the version number from the front of the string per spec - Ok(str::from_utf8(&buf[1..])?) + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), + } +} + +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), + } +} + +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), + } } diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs index 7c20cf3e..1ce49b66 100644 --- a/postgres-protocol/src/types/test.rs +++ b/postgres-protocol/src/types/test.rs @@ -1,4 +1,4 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use std::collections::HashMap; @@ -156,3 +156,117 @@ fn non_null_array() { assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); assert_eq!(array.values().collect::>().unwrap(), values); } + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let success = match ltree_from_sql(query.as_slice()) { + Ok(_) => true, + _ => false, + }; + + assert!(success) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let success = match ltree_from_sql(query.as_slice()) { + Err(_) => true, + _ => false, + }; + + assert!(success) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let success = match lquery_from_sql(query.as_slice()) { + Ok(_) => true, + _ => false, + }; + + assert!(success) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let success = match lquery_from_sql(query.as_slice()) { + Err(_) => true, + _ => false, + }; + + assert!(success) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let success = match ltree_from_sql(query.as_slice()) { + Ok(_) => true, + _ => false, + }; + + assert!(success) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let success = match ltree_from_sql(query.as_slice()) { + Err(_) => true, + _ => false, + }; + + assert!(success) +} diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 9580fb5c..d029d394 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -619,24 +619,24 @@ impl<'a> FromSql<'a> for Box { impl<'a> FromSql<'a> for &'a str { fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { match *ty { - ref ty if ( - ty.name() == "ltree" || - ty.name() == "lquery" || - ty.name() == "ltxtquery" - ) => types::ltree_from_sql(raw), - _ => types::text_from_sql(raw) + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), } } fn accepts(ty: &Type) -> bool { match *ty { Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ( - ty.name() == "citext" || - ty.name() == "ltree" || - ty.name() == "lquery" || - ty.name() == "ltxtquery" - ) => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } _ => false, } } @@ -939,13 +939,11 @@ impl ToSql for Vec { impl<'a> ToSql for &'a str { fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match ty { - ref ty if ( - ty.name() == "ltree" || - ty.name() == "lquery" || - ty.name() == "ltxtquery" - ) => types::ltree_to_sql(*self, w), - _ => types::text_to_sql(*self, w) + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_to_sql(*self, w), + ref ty if ty.name() == "lquery" => types::lquery_to_sql(*self, w), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(*self, w), + _ => types::text_to_sql(*self, w), } Ok(IsNull::No) } @@ -953,12 +951,14 @@ impl<'a> ToSql for &'a str { fn accepts(ty: &Type) -> bool { match *ty { Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ( - ty.name() == "citext" || - ty.name() == "ltree" || - ty.name() == "lquery" || - ty.name() == "ltxtquery" - ) => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } _ => false, } } diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 0ec329a4..f69932e5 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -652,74 +652,122 @@ async fn inet() { #[tokio::test] async fn ltree() { let client = connect("user=postgres").await; - client.execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]).await.unwrap(); + client + .execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]) + .await + .unwrap(); - test_type("ltree", &[ - (Some("b.c.d".to_owned()), "'b.c.d'"), - (None, "NULL"), - ]).await; + test_type( + "ltree", + &[(Some("b.c.d".to_owned()), "'b.c.d'"), (None, "NULL")], + ) + .await; } #[tokio::test] async fn ltree_any() { let client = connect("user=postgres").await; - client.execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]).await.unwrap(); + client + .execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]) + .await + .unwrap(); - test_type("ltree[]", &[ - (Some(vec![]), "ARRAY[]"), - (Some(vec!["a.b.c".to_string()]), "ARRAY['a.b.c']"), - (Some(vec!["a.b.c".to_string(), "e.f.g".to_string()]), "ARRAY['a.b.c','e.f.g']"), - (None, "NULL"), - ]).await; + test_type( + "ltree[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["a.b.c".to_string()]), "ARRAY['a.b.c']"), + ( + Some(vec!["a.b.c".to_string(), "e.f.g".to_string()]), + "ARRAY['a.b.c','e.f.g']", + ), + (None, "NULL"), + ], + ) + .await; } #[tokio::test] async fn lquery() { let client = connect("user=postgres").await; - client.execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]).await.unwrap(); + client + .execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]) + .await + .unwrap(); - test_type("lquery", &[ - (Some("b.c.d".to_owned()), "'b.c.d'"), - (Some("b.c.*".to_owned()), "'b.c.*'"), - (Some("b.*{1,2}.d|e".to_owned()), "'b.*{1,2}.d|e'"), - (None, "NULL"), - ]).await; + test_type( + "lquery", + &[ + (Some("b.c.d".to_owned()), "'b.c.d'"), + (Some("b.c.*".to_owned()), "'b.c.*'"), + (Some("b.*{1,2}.d|e".to_owned()), "'b.*{1,2}.d|e'"), + (None, "NULL"), + ], + ) + .await; } #[tokio::test] async fn lquery_any() { let client = connect("user=postgres").await; - client.execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]).await.unwrap(); + client + .execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]) + .await + .unwrap(); - test_type("lquery[]", &[ - (Some(vec![]), "ARRAY[]"), - (Some(vec!["b.c.*".to_string()]), "ARRAY['b.c.*']"), - (Some(vec!["b.c.*".to_string(), "b.*{1,2}.d|e".to_string()]), "ARRAY['b.c.*','b.*{1,2}.d|e']"), - (None, "NULL"), - ]).await; + test_type( + "lquery[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["b.c.*".to_string()]), "ARRAY['b.c.*']"), + ( + Some(vec!["b.c.*".to_string(), "b.*{1,2}.d|e".to_string()]), + "ARRAY['b.c.*','b.*{1,2}.d|e']", + ), + (None, "NULL"), + ], + ) + .await; } #[tokio::test] async fn ltxtquery() { let client = connect("user=postgres").await; - client.execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]).await.unwrap(); + client + .execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]) + .await + .unwrap(); - test_type("ltxtquery", &[ - (Some("b & c & d".to_owned()), "'b & c & d'"), - (Some("b@* & !c".to_owned()), "'b@* & !c'"), - (None, "NULL"), - ]).await; + test_type( + "ltxtquery", + &[ + (Some("b & c & d".to_owned()), "'b & c & d'"), + (Some("b@* & !c".to_owned()), "'b@* & !c'"), + (None, "NULL"), + ], + ) + .await; } #[tokio::test] async fn ltxtquery_any() { let client = connect("user=postgres").await; - client.execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]).await.unwrap(); + client + .execute("CREATE EXTENSION IF NOT EXISTS ltree;", &[]) + .await + .unwrap(); - test_type("ltxtquery[]", &[ - (Some(vec![]), "ARRAY[]"), - (Some(vec!["b & c & d".to_string()]), "ARRAY['b & c & d']"), - (Some(vec!["b & c & d".to_string(), "b@* & !c".to_string()]), "ARRAY['b & c & d','b@* & !c']"), - (None, "NULL"), - ]).await; + test_type( + "ltxtquery[]", + &[ + (Some(vec![]), "ARRAY[]"), + (Some(vec!["b & c & d".to_string()]), "ARRAY['b & c & d']"), + ( + Some(vec!["b & c & d".to_string(), "b@* & !c".to_string()]), + "ARRAY['b & c & d','b@* & !c']", + ), + (None, "NULL"), + ], + ) + .await; }