This commit is contained in:
Steven Fackler 2017-06-30 17:35:17 -10:00
parent 3809972907
commit 6a86f8dd85
52 changed files with 2638 additions and 1605 deletions

View File

@ -71,43 +71,51 @@ fn variant_name(code: &str) -> Option<String> {
} }
fn make_header(file: &mut BufWriter<File>) { fn make_header(file: &mut BufWriter<File>) {
write!(file, write!(
"// Autogenerated file - DO NOT EDIT file,
"// Autogenerated file - DO NOT EDIT
use phf; use phf;
" "
).unwrap(); ).unwrap();
} }
fn make_enum(codes: &[Code], file: &mut BufWriter<File>) { fn make_enum(codes: &[Code], file: &mut BufWriter<File>) {
write!(file, write!(
r#"/// SQLSTATE error codes file,
r#"/// SQLSTATE error codes
#[derive(PartialEq, Eq, Clone, Debug)] #[derive(PartialEq, Eq, Clone, Debug)]
#[allow(enum_variant_names)] #[allow(enum_variant_names)]
pub enum SqlState {{ pub enum SqlState {{
"# "#
).unwrap(); ).unwrap();
for code in codes { for code in codes {
write!(file, write!(
" /// `{}` file,
" /// `{}`
{},\n", {},\n",
code.code, code.variant).unwrap(); code.code,
code.variant
).unwrap();
} }
write!(file, write!(
" /// An unknown code file,
" /// An unknown code
Other(String), Other(String),
}} }}
" "
).unwrap(); ).unwrap();
} }
fn make_map(codes: &[Code], file: &mut BufWriter<File>) { fn make_map(codes: &[Code], file: &mut BufWriter<File>) {
write!(file, write!(
"#[cfg_attr(rustfmt, rustfmt_skip)] file,
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = ").unwrap(); "#[cfg_attr(rustfmt, rustfmt_skip)]
static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = "
).unwrap();
let mut builder = phf_codegen::Map::new(); let mut builder = phf_codegen::Map::new();
for code in codes { for code in codes {
builder.entry(&*code.code, &format!("SqlState::{}", code.variant)); builder.entry(&*code.code, &format!("SqlState::{}", code.variant));
@ -117,7 +125,9 @@ static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = ").unwrap();
} }
fn make_impl(codes: &[Code], file: &mut BufWriter<File>) { fn make_impl(codes: &[Code], file: &mut BufWriter<File>) {
write!(file, r#" write!(
file,
r#"
impl SqlState {{ impl SqlState {{
/// Creates a `SqlState` from its error code. /// Creates a `SqlState` from its error code.
pub fn from_code(s: &str) -> SqlState {{ pub fn from_code(s: &str) -> SqlState {{
@ -130,20 +140,25 @@ impl SqlState {{
/// Returns the error code corresponding to the `SqlState`. /// Returns the error code corresponding to the `SqlState`.
pub fn code(&self) -> &str {{ pub fn code(&self) -> &str {{
match *self {{"# match *self {{"#
).unwrap(); ).unwrap();
for code in codes { for code in codes {
write!(file, r#" write!(
file,
r#"
SqlState::{} => "{}","#, SqlState::{} => "{}","#,
code.variant, code.code).unwrap(); code.variant,
code.code
).unwrap();
} }
write!(file, r#" write!(
file,
r#"
SqlState::Other(ref s) => s, SqlState::Other(ref s) => s,
}} }}
}} }}
}} }}
"# "#
).unwrap(); ).unwrap();
} }

View File

@ -94,9 +94,10 @@ fn parse_types(ranges: &BTreeMap<u32, u32>) -> BTreeMap<u32, Type> {
let doc = array_re.replace(name, "$1[]"); let doc = array_re.replace(name, "$1[]");
let mut doc = doc.to_ascii_uppercase(); let mut doc = doc.to_ascii_uppercase();
let descr = lines.peek() let descr = lines
.and_then(|line| doc_re.captures(line)) .peek()
.and_then(|captures| captures.at(1)); .and_then(|line| doc_re.captures(line))
.and_then(|captures| captures.at(1));
if let Some(descr) = descr { if let Some(descr) = descr {
doc.push_str(" - "); doc.push_str(" - ");
doc.push_str(descr); doc.push_str(descr);
@ -119,38 +120,45 @@ fn parse_types(ranges: &BTreeMap<u32, u32>) -> BTreeMap<u32, Type> {
} }
fn make_header(w: &mut BufWriter<File>) { fn make_header(w: &mut BufWriter<File>) {
write!(w, write!(
"// Autogenerated file - DO NOT EDIT w,
"// Autogenerated file - DO NOT EDIT
use std::fmt; use std::fmt;
use types::{{Oid, Kind, Other}}; use types::{{Oid, Kind, Other}};
" "
).unwrap(); ).unwrap();
} }
fn make_enum(w: &mut BufWriter<File>, types: &BTreeMap<u32, Type>) { fn make_enum(w: &mut BufWriter<File>, types: &BTreeMap<u32, Type>) {
write!(w, write!(
"/// A Postgres type. w,
"/// A Postgres type.
#[derive(PartialEq, Eq, Clone, Debug)] #[derive(PartialEq, Eq, Clone, Debug)]
pub enum Type {{ pub enum Type {{
" "
).unwrap(); ).unwrap();
for type_ in types.values() { for type_ in types.values() {
write!(w, write!(
" /// {} w,
" /// {}
{}, {},
" ",
, type_.doc, type_.variant).unwrap(); type_.doc,
type_.variant
).unwrap();
} }
write!(w, write!(
r" /// An unknown type. w,
r" /// An unknown type.
Other(Other), Other(Other),
}} }}
" ).unwrap(); "
).unwrap();
} }
fn make_display_impl(w: &mut BufWriter<File>) { fn make_display_impl(w: &mut BufWriter<File>) {
@ -180,10 +188,13 @@ fn make_impl(w: &mut BufWriter<File>, types: &BTreeMap<u32, Type>) {
).unwrap(); ).unwrap();
for (oid, type_) in types { for (oid, type_) in types {
write!(w, write!(
" {} => Some(Type::{}), w,
" {} => Some(Type::{}),
", ",
oid, type_.variant).unwrap(); oid,
type_.variant
).unwrap();
} }
write!(w, write!(w,
@ -199,10 +210,13 @@ fn make_impl(w: &mut BufWriter<File>, types: &BTreeMap<u32, Type>) {
for (oid, type_) in types { for (oid, type_) in types {
write!(w, write!(
" Type::{} => {}, w,
" Type::{} => {},
", ",
type_.variant, oid).unwrap(); type_.variant,
oid
).unwrap();
} }
write!(w, write!(w,
@ -224,13 +238,16 @@ fn make_impl(w: &mut BufWriter<File>, types: &BTreeMap<u32, Type>) {
_ => "Simple".to_owned(), _ => "Simple".to_owned(),
}; };
write!(w, write!(
" Type::{} => {{ w,
" Type::{} => {{
const V: &'static Kind = &Kind::{}; const V: &'static Kind = &Kind::{};
V V
}} }}
", ",
type_.variant, kind).unwrap(); type_.variant,
kind
).unwrap();
} }
write!(w, write!(w,
@ -253,17 +270,21 @@ r#" Type::Other(ref u) => u.kind(),
).unwrap(); ).unwrap();
for type_ in types.values() { for type_ in types.values() {
write!(w, write!(
r#" Type::{} => "{}", w,
r#" Type::{} => "{}",
"#, "#,
type_.variant, type_.name).unwrap(); type_.variant,
type_.name
).unwrap();
} }
write!(w, write!(
" Type::Other(ref u) => u.name(), w,
" Type::Other(ref u) => u.name(),
}} }}
}} }}
}} }}
" "
).unwrap(); ).unwrap();
} }

View File

@ -30,7 +30,9 @@ mod test {
let password = b"password"; let password = b"password";
let salt = [0x2a, 0x3d, 0x8f, 0xe0]; let salt = [0x2a, 0x3d, 0x8f, 0xe0];
assert_eq!(md5_hash(username, password, salt), assert_eq!(
"md562af4dd09bbb41884907a838a3233294"); md5_hash(username, password, salt),
"md562af4dd09bbb41884907a838a3233294"
);
} }
} }

View File

@ -89,12 +89,12 @@ impl ScramSha256 {
let mut rng = OsRng::new()?; let mut rng = OsRng::new()?;
let nonce = (0..NONCE_LENGTH) let nonce = (0..NONCE_LENGTH)
.map(|_| { .map(|_| {
let mut v = rng.gen_range(0x21u8, 0x7e); let mut v = rng.gen_range(0x21u8, 0x7e);
if v == 0x2c { if v == 0x2c {
v = 0x7e v = 0x7e
} }
v as char v as char
}) })
.collect::<String>(); .collect::<String>();
ScramSha256::new_inner(password, nonce) ScramSha256::new_inner(password, nonce)
@ -108,12 +108,12 @@ impl ScramSha256 {
let password = normalize(password); let password = normalize(password);
Ok(ScramSha256 { Ok(ScramSha256 {
message: message, message: message,
state: State::Update { state: State::Update {
nonce: nonce, nonce: nonce,
password: password, password: password,
}, },
}) })
} }
/// Returns the message which should be sent to the backend in an `SASLResponse` message. /// Returns the message which should be sent to the backend in an `SASLResponse` message.
@ -133,8 +133,9 @@ impl ScramSha256 {
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
}; };
let message = str::from_utf8(message) let message = str::from_utf8(message).map_err(|e| {
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; io::Error::new(io::ErrorKind::InvalidInput, e)
})?;
let parsed = Parser::new(message).server_first_message()?; let parsed = Parser::new(message).server_first_message()?;
@ -193,14 +194,18 @@ impl ScramSha256 {
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
}; };
let message = str::from_utf8(message) let message = str::from_utf8(message).map_err(|e| {
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; io::Error::new(io::ErrorKind::InvalidInput, e)
})?;
let parsed = Parser::new(message).server_final_message()?; let parsed = Parser::new(message).server_final_message()?;
let verifier = match parsed { let verifier = match parsed {
ServerFinalMessage::Error(e) => { ServerFinalMessage::Error(e) => {
return Err(io::Error::new(io::ErrorKind::Other, format!("SCRAM error: {}", e))) return Err(io::Error::new(
io::ErrorKind::Other,
format!("SCRAM error: {}", e),
))
} }
ServerFinalMessage::Verifier(verifier) => verifier, ServerFinalMessage::Verifier(verifier) => verifier,
}; };
@ -219,7 +224,10 @@ impl ScramSha256 {
if hmac.verify(&verifier) { if hmac.verify(&verifier) {
Ok(()) Ok(())
} else { } else {
Err(io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error")) Err(io::Error::new(
io::ErrorKind::InvalidInput,
"SCRAM verification error",
))
} }
} }
} }
@ -241,18 +249,24 @@ impl<'a> Parser<'a> {
match self.it.next() { match self.it.next() {
Some((_, c)) if c == target => Ok(()), Some((_, c)) if c == target => Ok(()),
Some((i, c)) => { Some((i, c)) => {
let m = format!("unexpected character at byte {}: expected `{}` but got `{}", let m = format!(
i, "unexpected character at byte {}: expected `{}` but got `{}",
target, i,
c); target,
c
);
Err(io::Error::new(io::ErrorKind::InvalidInput, m)) Err(io::Error::new(io::ErrorKind::InvalidInput, m))
} }
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
} }
} }
fn take_while<F>(&mut self, f: F) -> io::Result<&'a str> fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
where F: Fn(char) -> bool where
F: Fn(char) -> bool,
{ {
let start = match self.it.peek() { let start = match self.it.peek() {
Some(&(i, _)) => i, Some(&(i, _)) => i,
@ -272,9 +286,9 @@ impl<'a> Parser<'a> {
fn printable(&mut self) -> io::Result<&'a str> { fn printable(&mut self) -> io::Result<&'a str> {
self.take_while(|c| match c { self.take_while(|c| match c {
'\x21'...'\x2b' | '\x2d'...'\x7e' => true, '\x21'...'\x2b' | '\x2d'...'\x7e' => true,
_ => false, _ => false,
}) })
} }
fn nonce(&mut self) -> io::Result<&'a str> { fn nonce(&mut self) -> io::Result<&'a str> {
@ -285,9 +299,9 @@ impl<'a> Parser<'a> {
fn base64(&mut self) -> io::Result<&'a str> { fn base64(&mut self) -> io::Result<&'a str> {
self.take_while(|c| match c { self.take_while(|c| match c {
'a'...'z' | 'A'...'Z' | '0'...'9' | '/' | '+' | '=' => true, 'a'...'z' | 'A'...'Z' | '0'...'9' | '/' | '+' | '=' => true,
_ => false, _ => false,
}) })
} }
fn salt(&mut self) -> io::Result<&'a str> { fn salt(&mut self) -> io::Result<&'a str> {
@ -298,11 +312,12 @@ impl<'a> Parser<'a> {
fn posit_number(&mut self) -> io::Result<u32> { fn posit_number(&mut self) -> io::Result<u32> {
let n = self.take_while(|c| match c { let n = self.take_while(|c| match c {
'0'...'9' => true, '0'...'9' => true,
_ => false, _ => false,
})?; })?;
n.parse() n.parse().map_err(
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) |e| io::Error::new(io::ErrorKind::InvalidInput, e),
)
} }
fn iteration_count(&mut self) -> io::Result<u32> { fn iteration_count(&mut self) -> io::Result<u32> {
@ -314,8 +329,10 @@ impl<'a> Parser<'a> {
fn eof(&mut self) -> io::Result<()> { fn eof(&mut self) -> io::Result<()> {
match self.it.peek() { match self.it.peek() {
Some(&(i, _)) => { Some(&(i, _)) => {
Err(io::Error::new(io::ErrorKind::InvalidInput, Err(io::Error::new(
format!("unexpected trailing data at byte {}", i))) io::ErrorKind::InvalidInput,
format!("unexpected trailing data at byte {}", i),
))
} }
None => Ok(()), None => Ok(()),
} }
@ -330,17 +347,17 @@ impl<'a> Parser<'a> {
self.eof()?; self.eof()?;
Ok(ServerFirstMessage { Ok(ServerFirstMessage {
nonce: nonce, nonce: nonce,
salt: salt, salt: salt,
iteration_count: iteration_count, iteration_count: iteration_count,
}) })
} }
fn value(&mut self) -> io::Result<&'a str> { fn value(&mut self) -> io::Result<&'a str> {
self.take_while(|c| match c { self.take_while(|c| match c {
'\0' | '=' | ',' => false, '\0' | '=' | ',' => false,
_ => true, _ => true,
}) })
} }
fn server_error(&mut self) -> io::Result<Option<&'a str>> { fn server_error(&mut self) -> io::Result<Option<&'a str>> {

View File

@ -43,8 +43,9 @@ pub enum IsNull {
#[inline] #[inline]
fn write_nullable<F, E>(serializer: F, buf: &mut Vec<u8>) -> Result<(), E> fn write_nullable<F, E>(serializer: F, buf: &mut Vec<u8>) -> Result<(), E>
where F: FnOnce(&mut Vec<u8>) -> Result<IsNull, E>, where
E: From<io::Error> F: FnOnce(&mut Vec<u8>) -> Result<IsNull, E>,
E: From<io::Error>,
{ {
let base = buf.len(); let base = buf.len();
buf.extend_from_slice(&[0; 4]); buf.extend_from_slice(&[0; 4]);

View File

@ -61,7 +61,10 @@ impl Message {
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap(); let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 { if len < 4 {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
let total_len = len as usize + 1; let total_len = len as usize + 1;
@ -85,10 +88,10 @@ impl Message {
let channel = buf.read_cstr()?; let channel = buf.read_cstr()?;
let message = buf.read_cstr()?; let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody { Message::NotificationResponse(NotificationResponseBody {
process_id: process_id, process_id: process_id,
channel: channel, channel: channel,
message: message, message: message,
}) })
} }
b'c' => Message::CopyDone, b'c' => Message::CopyDone,
b'C' => { b'C' => {
@ -103,9 +106,9 @@ impl Message {
let len = buf.read_u16::<BigEndian>()?; let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all(); let storage = buf.read_all();
Message::DataRow(DataRowBody { Message::DataRow(DataRowBody {
storage: storage, storage: storage,
len: len, len: len,
}) })
} }
b'E' => { b'E' => {
let storage = buf.read_all(); let storage = buf.read_all();
@ -116,29 +119,29 @@ impl Message {
let len = buf.read_u16::<BigEndian>()?; let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all(); let storage = buf.read_all();
Message::CopyInResponse(CopyInResponseBody { Message::CopyInResponse(CopyInResponseBody {
format: format, format: format,
len: len, len: len,
storage: storage, storage: storage,
}) })
} }
b'H' => { b'H' => {
let format = buf.read_u8()?; let format = buf.read_u8()?;
let len = buf.read_u16::<BigEndian>()?; let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all(); let storage = buf.read_all();
Message::CopyOutResponse(CopyOutResponseBody { Message::CopyOutResponse(CopyOutResponseBody {
format: format, format: format,
len: len, len: len,
storage: storage, storage: storage,
}) })
} }
b'I' => Message::EmptyQueryResponse, b'I' => Message::EmptyQueryResponse,
b'K' => { b'K' => {
let process_id = buf.read_i32::<BigEndian>()?; let process_id = buf.read_i32::<BigEndian>()?;
let secret_key = buf.read_i32::<BigEndian>()?; let secret_key = buf.read_i32::<BigEndian>()?;
Message::BackendKeyData(BackendKeyDataBody { Message::BackendKeyData(BackendKeyDataBody {
process_id: process_id, process_id: process_id,
secret_key: secret_key, secret_key: secret_key,
}) })
} }
b'n' => Message::NoData, b'n' => Message::NoData,
b'N' => { b'N' => {
@ -153,9 +156,9 @@ impl Message {
5 => { 5 => {
let mut salt = [0; 4]; let mut salt = [0; 4];
buf.read_exact(&mut salt)?; buf.read_exact(&mut salt)?;
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { Message::AuthenticationMd5Password(
salt: salt, AuthenticationMd5PasswordBody { salt: salt },
}) )
} }
6 => Message::AuthenticationScmCredential, 6 => Message::AuthenticationScmCredential,
7 => Message::AuthenticationGss, 7 => Message::AuthenticationGss,
@ -177,8 +180,10 @@ impl Message {
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
} }
tag => { tag => {
return Err(io::Error::new(io::ErrorKind::InvalidInput, return Err(io::Error::new(
format!("unknown authentication tag `{}`", tag))); io::ErrorKind::InvalidInput,
format!("unknown authentication tag `{}`", tag),
));
} }
} }
} }
@ -187,38 +192,43 @@ impl Message {
let name = buf.read_cstr()?; let name = buf.read_cstr()?;
let value = buf.read_cstr()?; let value = buf.read_cstr()?;
Message::ParameterStatus(ParameterStatusBody { Message::ParameterStatus(ParameterStatusBody {
name: name, name: name,
value: value, value: value,
}) })
} }
b't' => { b't' => {
let len = buf.read_u16::<BigEndian>()?; let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all(); let storage = buf.read_all();
Message::ParameterDescription(ParameterDescriptionBody { Message::ParameterDescription(ParameterDescriptionBody {
storage: storage, storage: storage,
len: len, len: len,
}) })
} }
b'T' => { b'T' => {
let len = buf.read_u16::<BigEndian>()?; let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all(); let storage = buf.read_all();
Message::RowDescription(RowDescriptionBody { Message::RowDescription(RowDescriptionBody {
storage: storage, storage: storage,
len: len, len: len,
}) })
} }
b'Z' => { b'Z' => {
let status = buf.read_u8()?; let status = buf.read_u8()?;
Message::ReadyForQuery(ReadyForQueryBody { status: status }) Message::ReadyForQuery(ReadyForQueryBody { status: status })
} }
tag => { tag => {
return Err(io::Error::new(io::ErrorKind::InvalidInput, return Err(io::Error::new(
format!("unknown message tag `{}`", tag))); io::ErrorKind::InvalidInput,
format!("unknown message tag `{}`", tag),
));
} }
}; };
if !buf.is_empty() { if !buf.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
Ok(Some(message)) Ok(Some(message))
@ -248,7 +258,10 @@ impl Buffer {
self.idx = end + 1; self.idx = end + 1;
Ok(cstr) Ok(cstr)
} }
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
} }
} }
@ -312,7 +325,10 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> {
let value_end = find_null(self.0, 0)?; let value_end = find_null(self.0, 0)?;
if value_end == 0 { if value_end == 0 {
if self.0.len() != 1 { if self.0.len() != 1 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length",
));
} }
Ok(None) Ok(None)
} else { } else {
@ -416,7 +432,10 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
if self.buf.is_empty() { if self.buf.is_empty() {
return Ok(None); return Ok(None);
} else { } else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
} }
@ -489,7 +508,10 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
if self.buf.is_empty() { if self.buf.is_empty() {
return Ok(None); return Ok(None);
} else { } else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
} }
@ -500,7 +522,10 @@ impl<'a> FallibleIterator for DataRowRanges<'a> {
} else { } else {
let len = len as usize; let len = len as usize;
if self.buf.len() < len { if self.buf.len() < len {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")); return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
} }
let base = self.len - self.buf.len(); let base = self.len - self.buf.len();
self.buf = &self.buf[len as usize..]; self.buf = &self.buf[len as usize..];
@ -541,7 +566,10 @@ impl<'a> FallibleIterator for ErrorFields<'a> {
if self.buf.is_empty() { if self.buf.is_empty() {
return Ok(None); return Ok(None);
} else { } else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
} }
@ -550,9 +578,9 @@ impl<'a> FallibleIterator for ErrorFields<'a> {
self.buf = &self.buf[value_end + 1..]; self.buf = &self.buf[value_end + 1..];
Ok(Some(ErrorField { Ok(Some(ErrorField {
type_: type_, type_: type_,
value: value, value: value,
})) }))
} }
} }
@ -637,7 +665,10 @@ impl<'a> FallibleIterator for Parameters<'a> {
if self.buf.is_empty() { if self.buf.is_empty() {
return Ok(None); return Ok(None);
} else { } else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
} }
@ -710,7 +741,10 @@ impl<'a> FallibleIterator for Fields<'a> {
if self.buf.is_empty() { if self.buf.is_empty() {
return Ok(None); return Ok(None);
} else { } else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length")); return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length",
));
} }
} }
@ -726,14 +760,14 @@ impl<'a> FallibleIterator for Fields<'a> {
let format = self.buf.read_i16::<BigEndian>()?; let format = self.buf.read_i16::<BigEndian>()?;
Ok(Some(Field { Ok(Some(Field {
name: name, name: name,
table_oid: table_oid, table_oid: table_oid,
column_id: column_id, column_id: column_id,
type_oid: type_oid, type_oid: type_oid,
type_size: type_size, type_size: type_size,
type_modifier: type_modifier, type_modifier: type_modifier,
format: format, format: format,
})) }))
} }
} }
@ -788,7 +822,10 @@ impl<'a> Field<'a> {
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> { fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
match memchr(0, &buf[start..]) { match memchr(0, &buf[start..]) {
Some(pos) => Ok(pos + start), Some(pos) => Ok(pos + start),
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
} }
} }

View File

@ -51,19 +51,21 @@ impl<'a> Message<'a> {
values, values,
result_formats, result_formats,
} => { } => {
let r = bind(portal, let r = bind(
statement, portal,
formats.iter().cloned(), statement,
values, formats.iter().cloned(),
|v, buf| match *v { values,
Some(ref v) => { |v, buf| match *v {
buf.extend_from_slice(v); Some(ref v) => {
Ok(IsNull::No) buf.extend_from_slice(v);
} Ok(IsNull::No)
None => Ok(IsNull::Yes), }
}, None => Ok(IsNull::Yes),
result_formats.iter().cloned(), },
buf); result_formats.iter().cloned(),
buf,
);
match r { match r {
Ok(()) => Ok(()), Ok(()) => Ok(()),
Err(BindError::Conversion(_)) => unreachable!(), Err(BindError::Conversion(_)) => unreachable!(),
@ -104,8 +106,9 @@ impl<'a> Message<'a> {
#[inline] #[inline]
fn write_body<F, E>(buf: &mut Vec<u8>, f: F) -> Result<(), E> fn write_body<F, E>(buf: &mut Vec<u8>, f: F) -> Result<(), E>
where F: FnOnce(&mut Vec<u8>) -> Result<(), E>, where
E: From<io::Error> F: FnOnce(&mut Vec<u8>) -> Result<(), E>,
E: From<io::Error>,
{ {
let base = buf.len(); let base = buf.len();
buf.extend_from_slice(&[0; 4]); buf.extend_from_slice(&[0; 4]);
@ -137,18 +140,20 @@ impl From<io::Error> for BindError {
} }
#[inline] #[inline]
pub fn bind<I, J, F, T, K>(portal: &str, pub fn bind<I, J, F, T, K>(
statement: &str, portal: &str,
formats: I, statement: &str,
values: J, formats: I,
mut serializer: F, values: J,
result_formats: K, mut serializer: F,
buf: &mut Vec<u8>) result_formats: K,
-> Result<(), BindError> buf: &mut Vec<u8>,
where I: IntoIterator<Item = i16>, ) -> Result<(), BindError>
J: IntoIterator<Item = T>, where
F: FnMut(T, &mut Vec<u8>) -> Result<IsNull, Box<Error + marker::Sync + Send>>, I: IntoIterator<Item = i16>,
K: IntoIterator<Item = i16> J: IntoIterator<Item = T>,
F: FnMut(T, &mut Vec<u8>) -> Result<IsNull, Box<Error + marker::Sync + Send>>,
K: IntoIterator<Item = i16>,
{ {
buf.push(b'B'); buf.push(b'B');
@ -156,9 +161,11 @@ pub fn bind<I, J, F, T, K>(portal: &str,
buf.write_cstr(portal)?; buf.write_cstr(portal)?;
buf.write_cstr(statement)?; buf.write_cstr(statement)?;
write_counted(formats, |f, buf| buf.write_i16::<BigEndian>(f), buf)?; write_counted(formats, |f, buf| buf.write_i16::<BigEndian>(f), buf)?;
write_counted(values, write_counted(
|v, buf| write_nullable(|buf| serializer(v, buf), buf), values,
buf)?; |v, buf| write_nullable(|buf| serializer(v, buf), buf),
buf,
)?;
write_counted(result_formats, |f, buf| buf.write_i16::<BigEndian>(f), buf)?; write_counted(result_formats, |f, buf| buf.write_i16::<BigEndian>(f), buf)?;
Ok(()) Ok(())
@ -167,9 +174,10 @@ pub fn bind<I, J, F, T, K>(portal: &str,
#[inline] #[inline]
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut Vec<u8>) -> Result<(), E> fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut Vec<u8>) -> Result<(), E>
where I: IntoIterator<Item = T>, where
F: FnMut(T, &mut Vec<u8>) -> Result<(), E>, I: IntoIterator<Item = T>,
E: From<io::Error> F: FnMut(T, &mut Vec<u8>) -> Result<(), E>,
E: From<io::Error>,
{ {
let base = buf.len(); let base = buf.len();
buf.extend_from_slice(&[0; 2]); buf.extend_from_slice(&[0; 2]);
@ -190,8 +198,7 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut Vec<u8>) {
buf.write_i32::<BigEndian>(80877102).unwrap(); buf.write_i32::<BigEndian>(80877102).unwrap();
buf.write_i32::<BigEndian>(process_id).unwrap(); buf.write_i32::<BigEndian>(process_id).unwrap();
buf.write_i32::<BigEndian>(secret_key) buf.write_i32::<BigEndian>(secret_key)
}) }).unwrap();
.unwrap();
} }
#[inline] #[inline]
@ -246,7 +253,8 @@ pub fn execute(portal: &str, max_rows: i32, buf: &mut Vec<u8>) -> io::Result<()>
#[inline] #[inline]
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut Vec<u8>) -> io::Result<()> pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut Vec<u8>) -> io::Result<()>
where I: IntoIterator<Item = Oid> where
I: IntoIterator<Item = Oid>,
{ {
buf.push(b'P'); buf.push(b'P');
write_body(buf, |buf| { write_body(buf, |buf| {
@ -294,7 +302,8 @@ pub fn ssl_request(buf: &mut Vec<u8>) {
#[inline] #[inline]
pub fn startup_message<'a, I>(parameters: I, buf: &mut Vec<u8>) -> io::Result<()> pub fn startup_message<'a, I>(parameters: I, buf: &mut Vec<u8>) -> io::Result<()>
where I: IntoIterator<Item = (&'a str, &'a str)> where
I: IntoIterator<Item = (&'a str, &'a str)>,
{ {
write_body(buf, |buf| { write_body(buf, |buf| {
buf.write_i32::<BigEndian>(196608).unwrap(); buf.write_i32::<BigEndian>(196608).unwrap();
@ -327,8 +336,10 @@ impl WriteCStr for Vec<u8> {
#[inline] #[inline]
fn write_cstr(&mut self, s: &str) -> Result<(), io::Error> { fn write_cstr(&mut self, s: &str) -> Result<(), io::Error> {
if s.as_bytes().contains(&0) { if s.as_bytes().contains(&0) {
return Err(io::Error::new(io::ErrorKind::InvalidInput, return Err(io::Error::new(
"string contains embedded null")); io::ErrorKind::InvalidInput,
"string contains embedded null",
));
} }
self.extend_from_slice(s.as_bytes()); self.extend_from_slice(s.as_bytes());
self.push(0); self.push(0);

View File

@ -168,7 +168,8 @@ pub fn float8_from_sql(mut buf: &[u8]) -> Result<f64, StdBox<Error + Sync + Send
/// Serializes an `HSTORE` value. /// Serializes an `HSTORE` value.
#[inline] #[inline]
pub fn hstore_to_sql<'a, I>(values: I, buf: &mut Vec<u8>) -> Result<(), StdBox<Error + Sync + Send>> pub fn hstore_to_sql<'a, I>(values: I, buf: &mut Vec<u8>) -> Result<(), StdBox<Error + Sync + Send>>
where I: IntoIterator<Item = (&'a str, Option<&'a str>)> where
I: IntoIterator<Item = (&'a str, Option<&'a str>)>,
{ {
let base = buf.len(); let base = buf.len();
buf.extend_from_slice(&[0; 4]); buf.extend_from_slice(&[0; 4]);
@ -204,17 +205,18 @@ fn write_pascal_string(s: &str, buf: &mut Vec<u8>) -> Result<(), StdBox<Error +
/// Deserializes an `HSTORE` value. /// Deserializes an `HSTORE` value.
#[inline] #[inline]
pub fn hstore_from_sql<'a>(mut buf: &'a [u8]) pub fn hstore_from_sql<'a>(
-> Result<HstoreEntries<'a>, StdBox<Error + Sync + Send>> { mut buf: &'a [u8],
) -> Result<HstoreEntries<'a>, StdBox<Error + Sync + Send>> {
let count = buf.read_i32::<BigEndian>()?; let count = buf.read_i32::<BigEndian>()?;
if count < 0 { if count < 0 {
return Err("invalid entry count".into()); return Err("invalid entry count".into());
} }
Ok(HstoreEntries { Ok(HstoreEntries {
remaining: count, remaining: count,
buf: buf, buf: buf,
}) })
} }
/// A fallible iterator over `HSTORE` entries. /// A fallible iterator over `HSTORE` entries.
@ -268,11 +270,13 @@ impl<'a> FallibleIterator for HstoreEntries<'a> {
/// Serializes a `VARBIT` or `BIT` value. /// Serializes a `VARBIT` or `BIT` value.
#[inline] #[inline]
pub fn varbit_to_sql<I>(len: usize, pub fn varbit_to_sql<I>(
v: I, len: usize,
buf: &mut Vec<u8>) v: I,
-> Result<(), StdBox<Error + Sync + Send>> buf: &mut Vec<u8>,
where I: Iterator<Item = u8> ) -> Result<(), StdBox<Error + Sync + Send>>
where
I: Iterator<Item = u8>,
{ {
let len = i32::from_usize(len)?; let len = i32::from_usize(len)?;
buf.write_i32::<BigEndian>(len).unwrap(); buf.write_i32::<BigEndian>(len).unwrap();
@ -297,9 +301,9 @@ pub fn varbit_from_sql<'a>(mut buf: &'a [u8]) -> Result<Varbit<'a>, StdBox<Error
} }
Ok(Varbit { Ok(Varbit {
len: len as usize, len: len as usize,
bytes: buf, bytes: buf,
}) })
} }
/// A `VARBIT` value. /// A `VARBIT` value.
@ -418,16 +422,18 @@ pub fn uuid_from_sql(buf: &[u8]) -> Result<[u8; 16], StdBox<Error + Sync + Send>
/// Serializes an array value. /// Serializes an array value.
#[inline] #[inline]
pub fn array_to_sql<T, I, J, F>(dimensions: I, pub fn array_to_sql<T, I, J, F>(
has_nulls: bool, dimensions: I,
element_type: Oid, has_nulls: bool,
elements: J, element_type: Oid,
mut serializer: F, elements: J,
buf: &mut Vec<u8>) mut serializer: F,
-> Result<(), StdBox<Error + Sync + Send>> buf: &mut Vec<u8>,
where I: IntoIterator<Item = ArrayDimension>, ) -> Result<(), StdBox<Error + Sync + Send>>
J: IntoIterator<Item = T>, where
F: FnMut(T, &mut Vec<u8>) -> Result<IsNull, StdBox<Error + Sync + Send>> I: IntoIterator<Item = ArrayDimension>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut Vec<u8>) -> Result<IsNull, StdBox<Error + Sync + Send>>,
{ {
let dimensions_idx = buf.len(); let dimensions_idx = buf.len();
buf.extend_from_slice(&[0; 4]); buf.extend_from_slice(&[0; 4]);
@ -482,12 +488,12 @@ pub fn array_from_sql<'a>(mut buf: &'a [u8]) -> Result<Array<'a>, StdBox<Error +
} }
Ok(Array { Ok(Array {
dimensions: dimensions, dimensions: dimensions,
has_nulls: has_nulls, has_nulls: has_nulls,
element_type: element_type, element_type: element_type,
elements: elements, elements: elements,
buf: buf, buf: buf,
}) })
} }
/// A Postgres array. /// A Postgres array.
@ -545,9 +551,9 @@ impl<'a> FallibleIterator for ArrayDimensions<'a> {
let lower_bound = self.0.read_i32::<BigEndian>()?; let lower_bound = self.0.read_i32::<BigEndian>()?;
Ok(Some(ArrayDimension { Ok(Some(ArrayDimension {
len: len, len: len,
lower_bound: lower_bound, lower_bound: lower_bound,
})) }))
} }
#[inline] #[inline]
@ -616,12 +622,14 @@ pub fn empty_range_to_sql(buf: &mut Vec<u8>) {
} }
/// Serializes a range value. /// Serializes a range value.
pub fn range_to_sql<F, G>(lower: F, pub fn range_to_sql<F, G>(
upper: G, lower: F,
buf: &mut Vec<u8>) upper: G,
-> Result<(), StdBox<Error + Sync + Send>> buf: &mut Vec<u8>,
where F: FnOnce(&mut Vec<u8>) -> Result<RangeBound<IsNull>, StdBox<Error + Sync + Send>>, ) -> Result<(), StdBox<Error + Sync + Send>>
G: FnOnce(&mut Vec<u8>) -> Result<RangeBound<IsNull>, StdBox<Error + Sync + Send>> where
F: FnOnce(&mut Vec<u8>) -> Result<RangeBound<IsNull>, StdBox<Error + Sync + Send>>,
G: FnOnce(&mut Vec<u8>) -> Result<RangeBound<IsNull>, StdBox<Error + Sync + Send>>,
{ {
let tag_idx = buf.len(); let tag_idx = buf.len();
buf.push(0); buf.push(0);
@ -644,10 +652,12 @@ pub fn range_to_sql<F, G>(lower: F,
Ok(()) Ok(())
} }
fn write_bound<F>(bound: F, fn write_bound<F>(
buf: &mut Vec<u8>) bound: F,
-> Result<RangeBound<()>, StdBox<Error + Sync + Send>> buf: &mut Vec<u8>,
where F: FnOnce(&mut Vec<u8>) -> Result<RangeBound<IsNull>, StdBox<Error + Sync + Send>> ) -> Result<RangeBound<()>, StdBox<Error + Sync + Send>>
where
F: FnOnce(&mut Vec<u8>) -> Result<RangeBound<IsNull>, StdBox<Error + Sync + Send>>,
{ {
let base = buf.len(); let base = buf.len();
buf.extend_from_slice(&[0; 4]); buf.extend_from_slice(&[0; 4]);
@ -707,11 +717,12 @@ pub fn range_from_sql<'a>(mut buf: &'a [u8]) -> Result<Range<'a>, StdBox<Error +
} }
#[inline] #[inline]
fn read_bound<'a>(buf: &mut &'a [u8], fn read_bound<'a>(
tag: u8, buf: &mut &'a [u8],
unbounded: u8, tag: u8,
inclusive: u8) unbounded: u8,
-> Result<RangeBound<Option<&'a [u8]>>, StdBox<Error + Sync + Send>> { inclusive: u8,
) -> Result<RangeBound<Option<&'a [u8]>>, StdBox<Error + Sync + Send>> {
if tag & unbounded != 0 { if tag & unbounded != 0 {
Ok(RangeBound::Unbounded) Ok(RangeBound::Unbounded)
} else { } else {
@ -803,9 +814,9 @@ pub fn box_from_sql(mut buf: &[u8]) -> Result<Box, StdBox<Error + Sync + Send>>
return Err("invalid buffer size".into()); return Err("invalid buffer size".into());
} }
Ok(Box { Ok(Box {
upper_right: Point { x: x1, y: y1 }, upper_right: Point { x: x1, y: y1 },
lower_left: Point { x: x2, y: y2 }, lower_left: Point { x: x2, y: y2 },
}) })
} }
/// A Postgres box. /// A Postgres box.
@ -831,11 +842,13 @@ impl Box {
/// Serializes a Postgres path. /// Serializes a Postgres path.
#[inline] #[inline]
pub fn path_to_sql<I>(closed: bool, pub fn path_to_sql<I>(
points: I, closed: bool,
buf: &mut Vec<u8>) points: I,
-> Result<(), StdBox<Error + Sync + Send>> buf: &mut Vec<u8>,
where I: IntoIterator<Item = (f64, f64)> ) -> Result<(), StdBox<Error + Sync + Send>>
where
I: IntoIterator<Item = (f64, f64)>,
{ {
buf.push(closed as u8); buf.push(closed as u8);
let points_idx = buf.len(); let points_idx = buf.len();
@ -863,10 +876,10 @@ pub fn path_from_sql<'a>(mut buf: &'a [u8]) -> Result<Path<'a>, StdBox<Error + S
let points = buf.read_i32::<BigEndian>()?; let points = buf.read_i32::<BigEndian>()?;
Ok(Path { Ok(Path {
closed: closed, closed: closed,
points: points, points: points,
buf: buf, buf: buf,
}) })
} }
/// A Postgres point. /// A Postgres point.
@ -988,11 +1001,13 @@ mod test {
let mut buf = vec![]; let mut buf = vec![];
hstore_to_sql(map.iter().map(|(&k, &v)| (k, v)), &mut buf).unwrap(); hstore_to_sql(map.iter().map(|(&k, &v)| (k, v)), &mut buf).unwrap();
assert_eq!(hstore_from_sql(&buf) assert_eq!(
.unwrap() hstore_from_sql(&buf)
.collect::<HashMap<_, _>>() .unwrap()
.unwrap(), .collect::<HashMap<_, _>>()
map); .unwrap(),
map
);
} }
#[test] #[test]
@ -1009,30 +1024,33 @@ mod test {
#[test] #[test]
fn array() { fn array() {
let dimensions = [ArrayDimension { let dimensions = [
len: 1, ArrayDimension {
lower_bound: 10, len: 1,
}, lower_bound: 10,
ArrayDimension { },
len: 2, ArrayDimension {
lower_bound: 0, len: 2,
}]; lower_bound: 0,
},
];
let values = [None, Some(&b"hello"[..])]; let values = [None, Some(&b"hello"[..])];
let mut buf = vec![]; let mut buf = vec![];
array_to_sql(dimensions.iter().cloned(), array_to_sql(
true, dimensions.iter().cloned(),
10, true,
values.iter().cloned(), 10,
|v, buf| match v { values.iter().cloned(),
Some(v) => { |v, buf| match v {
buf.extend_from_slice(v); Some(v) => {
Ok(IsNull::No) buf.extend_from_slice(v);
} Ok(IsNull::No)
None => Ok(IsNull::Yes), }
}, None => Ok(IsNull::Yes),
&mut buf) },
.unwrap(); &mut buf,
).unwrap();
let array = array_from_sql(&buf).unwrap(); let array = array_from_sql(&buf).unwrap();
assert_eq!(array.has_nulls(), true); assert_eq!(array.has_nulls(), true);

View File

@ -168,14 +168,18 @@ impl DbError {
b'H' => hint = Some(field.value().to_owned()), b'H' => hint = Some(field.value().to_owned()),
b'P' => { b'P' => {
normal_position = Some(field.value().parse::<u32>().map_err(|_| { normal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, io::Error::new(
"`P` field did not contain an integer") io::ErrorKind::InvalidInput,
"`P` field did not contain an integer",
)
})?); })?);
} }
b'p' => { b'p' => {
internal_position = Some(field.value().parse::<u32>().map_err(|_| { internal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, io::Error::new(
"`p` field did not contain an integer") io::ErrorKind::InvalidInput,
"`p` field did not contain an integer",
)
})?); })?);
} }
b'q' => internal_query = Some(field.value().to_owned()), b'q' => internal_query = Some(field.value().to_owned()),
@ -188,18 +192,22 @@ impl DbError {
b'F' => file = Some(field.value().to_owned()), b'F' => file = Some(field.value().to_owned()),
b'L' => { b'L' => {
line = Some(field.value().parse::<u32>().map_err(|_| { line = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, io::Error::new(
"`L` field did not contain an integer") io::ErrorKind::InvalidInput,
"`L` field did not contain an integer",
)
})?); })?);
} }
b'R' => routine = Some(field.value().to_owned()), b'R' => routine = Some(field.value().to_owned()),
b'V' => { b'V' => {
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, io::Error::new(
"`V` field contained an invalid value") io::ErrorKind::InvalidInput,
"`V` field contained an invalid value",
)
})?); })?);
} }
_ => {}, _ => {}
} }
} }
@ -208,10 +216,12 @@ impl DbError {
io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing") io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing")
})?, })?,
parsed_severity: parsed_severity, parsed_severity: parsed_severity,
code: code.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, code: code.ok_or_else(|| {
"`C` field missing"))?, io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing")
message: message.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, })?,
"`M` field missing"))?, message: message.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing")
})?,
detail: detail, detail: detail,
hint: hint, hint: hint,
position: match normal_position { position: match normal_position {
@ -222,8 +232,10 @@ impl DbError {
Some(ErrorPosition::Internal { Some(ErrorPosition::Internal {
position: position, position: position,
query: internal_query.ok_or_else(|| { query: internal_query.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, io::Error::new(
"`q` field missing but `p` field present") io::ErrorKind::InvalidInput,
"`q` field missing but `p` field present",
)
})?, })?,
}) })
} }

View File

@ -168,7 +168,17 @@ impl IntoConnectParams for String {
impl IntoConnectParams for Url { impl IntoConnectParams for Url {
fn into_connect_params(self) -> Result<ConnectParams, Box<Error + Sync + Send>> { fn into_connect_params(self) -> Result<ConnectParams, Box<Error + Sync + Send>> {
let Url { host, port, user, path: url::Path { path, query: options, .. }, .. } = self; let Url {
host,
port,
user,
path: url::Path {
path,
query: options,
..
},
..
} = self;
let mut builder = ConnectParams::builder(); let mut builder = ConnectParams::builder();
@ -199,4 +209,3 @@ impl IntoConnectParams for Url {
Ok(builder.build(host)) Ok(builder.build(host))
} }
} }

View File

@ -32,14 +32,15 @@ pub struct UserInfo {
pub type Query = Vec<(String, String)>; pub type Query = Vec<(String, String)>;
impl Url { impl Url {
pub fn new(scheme: String, pub fn new(
user: Option<UserInfo>, scheme: String,
host: String, user: Option<UserInfo>,
port: Option<u16>, host: String,
path: String, port: Option<u16>,
query: Query, path: String,
fragment: Option<String>) query: Query,
-> Url { fragment: Option<String>,
) -> Url {
Url { Url {
scheme: scheme, scheme: scheme,
user: user, user: user,
@ -63,13 +64,15 @@ impl Url {
// query and fragment // query and fragment
let (query, fragment) = get_query_fragment(rest)?; let (query, fragment) = get_query_fragment(rest)?;
let url = Url::new(scheme.to_owned(), let url = Url::new(
userinfo, scheme.to_owned(),
host.to_owned(), userinfo,
port, host.to_owned(),
path, port,
query, path,
fragment); query,
fragment,
);
Ok(url) Ok(url)
} }
} }
@ -125,19 +128,23 @@ fn decode_inner(c: &str, full_url: bool) -> DecodeResult<String> {
let bytes = match (iter.next(), iter.next()) { let bytes = match (iter.next(), iter.next()) {
(Some(one), Some(two)) => [one, two], (Some(one), Some(two)) => [one, two],
_ => { _ => {
return Err("Malformed input: found '%' without two \ return Err(
"Malformed input: found '%' without two \
trailing bytes" trailing bytes"
.to_owned()) .to_owned(),
)
} }
}; };
let bytes_from_hex = match Vec::<u8>::from_hex(&bytes) { let bytes_from_hex = match Vec::<u8>::from_hex(&bytes) {
Ok(b) => b, Ok(b) => b,
_ => { _ => {
return Err("Malformed input: found '%' followed by \ return Err(
"Malformed input: found '%' followed by \
invalid hex values. Character '%' must \ invalid hex values. Character '%' must \
escaped." escaped."
.to_owned()) .to_owned(),
)
} }
}; };
@ -247,9 +254,7 @@ fn get_authority(rawurl: &str) -> DecodeResult<(Option<UserInfo>, &str, Option<u
let mut begin = 2; let mut begin = 2;
let mut end = len; let mut end = len;
for (i, c) in rawurl.chars() for (i, c) in rawurl.chars().enumerate().skip(2) {
.enumerate()
.skip(2) {
// deal with input class first // deal with input class first
match c { match c {
'0'...'9' => (), '0'...'9' => (),
@ -390,7 +395,9 @@ fn get_path(rawurl: &str, is_authority: bool) -> DecodeResult<(String, &str)> {
} }
if is_authority && end != 0 && !rawurl.starts_with('/') { if is_authority && end != 0 && !rawurl.starts_with('/') {
Err("Non-empty path must begin with '/' in presence of authority.".to_owned()) Err(
"Non-empty path must begin with '/' in presence of authority.".to_owned(),
)
} else { } else {
Ok((decode_component(&rawurl[0..end])?, &rawurl[end..len])) Ok((decode_component(&rawurl[0..end])?, &rawurl[end..len]))
} }
@ -409,7 +416,10 @@ fn get_query_fragment(rawurl: &str) -> DecodeResult<(Query, Option<String>)> {
match before_fragment.chars().next() { match before_fragment.chars().next() {
Some('?') => Ok((query_from_str(&before_fragment[1..])?, fragment)), Some('?') => Ok((query_from_str(&before_fragment[1..])?, fragment)),
None => Ok((vec![], fragment)), None => Ok((vec![], fragment)),
_ => Err(format!("Query didn't start with '?': '{}..'", before_fragment)), _ => Err(format!(
"Query didn't start with '?': '{}..'",
before_fragment
)),
} }
} }

View File

@ -1,4 +1,4 @@
use fallible_iterator::{FallibleIterator}; use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::DataRowBody; use postgres_protocol::message::backend::DataRowBody;
use std::ascii::AsciiExt; use std::ascii::AsciiExt;
use std::io; use std::io;
@ -34,13 +34,15 @@ impl<'a> RowIndex for str {
// FIXME ASCII-only case insensitivity isn't really the right thing to // FIXME ASCII-only case insensitivity isn't really the right thing to
// do. Postgres itself uses a dubious wrapper around tolower and JDBC // do. Postgres itself uses a dubious wrapper around tolower and JDBC
// uses the US locale. // uses the US locale.
stmt.iter() stmt.iter().position(
.position(|d| d.name().eq_ignore_ascii_case(self)) |d| d.name().eq_ignore_ascii_case(self),
)
} }
} }
impl<'a, T: ?Sized> RowIndex for &'a T impl<'a, T: ?Sized> RowIndex for &'a T
where T: RowIndex where
T: RowIndex,
{ {
#[inline] #[inline]
fn idx(&self, columns: &[Column]) -> Option<usize> { fn idx(&self, columns: &[Column]) -> Option<usize> {
@ -58,9 +60,9 @@ impl RowData {
pub fn new(body: DataRowBody) -> io::Result<RowData> { pub fn new(body: DataRowBody) -> io::Result<RowData> {
let ranges = body.ranges().collect()?; let ranges = body.ranges().collect()?;
Ok(RowData { Ok(RowData {
body: body, body: body,
ranges: ranges, ranges: ranges,
}) })
} }
pub fn len(&self) -> usize { pub fn len(&self) -> usize {

View File

@ -21,10 +21,7 @@ impl FromSql for BitVec {
} }
impl ToSql for BitVec { impl ToSql for BitVec {
fn to_sql(&self, fn to_sql(&self, _: &Type, mut out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
mut out: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
types::varbit_to_sql(self.len(), self.to_bytes().into_iter(), out)?; types::varbit_to_sql(self.len(), self.to_bytes().into_iter(), out)?;
Ok(IsNull::No) Ok(IsNull::No)
} }

View File

@ -12,9 +12,7 @@ fn base() -> NaiveDateTime {
} }
impl FromSql for NaiveDateTime { impl FromSql for NaiveDateTime {
fn from_sql(_: &Type, fn from_sql(_: &Type, raw: &[u8]) -> Result<NaiveDateTime, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<NaiveDateTime, Box<Error + Sync + Send>> {
let t = types::timestamp_from_sql(raw)?; let t = types::timestamp_from_sql(raw)?;
Ok(base() + Duration::microseconds(t)) Ok(base() + Duration::microseconds(t))
} }
@ -23,10 +21,7 @@ impl FromSql for NaiveDateTime {
} }
impl ToSql for NaiveDateTime { impl ToSql for NaiveDateTime {
fn to_sql(&self, fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let time = match self.signed_duration_since(base()).num_microseconds() { let time = match self.signed_duration_since(base()).num_microseconds() {
Some(time) => time, Some(time) => time,
None => return Err("value too large to transmit".into()), None => return Err("value too large to transmit".into()),
@ -40,9 +35,7 @@ impl ToSql for NaiveDateTime {
} }
impl FromSql for DateTime<Utc> { impl FromSql for DateTime<Utc> {
fn from_sql(type_: &Type, fn from_sql(type_: &Type, raw: &[u8]) -> Result<DateTime<Utc>, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<DateTime<Utc>, Box<Error + Sync + Send>> {
let naive = NaiveDateTime::from_sql(type_, raw)?; let naive = NaiveDateTime::from_sql(type_, raw)?;
Ok(DateTime::from_utc(naive, Utc)) Ok(DateTime::from_utc(naive, Utc))
} }
@ -51,10 +44,7 @@ impl FromSql for DateTime<Utc> {
} }
impl ToSql for DateTime<Utc> { impl ToSql for DateTime<Utc> {
fn to_sql(&self, fn to_sql(&self, type_: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
type_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
self.naive_utc().to_sql(type_, w) self.naive_utc().to_sql(type_, w)
} }
@ -63,9 +53,7 @@ impl ToSql for DateTime<Utc> {
} }
impl FromSql for DateTime<Local> { impl FromSql for DateTime<Local> {
fn from_sql(type_: &Type, fn from_sql(type_: &Type, raw: &[u8]) -> Result<DateTime<Local>, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<DateTime<Local>, Box<Error + Sync + Send>> {
let utc = DateTime::<Utc>::from_sql(type_, raw)?; let utc = DateTime::<Utc>::from_sql(type_, raw)?;
Ok(utc.with_timezone(&Local)) Ok(utc.with_timezone(&Local))
} }
@ -74,10 +62,11 @@ impl FromSql for DateTime<Local> {
} }
impl ToSql for DateTime<Local> { impl ToSql for DateTime<Local> {
fn to_sql(&self, fn to_sql(
type_: &Type, &self,
mut w: &mut Vec<u8>) type_: &Type,
-> Result<IsNull, Box<Error + Sync + Send>> { mut w: &mut Vec<u8>,
) -> Result<IsNull, Box<Error + Sync + Send>> {
self.with_timezone(&Utc).to_sql(type_, w) self.with_timezone(&Utc).to_sql(type_, w)
} }
@ -86,9 +75,10 @@ impl ToSql for DateTime<Local> {
} }
impl FromSql for DateTime<FixedOffset> { impl FromSql for DateTime<FixedOffset> {
fn from_sql(type_: &Type, fn from_sql(
raw: &[u8]) type_: &Type,
-> Result<DateTime<FixedOffset>, Box<Error + Sync + Send>> { raw: &[u8],
) -> Result<DateTime<FixedOffset>, Box<Error + Sync + Send>> {
let utc = DateTime::<Utc>::from_sql(type_, raw)?; let utc = DateTime::<Utc>::from_sql(type_, raw)?;
Ok(utc.with_timezone(&FixedOffset::east(0))) Ok(utc.with_timezone(&FixedOffset::east(0)))
} }
@ -97,10 +87,7 @@ impl FromSql for DateTime<FixedOffset> {
} }
impl ToSql for DateTime<FixedOffset> { impl ToSql for DateTime<FixedOffset> {
fn to_sql(&self, fn to_sql(&self, type_: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
type_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
self.with_timezone(&Utc).to_sql(type_, w) self.with_timezone(&Utc).to_sql(type_, w)
} }
@ -109,9 +96,7 @@ impl ToSql for DateTime<FixedOffset> {
} }
impl FromSql for NaiveDate { impl FromSql for NaiveDate {
fn from_sql(_: &Type, fn from_sql(_: &Type, raw: &[u8]) -> Result<NaiveDate, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<NaiveDate, Box<Error + Sync + Send>> {
let jd = types::date_from_sql(raw)?; let jd = types::date_from_sql(raw)?;
Ok(base().date() + Duration::days(jd as i64)) Ok(base().date() + Duration::days(jd as i64))
} }
@ -120,10 +105,7 @@ impl FromSql for NaiveDate {
} }
impl ToSql for NaiveDate { impl ToSql for NaiveDate {
fn to_sql(&self, fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let jd = self.signed_duration_since(base().date()).num_days(); let jd = self.signed_duration_since(base().date()).num_days();
if jd > i32::max_value() as i64 || jd < i32::min_value() as i64 { if jd > i32::max_value() as i64 || jd < i32::min_value() as i64 {
return Err("value too large to transmit".into()); return Err("value too large to transmit".into());
@ -138,9 +120,7 @@ impl ToSql for NaiveDate {
} }
impl FromSql for NaiveTime { impl FromSql for NaiveTime {
fn from_sql(_: &Type, fn from_sql(_: &Type, raw: &[u8]) -> Result<NaiveTime, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<NaiveTime, Box<Error + Sync + Send>> {
let usec = types::time_from_sql(raw)?; let usec = types::time_from_sql(raw)?;
Ok(NaiveTime::from_hms(0, 0, 0) + Duration::microseconds(usec)) Ok(NaiveTime::from_hms(0, 0, 0) + Duration::microseconds(usec))
} }
@ -149,10 +129,7 @@ impl FromSql for NaiveTime {
} }
impl ToSql for NaiveTime { impl ToSql for NaiveTime {
fn to_sql(&self, fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let delta = self.signed_duration_since(NaiveTime::from_hms(0, 0, 0)); let delta = self.signed_duration_since(NaiveTime::from_hms(0, 0, 0));
let time = match delta.num_microseconds() { let time = match delta.num_microseconds() {
Some(time) => time, Some(time) => time,

View File

@ -7,9 +7,7 @@ use postgres_protocol::types;
use types::{FromSql, ToSql, Type, IsNull}; use types::{FromSql, ToSql, Type, IsNull};
impl FromSql for MacAddress { impl FromSql for MacAddress {
fn from_sql(_: &Type, fn from_sql(_: &Type, raw: &[u8]) -> Result<MacAddress, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<MacAddress, Box<Error + Sync + Send>> {
let bytes = types::macaddr_from_sql(raw)?; let bytes = types::macaddr_from_sql(raw)?;
Ok(MacAddress::new(bytes)) Ok(MacAddress::new(bytes))
} }
@ -18,10 +16,7 @@ impl FromSql for MacAddress {
} }
impl ToSql for MacAddress { impl ToSql for MacAddress {
fn to_sql(&self, fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let mut bytes = [0; 6]; let mut bytes = [0; 6];
bytes.copy_from_slice(self.as_bytes()); bytes.copy_from_slice(self.as_bytes());
types::macaddr_to_sql(bytes, w); types::macaddr_to_sql(bytes, w);

View File

@ -47,11 +47,13 @@ macro_rules! to_sql_checked {
// WARNING: this function is not considered part of this crate's public API. // WARNING: this function is not considered part of this crate's public API.
// It is subject to change at any time. // It is subject to change at any time.
#[doc(hidden)] #[doc(hidden)]
pub fn __to_sql_checked<T>(v: &T, pub fn __to_sql_checked<T>(
ty: &Type, v: &T,
out: &mut Vec<u8>) ty: &Type,
-> Result<IsNull, Box<Error + Sync + Send>> out: &mut Vec<u8>,
where T: ToSql ) -> Result<IsNull, Box<Error + Sync + Send>>
where
T: ToSql,
{ {
if !T::accepts(ty) { if !T::accepts(ty) {
return Err(Box::new(WrongType(ty.clone()))); return Err(Box::new(WrongType(ty.clone())));
@ -156,11 +158,11 @@ impl Other {
#[doc(hidden)] #[doc(hidden)]
pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Other { pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Other {
Other(Arc::new(OtherInner { Other(Arc::new(OtherInner {
name: name, name: name,
oid: oid, oid: oid,
kind: kind, kind: kind,
schema: schema, schema: schema,
})) }))
} }
} }
@ -210,9 +212,11 @@ pub struct WrongType(Type);
impl fmt::Display for WrongType { impl fmt::Display for WrongType {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, write!(
"cannot convert to or from a Postgres value of type `{}`", fmt,
self.0) "cannot convert to or from a Postgres value of type `{}`",
self.0
)
} }
} }
@ -402,9 +406,10 @@ simple_from!(f32, float4_from_sql, Type::Float4);
simple_from!(f64, float8_from_sql, Type::Float8); simple_from!(f64, float8_from_sql, Type::Float8);
impl FromSql for HashMap<String, Option<String>> { impl FromSql for HashMap<String, Option<String>> {
fn from_sql(_: &Type, fn from_sql(
raw: &[u8]) _: &Type,
-> Result<HashMap<String, Option<String>>, Box<Error + Sync + Send>> { raw: &[u8],
) -> Result<HashMap<String, Option<String>>, Box<Error + Sync + Send>> {
types::hstore_from_sql(raw)? types::hstore_from_sql(raw)?
.map(|(k, v)| (k.to_owned(), v.map(str::to_owned))) .map(|(k, v)| (k.to_owned(), v.map(str::to_owned)))
.collect() .collect()
@ -492,24 +497,29 @@ pub trait ToSql: fmt::Debug {
/// `NULL`. If this is the case, implementations **must not** write /// `NULL`. If this is the case, implementations **must not** write
/// anything to `out`. /// anything to `out`.
fn to_sql(&self, ty: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> fn to_sql(&self, ty: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>>
where Self: Sized; where
Self: Sized;
/// Determines if a value of this type can be converted to the specified /// Determines if a value of this type can be converted to the specified
/// Postgres `Type`. /// Postgres `Type`.
fn accepts(ty: &Type) -> bool where Self: Sized; fn accepts(ty: &Type) -> bool
where
Self: Sized;
/// An adaptor method used internally by Rust-Postgres. /// An adaptor method used internally by Rust-Postgres.
/// ///
/// *All* implementations of this method should be generated by the /// *All* implementations of this method should be generated by the
/// `to_sql_checked!()` macro. /// `to_sql_checked!()` macro.
fn to_sql_checked(&self, fn to_sql_checked(
ty: &Type, &self,
out: &mut Vec<u8>) ty: &Type,
-> Result<IsNull, Box<Error + Sync + Send>>; out: &mut Vec<u8>,
) -> Result<IsNull, Box<Error + Sync + Send>>;
} }
impl<'a, T> ToSql for &'a T impl<'a, T> ToSql for &'a T
where T: ToSql where
T: ToSql,
{ {
fn to_sql(&self, ty: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> { fn to_sql(&self, ty: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
(*self).to_sql(ty, out) (*self).to_sql(ty, out)
@ -549,15 +559,17 @@ impl<'a, T: ToSql> ToSql for &'a [T] {
lower_bound: 1, lower_bound: 1,
}; };
types::array_to_sql(Some(dimension), types::array_to_sql(
true, Some(dimension),
member_type.oid(), true,
self.iter(), member_type.oid(),
|e, w| match e.to_sql(member_type, w)? { self.iter(),
IsNull::No => Ok(postgres_protocol::IsNull::No), |e, w| match e.to_sql(member_type, w)? {
IsNull::Yes => Ok(postgres_protocol::IsNull::Yes), IsNull::No => Ok(postgres_protocol::IsNull::No),
}, IsNull::Yes => Ok(postgres_protocol::IsNull::Yes),
w)?; },
w,
)?;
Ok(IsNull::No) Ok(IsNull::No)
} }
@ -664,8 +676,10 @@ simple_to!(f64, float8_to_sql, Type::Float8);
impl ToSql for HashMap<String, Option<String>> { impl ToSql for HashMap<String, Option<String>> {
fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> { fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
types::hstore_to_sql(self.iter().map(|(k, v)| (&**k, v.as_ref().map(|v| &**v))), types::hstore_to_sql(
w)?; self.iter().map(|(k, v)| (&**k, v.as_ref().map(|v| &**v))),
w,
)?;
Ok(IsNull::No) Ok(IsNull::No)
} }

View File

@ -7,9 +7,7 @@ use std::error::Error;
use types::{FromSql, ToSql, IsNull, Type}; use types::{FromSql, ToSql, IsNull, Type};
impl FromSql for json::Json { impl FromSql for json::Json {
fn from_sql(ty: &Type, fn from_sql(ty: &Type, mut raw: &[u8]) -> Result<json::Json, Box<Error + Sync + Send>> {
mut raw: &[u8])
-> Result<json::Json, Box<Error + Sync + Send>> {
if let Type::Jsonb = *ty { if let Type::Jsonb = *ty {
let mut b = [0; 1]; let mut b = [0; 1];
raw.read_exact(&mut b)?; raw.read_exact(&mut b)?;
@ -25,10 +23,7 @@ impl FromSql for json::Json {
} }
impl ToSql for json::Json { impl ToSql for json::Json {
fn to_sql(&self, fn to_sql(&self, ty: &Type, mut out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
ty: &Type,
mut out: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
if let Type::Jsonb = *ty { if let Type::Jsonb = *ty {
out.push(1); out.push(1);
} }

View File

@ -7,9 +7,7 @@ use std::io::{Read, Write};
use types::{FromSql, ToSql, IsNull, Type}; use types::{FromSql, ToSql, IsNull, Type};
impl FromSql for Value { impl FromSql for Value {
fn from_sql(ty: &Type, fn from_sql(ty: &Type, mut raw: &[u8]) -> Result<Value, Box<Error + Sync + Send>> {
mut raw: &[u8])
-> Result<Value, Box<Error + Sync + Send>> {
if let Type::Jsonb = *ty { if let Type::Jsonb = *ty {
let mut b = [0; 1]; let mut b = [0; 1];
raw.read_exact(&mut b)?; raw.read_exact(&mut b)?;
@ -25,10 +23,7 @@ impl FromSql for Value {
} }
impl ToSql for Value { impl ToSql for Value {
fn to_sql(&self, fn to_sql(&self, ty: &Type, mut out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
ty: &Type,
mut out: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
if let Type::Jsonb = *ty { if let Type::Jsonb = *ty {
out.push(1); out.push(1);
} }

View File

@ -16,9 +16,7 @@ pub enum Date<T> {
} }
impl<T: FromSql> FromSql for Date<T> { impl<T: FromSql> FromSql for Date<T> {
fn from_sql(ty: &Type, fn from_sql(ty: &Type, raw: &[u8]) -> Result<Self, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<Self, Box<Error + Sync + Send>> {
match types::date_from_sql(raw)? { match types::date_from_sql(raw)? {
i32::MAX => Ok(Date::PosInfinity), i32::MAX => Ok(Date::PosInfinity),
i32::MIN => Ok(Date::NegInfinity), i32::MIN => Ok(Date::NegInfinity),
@ -31,10 +29,7 @@ impl<T: FromSql> FromSql for Date<T> {
} }
} }
impl<T: ToSql> ToSql for Date<T> { impl<T: ToSql> ToSql for Date<T> {
fn to_sql(&self, fn to_sql(&self, ty: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
ty: &Type,
out: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let value = match *self { let value = match *self {
Date::PosInfinity => i32::MAX, Date::PosInfinity => i32::MAX,
Date::NegInfinity => i32::MIN, Date::NegInfinity => i32::MIN,
@ -65,9 +60,7 @@ pub enum Timestamp<T> {
} }
impl<T: FromSql> FromSql for Timestamp<T> { impl<T: FromSql> FromSql for Timestamp<T> {
fn from_sql(ty: &Type, fn from_sql(ty: &Type, raw: &[u8]) -> Result<Self, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<Self, Box<Error + Sync + Send>> {
match types::timestamp_from_sql(raw)? { match types::timestamp_from_sql(raw)? {
i64::MAX => Ok(Timestamp::PosInfinity), i64::MAX => Ok(Timestamp::PosInfinity),
i64::MIN => Ok(Timestamp::NegInfinity), i64::MIN => Ok(Timestamp::NegInfinity),
@ -81,10 +74,7 @@ impl<T: FromSql> FromSql for Timestamp<T> {
} }
impl<T: ToSql> ToSql for Timestamp<T> { impl<T: ToSql> ToSql for Timestamp<T> {
fn to_sql(&self, fn to_sql(&self, ty: &Type, out: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
ty: &Type,
out: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let value = match *self { let value = match *self {
Timestamp::PosInfinity => i64::MAX, Timestamp::PosInfinity => i64::MAX,
Timestamp::NegInfinity => i64::MIN, Timestamp::NegInfinity => i64::MIN,

View File

@ -13,9 +13,7 @@ const NSEC_PER_USEC: i64 = 1_000;
const TIME_SEC_CONVERSION: i64 = 946684800; const TIME_SEC_CONVERSION: i64 = 946684800;
impl FromSql for Timespec { impl FromSql for Timespec {
fn from_sql(_: &Type, fn from_sql(_: &Type, raw: &[u8]) -> Result<Timespec, Box<Error + Sync + Send>> {
raw: &[u8])
-> Result<Timespec, Box<Error + Sync + Send>> {
let t = types::timestamp_from_sql(raw)?; let t = types::timestamp_from_sql(raw)?;
let mut sec = t / USEC_PER_SEC + TIME_SEC_CONVERSION; let mut sec = t / USEC_PER_SEC + TIME_SEC_CONVERSION;
let mut usec = t % USEC_PER_SEC; let mut usec = t % USEC_PER_SEC;
@ -32,10 +30,7 @@ impl FromSql for Timespec {
} }
impl ToSql for Timespec { impl ToSql for Timespec {
fn to_sql(&self, fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
let t = (self.sec - TIME_SEC_CONVERSION) * USEC_PER_SEC + self.nsec as i64 / NSEC_PER_USEC; let t = (self.sec - TIME_SEC_CONVERSION) * USEC_PER_SEC + self.nsec as i64 / NSEC_PER_USEC;
types::timestamp_to_sql(t, w); types::timestamp_to_sql(t, w);
Ok(IsNull::No) Ok(IsNull::No)

View File

@ -16,10 +16,7 @@ impl FromSql for Uuid {
} }
impl ToSql for Uuid { impl ToSql for Uuid {
fn to_sql(&self, fn to_sql(&self, _: &Type, w: &mut Vec<u8>) -> Result<IsNull, Box<Error + Sync + Send>> {
_: &Type,
w: &mut Vec<u8>)
-> Result<IsNull, Box<Error + Sync + Send>> {
types::uuid_to_sql(*self.as_bytes(), w); types::uuid_to_sql(*self.as_bytes(), w);
Ok(IsNull::No) Ok(IsNull::No)
} }

View File

@ -7,7 +7,8 @@ use postgres::{Connection, TlsMode};
#[bench] #[bench]
fn bench_naiive_execute(b: &mut test::Bencher) { fn bench_naiive_execute(b: &mut test::Bencher) {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap(); let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]).unwrap(); conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])
.unwrap();
b.iter(|| { b.iter(|| {
let stmt = conn.prepare("UPDATE foo SET id = 1").unwrap(); let stmt = conn.prepare("UPDATE foo SET id = 1").unwrap();
@ -20,9 +21,8 @@ fn bench_naiive_execute(b: &mut test::Bencher) {
#[bench] #[bench]
fn bench_execute(b: &mut test::Bencher) { fn bench_execute(b: &mut test::Bencher) {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap(); let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]).unwrap(); conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])
.unwrap();
b.iter(|| { b.iter(|| conn.execute("UPDATE foo SET id = 1", &[]).unwrap());
conn.execute("UPDATE foo SET id = 1", &[]).unwrap()
});
} }

View File

@ -178,15 +178,17 @@ impl HandleNotice for LoggingNoticeHandler {
/// }); /// });
/// postgres::cancel_query(url, TlsMode::None, &cancel_data).unwrap(); /// postgres::cancel_query(url, TlsMode::None, &cancel_data).unwrap();
/// ``` /// ```
pub fn cancel_query<T>(params: T, pub fn cancel_query<T>(
tls: TlsMode, params: T,
data: &CancelData) tls: TlsMode,
-> result::Result<(), ConnectError> data: &CancelData,
where T: IntoConnectParams ) -> result::Result<(), ConnectError>
where
T: IntoConnectParams,
{ {
let params = params let params = params.into_connect_params().map_err(
.into_connect_params() ConnectError::ConnectParams,
.map_err(ConnectError::ConnectParams)?; )?;
let mut socket = priv_io::initialize_stream(&params, tls)?; let mut socket = priv_io::initialize_stream(&params, tls)?;
let mut buf = vec![]; let mut buf = vec![];
@ -198,13 +200,17 @@ pub fn cancel_query<T>(params: T,
} }
fn bad_response() -> io::Error { fn bad_response() -> io::Error {
io::Error::new(io::ErrorKind::InvalidInput, io::Error::new(
"the server returned an unexpected response") io::ErrorKind::InvalidInput,
"the server returned an unexpected response",
)
} }
fn desynchronized() -> io::Error { fn desynchronized() -> io::Error {
io::Error::new(io::ErrorKind::Other, io::Error::new(
"communication with the server has desynchronized due to an earlier IO error") io::ErrorKind::Other,
"communication with the server has desynchronized due to an earlier IO error",
)
} }
/// Specifies the TLS support requested for a new connection. /// Specifies the TLS support requested for a new connection.
@ -252,18 +258,20 @@ impl Drop for InnerConnection {
impl InnerConnection { impl InnerConnection {
fn connect<T>(params: T, tls: TlsMode) -> result::Result<InnerConnection, ConnectError> fn connect<T>(params: T, tls: TlsMode) -> result::Result<InnerConnection, ConnectError>
where T: IntoConnectParams where
T: IntoConnectParams,
{ {
let params = params let params = params.into_connect_params().map_err(
.into_connect_params() ConnectError::ConnectParams,
.map_err(ConnectError::ConnectParams)?; )?;
let stream = priv_io::initialize_stream(&params, tls)?; let stream = priv_io::initialize_stream(&params, tls)?;
let user = match params.user() { let user = match params.user() {
Some(user) => user, Some(user) => user,
None => { None => {
return Err(ConnectError::ConnectParams("User missing from connection parameters" return Err(ConnectError::ConnectParams(
.into())); "User missing from connection parameters".into(),
));
} }
}; };
@ -299,8 +307,9 @@ impl InnerConnection {
} }
let options = options.iter().map(|&(ref a, ref b)| (&**a, &**b)); let options = options.iter().map(|&(ref a, ref b)| (&**a, &**b));
conn.stream conn.stream.write_message(
.write_message(|buf| frontend::startup_message(options, buf))?; |buf| frontend::startup_message(options, buf),
)?;
conn.stream.flush()?; conn.stream.flush()?;
conn.handle_auth(user)?; conn.handle_auth(user)?;
@ -332,17 +341,20 @@ impl InnerConnection {
} }
} }
backend::Message::ParameterStatus(body) => { backend::Message::ParameterStatus(body) => {
self.parameters self.parameters.insert(
.insert(body.name()?.to_owned(), body.value()?.to_owned()); body.name()?.to_owned(),
body.value()?.to_owned(),
);
} }
val => return Ok(val), val => return Ok(val),
} }
} }
} }
fn read_message_with_notification_timeout(&mut self, fn read_message_with_notification_timeout(
timeout: Duration) &mut self,
-> io::Result<Option<backend::Message>> { timeout: Duration,
) -> io::Result<Option<backend::Message>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop { loop {
match try_desync!(self, self.stream.read_message_timeout(timeout)) { match try_desync!(self, self.stream.read_message_timeout(timeout)) {
@ -352,16 +364,19 @@ impl InnerConnection {
} }
} }
Some(backend::Message::ParameterStatus(body)) => { Some(backend::Message::ParameterStatus(body)) => {
self.parameters self.parameters.insert(
.insert(body.name()?.to_owned(), body.value()?.to_owned()); body.name()?.to_owned(),
body.value()?.to_owned(),
);
} }
val => return Ok(val), val => return Ok(val),
} }
} }
} }
fn read_message_with_notification_nonblocking(&mut self) fn read_message_with_notification_nonblocking(
-> io::Result<Option<backend::Message>> { &mut self,
) -> io::Result<Option<backend::Message>> {
debug_assert!(!self.desynchronized); debug_assert!(!self.desynchronized);
loop { loop {
match try_desync!(self, self.stream.read_message_nonblocking()) { match try_desync!(self, self.stream.read_message_nonblocking()) {
@ -371,8 +386,10 @@ impl InnerConnection {
} }
} }
Some(backend::Message::ParameterStatus(body)) => { Some(backend::Message::ParameterStatus(body)) => {
self.parameters self.parameters.insert(
.insert(body.name()?.to_owned(), body.value()?.to_owned()); body.name()?.to_owned(),
body.value()?.to_owned(),
);
} }
val => return Ok(val), val => return Ok(val),
} }
@ -383,12 +400,11 @@ impl InnerConnection {
loop { loop {
match self.read_message_with_notification()? { match self.read_message_with_notification()? {
backend::Message::NotificationResponse(body) => { backend::Message::NotificationResponse(body) => {
self.notifications self.notifications.push_back(Notification {
.push_back(Notification { process_id: body.process_id(),
process_id: body.process_id(), channel: body.channel()?.to_owned(),
channel: body.channel()?.to_owned(), payload: body.message()?.to_owned(),
payload: body.message()?.to_owned(), })
})
} }
val => return Ok(val), val => return Ok(val),
} }
@ -399,50 +415,46 @@ impl InnerConnection {
match self.read_message()? { match self.read_message()? {
backend::Message::AuthenticationOk => return Ok(()), backend::Message::AuthenticationOk => return Ok(()),
backend::Message::AuthenticationCleartextPassword => { backend::Message::AuthenticationCleartextPassword => {
let pass = user.password() let pass = user.password().ok_or_else(|| {
.ok_or_else(|| { ConnectError::ConnectParams("a password was requested but not provided".into())
ConnectError::ConnectParams("a password was requested but not provided" })?;
.into()) self.stream.write_message(
})?; |buf| frontend::password_message(pass, buf),
self.stream )?;
.write_message(|buf| frontend::password_message(pass, buf))?;
self.stream.flush()?; self.stream.flush()?;
} }
backend::Message::AuthenticationMd5Password(body) => { backend::Message::AuthenticationMd5Password(body) => {
let pass = user.password() let pass = user.password().ok_or_else(|| {
.ok_or_else(|| { ConnectError::ConnectParams("a password was requested but not provided".into())
ConnectError::ConnectParams("a password was requested but not provided" })?;
.into())
})?;
let output = let output =
authentication::md5_hash(user.name().as_bytes(), pass.as_bytes(), body.salt()); authentication::md5_hash(user.name().as_bytes(), pass.as_bytes(), body.salt());
self.stream self.stream.write_message(
.write_message(|buf| frontend::password_message(&output, buf))?; |buf| frontend::password_message(&output, buf),
)?;
self.stream.flush()?; self.stream.flush()?;
} }
backend::Message::AuthenticationSasl(body) => { backend::Message::AuthenticationSasl(body) => {
// count to validate the entire message body. // count to validate the entire message body.
if body.mechanisms() if body.mechanisms()
.filter(|m| *m == sasl::SCRAM_SHA_256) .filter(|m| *m == sasl::SCRAM_SHA_256)
.count()? == 0 { .count()? == 0
return Err(ConnectError::Io(io::Error::new(io::ErrorKind::Other, {
"unsupported authentication"))); return Err(ConnectError::Io(io::Error::new(
io::ErrorKind::Other,
"unsupported authentication",
)));
} }
let pass = user.password() let pass = user.password().ok_or_else(|| {
.ok_or_else(|| { ConnectError::ConnectParams("a password was requested but not provided".into())
ConnectError::ConnectParams("a password was requested but not provided" })?;
.into())
})?;
let mut scram = ScramSha256::new(pass.as_bytes())?; let mut scram = ScramSha256::new(pass.as_bytes())?;
self.stream self.stream.write_message(|buf| {
.write_message(|buf| { frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), buf)
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, })?;
scram.message(),
buf)
})?;
self.stream.flush()?; self.stream.flush()?;
let body = match self.read_message()? { let body = match self.read_message()? {
@ -455,8 +467,9 @@ impl InnerConnection {
scram.update(body.data())?; scram.update(body.data())?;
self.stream self.stream.write_message(|buf| {
.write_message(|buf| frontend::sasl_response(scram.message(), buf))?; frontend::sasl_response(scram.message(), buf)
})?;
self.stream.flush()?; self.stream.flush()?;
let body = match self.read_message()? { let body = match self.read_message()? {
@ -473,8 +486,10 @@ impl InnerConnection {
backend::Message::AuthenticationScmCredential | backend::Message::AuthenticationScmCredential |
backend::Message::AuthenticationGss | backend::Message::AuthenticationGss |
backend::Message::AuthenticationSspi => { backend::Message::AuthenticationSspi => {
return Err(ConnectError::Io(io::Error::new(io::ErrorKind::Other, return Err(ConnectError::Io(io::Error::new(
"unsupported authentication"))) io::ErrorKind::Other,
"unsupported authentication",
)))
} }
backend::Message::ErrorResponse(body) => return Err(connect_err(&mut body.fields())), backend::Message::ErrorResponse(body) => return Err(connect_err(&mut body.fields())),
_ => return Err(ConnectError::Io(bad_response())), _ => return Err(ConnectError::Io(bad_response())),
@ -494,12 +509,15 @@ impl InnerConnection {
fn raw_prepare(&mut self, stmt_name: &str, query: &str) -> Result<(Vec<Type>, Vec<Column>)> { fn raw_prepare(&mut self, stmt_name: &str, query: &str) -> Result<(Vec<Type>, Vec<Column>)> {
debug!("preparing query with name `{}`: {}", stmt_name, query); debug!("preparing query with name `{}`: {}", stmt_name, query);
self.stream self.stream.write_message(|buf| {
.write_message(|buf| frontend::parse(stmt_name, query, None, buf))?; frontend::parse(stmt_name, query, None, buf)
self.stream })?;
.write_message(|buf| frontend::describe(b'S', stmt_name, buf))?; self.stream.write_message(
self.stream |buf| frontend::describe(b'S', stmt_name, buf),
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; )?;
self.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
self.stream.flush()?; self.stream.flush()?;
match self.read_message()? { match self.read_message()? {
@ -534,9 +552,11 @@ impl InnerConnection {
Some(body) => { Some(body) => {
body.fields() body.fields()
.and_then(|field| { .and_then(|field| {
Ok(Column::new(field.name().to_owned(), Ok(Column::new(
self.get_type(field.type_oid())?)) field.name().to_owned(),
}) self.get_type(field.type_oid())?,
))
})
.collect()? .collect()?
} }
None => vec![], None => vec![],
@ -546,7 +566,8 @@ impl InnerConnection {
} }
fn read_rows<F>(&mut self, mut consumer: F) -> Result<bool> fn read_rows<F>(&mut self, mut consumer: F) -> Result<bool>
where F: FnMut(RowData) where
F: FnMut(RowData),
{ {
let more_rows; let more_rows;
loop { loop {
@ -566,12 +587,12 @@ impl InnerConnection {
return Err(err(&mut body.fields())); return Err(err(&mut body.fields()));
} }
backend::Message::CopyInResponse(_) => { backend::Message::CopyInResponse(_) => {
self.stream self.stream.write_message(|buf| {
.write_message(|buf| { frontend::copy_fail("COPY queries cannot be directly executed", buf)
frontend::copy_fail("COPY queries cannot be directly executed", buf) })?;
})?; self.stream.write_message(
self.stream |buf| Ok::<(), io::Error>(frontend::sync(buf)),
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; )?;
self.stream.flush()?; self.stream.flush()?;
} }
backend::Message::CopyOutResponse(_) => { backend::Message::CopyOutResponse(_) => {
@ -580,9 +601,11 @@ impl InnerConnection {
break; break;
} }
} }
return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput, return Err(Error::Io(io::Error::new(
"COPY queries cannot be directly \ io::ErrorKind::InvalidInput,
executed"))); "COPY queries cannot be directly \
executed",
)));
} }
_ => { _ => {
self.desynchronized = true; self.desynchronized = true;
@ -594,36 +617,42 @@ impl InnerConnection {
Ok(more_rows) Ok(more_rows)
} }
fn raw_execute(&mut self, fn raw_execute(
stmt_name: &str, &mut self,
portal_name: &str, stmt_name: &str,
row_limit: i32, portal_name: &str,
param_types: &[Type], row_limit: i32,
params: &[&ToSql]) param_types: &[Type],
-> Result<()> { params: &[&ToSql],
assert!(param_types.len() == params.len(), ) -> Result<()> {
"expected {} parameters but got {}", assert!(
param_types.len(), param_types.len() == params.len(),
params.len()); "expected {} parameters but got {}",
debug!("executing statement {} with parameters: {:?}", param_types.len(),
stmt_name, params.len()
params); );
debug!(
"executing statement {} with parameters: {:?}",
stmt_name,
params
);
{ {
let r = self.stream let r = self.stream.write_message(|buf| {
.write_message(|buf| { frontend::bind(
frontend::bind(portal_name, portal_name,
stmt_name, stmt_name,
Some(1), Some(1),
params.iter().zip(param_types), params.iter().zip(param_types),
|(param, ty), buf| match param.to_sql_checked(ty, buf) { |(param, ty), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
Err(e) => Err(e), Err(e) => Err(e),
}, },
Some(1), Some(1),
buf) buf,
}); )
});
match r { match r {
Ok(()) => {} Ok(()) => {}
Err(frontend::BindError::Conversion(e)) => return Err(Error::Conversion(e)), Err(frontend::BindError::Conversion(e)) => return Err(Error::Conversion(e)),
@ -631,10 +660,12 @@ impl InnerConnection {
} }
} }
self.stream self.stream.write_message(|buf| {
.write_message(|buf| frontend::execute(portal_name, row_limit, buf))?; frontend::execute(portal_name, row_limit, buf)
self.stream })?;
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; self.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
self.stream.flush()?; self.stream.flush()?;
match self.read_message()? { match self.read_message()? {
@ -660,10 +691,10 @@ impl InnerConnection {
let stmt_name = self.make_stmt_name(); let stmt_name = self.make_stmt_name();
let (param_types, columns) = self.raw_prepare(&stmt_name, query)?; let (param_types, columns) = self.raw_prepare(&stmt_name, query)?;
let info = Arc::new(StatementInfo { let info = Arc::new(StatementInfo {
name: stmt_name, name: stmt_name,
param_types: param_types, param_types: param_types,
columns: columns, columns: columns,
}); });
Ok(Statement::new(conn, info, Cell::new(0), false)) Ok(Statement::new(conn, info, Cell::new(0), false))
} }
@ -676,12 +707,14 @@ impl InnerConnection {
let stmt_name = self.make_stmt_name(); let stmt_name = self.make_stmt_name();
let (param_types, columns) = self.raw_prepare(&stmt_name, query)?; let (param_types, columns) = self.raw_prepare(&stmt_name, query)?;
let info = Arc::new(StatementInfo { let info = Arc::new(StatementInfo {
name: stmt_name, name: stmt_name,
param_types: param_types, param_types: param_types,
columns: columns, columns: columns,
}); });
self.cached_statements self.cached_statements.insert(
.insert(query.to_owned(), info.clone()); query.to_owned(),
info.clone(),
);
info info
} }
}; };
@ -690,10 +723,12 @@ impl InnerConnection {
} }
fn close_statement(&mut self, name: &str, type_: u8) -> Result<()> { fn close_statement(&mut self, name: &str, type_: u8) -> Result<()> {
self.stream self.stream.write_message(
.write_message(|buf| frontend::close(type_, name, buf))?; |buf| frontend::close(type_, name, buf),
self.stream )?;
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; self.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
self.stream.flush()?; self.stream.flush()?;
let resp = match self.read_message()? { let resp = match self.read_message()? {
backend::Message::CloseComplete => Ok(()), backend::Message::CloseComplete => Ok(()),
@ -723,25 +758,29 @@ impl InnerConnection {
return Ok(()); return Ok(());
} }
match self.raw_prepare(TYPEINFO_QUERY, match self.raw_prepare(
"SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, \ TYPEINFO_QUERY,
"SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, \
t.typbasetype, n.nspname, t.typrelid \ t.typbasetype, n.nspname, t.typrelid \
FROM pg_catalog.pg_type t \ FROM pg_catalog.pg_type t \
LEFT OUTER JOIN pg_catalog.pg_range r ON \ LEFT OUTER JOIN pg_catalog.pg_range r ON \
r.rngtypid = t.oid \ r.rngtypid = t.oid \
INNER JOIN pg_catalog.pg_namespace n ON \ INNER JOIN pg_catalog.pg_namespace n ON \
t.typnamespace = n.oid \ t.typnamespace = n.oid \
WHERE t.oid = $1") { WHERE t.oid = $1",
) {
Ok(..) => {} Ok(..) => {}
// Range types weren't added until Postgres 9.2, so pg_range may not exist // Range types weren't added until Postgres 9.2, so pg_range may not exist
Err(Error::Db(ref e)) if e.code == SqlState::UndefinedTable => { Err(Error::Db(ref e)) if e.code == SqlState::UndefinedTable => {
self.raw_prepare(TYPEINFO_QUERY, self.raw_prepare(
"SELECT t.typname, t.typtype, t.typelem, NULL::OID, \ TYPEINFO_QUERY,
"SELECT t.typname, t.typtype, t.typelem, NULL::OID, \
t.typbasetype, n.nspname, t.typrelid \ t.typbasetype, n.nspname, t.typrelid \
FROM pg_catalog.pg_type t \ FROM pg_catalog.pg_type t \
INNER JOIN pg_catalog.pg_namespace n \ INNER JOIN pg_catalog.pg_namespace n \
ON t.typnamespace = n.oid \ ON t.typnamespace = n.oid \
WHERE t.oid = $1")?; WHERE t.oid = $1",
)?;
} }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
@ -753,27 +792,39 @@ impl InnerConnection {
#[allow(if_not_else)] #[allow(if_not_else)]
fn read_type(&mut self, oid: Oid) -> Result<Other> { fn read_type(&mut self, oid: Oid) -> Result<Other> {
self.setup_typeinfo_query()?; self.setup_typeinfo_query()?;
self.raw_execute(TYPEINFO_QUERY, "", 0, &[Type::Oid], &[&oid])?; self.raw_execute(
TYPEINFO_QUERY,
"",
0,
&[Type::Oid],
&[&oid],
)?;
let mut row = None; let mut row = None;
self.read_rows(|r| row = Some(r))?; self.read_rows(|r| row = Some(r))?;
let get_raw = |i: usize| row.as_ref().and_then(|r| r.get(i)); let get_raw = |i: usize| row.as_ref().and_then(|r| r.get(i));
let (name, type_, elem_oid, rngsubtype, basetype, schema, relid) = { let (name, type_, elem_oid, rngsubtype, basetype, schema, relid) = {
let name = String::from_sql_nullable(&Type::Name, get_raw(0)) let name = String::from_sql_nullable(&Type::Name, get_raw(0)).map_err(
.map_err(Error::Conversion)?; Error::Conversion,
let type_ = i8::from_sql_nullable(&Type::Char, get_raw(1)) )?;
.map_err(Error::Conversion)?; let type_ = i8::from_sql_nullable(&Type::Char, get_raw(1)).map_err(
let elem_oid = Oid::from_sql_nullable(&Type::Oid, get_raw(2)) Error::Conversion,
.map_err(Error::Conversion)?; )?;
let elem_oid = Oid::from_sql_nullable(&Type::Oid, get_raw(2)).map_err(
Error::Conversion,
)?;
let rngsubtype = Option::<Oid>::from_sql_nullable(&Type::Oid, get_raw(3)) let rngsubtype = Option::<Oid>::from_sql_nullable(&Type::Oid, get_raw(3))
.map_err(Error::Conversion)?; .map_err(Error::Conversion)?;
let basetype = Oid::from_sql_nullable(&Type::Oid, get_raw(4)) let basetype = Oid::from_sql_nullable(&Type::Oid, get_raw(4)).map_err(
.map_err(Error::Conversion)?; Error::Conversion,
let schema = String::from_sql_nullable(&Type::Name, get_raw(5)) )?;
.map_err(Error::Conversion)?; let schema = String::from_sql_nullable(&Type::Name, get_raw(5)).map_err(
let relid = Oid::from_sql_nullable(&Type::Oid, get_raw(6)) Error::Conversion,
.map_err(Error::Conversion)?; )?;
let relid = Oid::from_sql_nullable(&Type::Oid, get_raw(6)).map_err(
Error::Conversion,
)?;
(name, type_, elem_oid, rngsubtype, basetype, schema, relid) (name, type_, elem_oid, rngsubtype, basetype, schema, relid)
}; };
@ -802,19 +853,23 @@ impl InnerConnection {
return Ok(()); return Ok(());
} }
match self.raw_prepare(TYPEINFO_ENUM_QUERY, match self.raw_prepare(
"SELECT enumlabel \ TYPEINFO_ENUM_QUERY,
"SELECT enumlabel \
FROM pg_catalog.pg_enum \ FROM pg_catalog.pg_enum \
WHERE enumtypid = $1 \ WHERE enumtypid = $1 \
ORDER BY enumsortorder") { ORDER BY enumsortorder",
) {
Ok(..) => {} Ok(..) => {}
// Postgres 9.0 doesn't have enumsortorder // Postgres 9.0 doesn't have enumsortorder
Err(Error::Db(ref e)) if e.code == SqlState::UndefinedColumn => { Err(Error::Db(ref e)) if e.code == SqlState::UndefinedColumn => {
self.raw_prepare(TYPEINFO_ENUM_QUERY, self.raw_prepare(
"SELECT enumlabel \ TYPEINFO_ENUM_QUERY,
"SELECT enumlabel \
FROM pg_catalog.pg_enum \ FROM pg_catalog.pg_enum \
WHERE enumtypid = $1 \ WHERE enumtypid = $1 \
ORDER BY oid")?; ORDER BY oid",
)?;
} }
Err(e) => return Err(e), Err(e) => return Err(e),
} }
@ -825,14 +880,21 @@ impl InnerConnection {
fn read_enum_variants(&mut self, oid: Oid) -> Result<Vec<String>> { fn read_enum_variants(&mut self, oid: Oid) -> Result<Vec<String>> {
self.setup_typeinfo_enum_query()?; self.setup_typeinfo_enum_query()?;
self.raw_execute(TYPEINFO_ENUM_QUERY, "", 0, &[Type::Oid], &[&oid])?; self.raw_execute(
TYPEINFO_ENUM_QUERY,
"",
0,
&[Type::Oid],
&[&oid],
)?;
let mut rows = vec![]; let mut rows = vec![];
self.read_rows(|row| rows.push(row))?; self.read_rows(|row| rows.push(row))?;
let mut variants = vec![]; let mut variants = vec![];
for row in rows { for row in rows {
variants.push(String::from_sql_nullable(&Type::Name, row.get(0)) variants.push(String::from_sql_nullable(&Type::Name, row.get(0)).map_err(
.map_err(Error::Conversion)?); Error::Conversion,
)?);
} }
Ok(variants) Ok(variants)
@ -843,13 +905,15 @@ impl InnerConnection {
return Ok(()); return Ok(());
} }
self.raw_prepare(TYPEINFO_COMPOSITE_QUERY, self.raw_prepare(
"SELECT attname, atttypid \ TYPEINFO_COMPOSITE_QUERY,
"SELECT attname, atttypid \
FROM pg_catalog.pg_attribute \ FROM pg_catalog.pg_attribute \
WHERE attrelid = $1 \ WHERE attrelid = $1 \
AND NOT attisdropped \ AND NOT attisdropped \
AND attnum > 0 \ AND attnum > 0 \
ORDER BY attnum")?; ORDER BY attnum",
)?;
self.has_typeinfo_composite_query = true; self.has_typeinfo_composite_query = true;
Ok(()) Ok(())
@ -857,17 +921,25 @@ impl InnerConnection {
fn read_composite_fields(&mut self, relid: Oid) -> Result<Vec<Field>> { fn read_composite_fields(&mut self, relid: Oid) -> Result<Vec<Field>> {
self.setup_typeinfo_composite_query()?; self.setup_typeinfo_composite_query()?;
self.raw_execute(TYPEINFO_COMPOSITE_QUERY, "", 0, &[Type::Oid], &[&relid])?; self.raw_execute(
TYPEINFO_COMPOSITE_QUERY,
"",
0,
&[Type::Oid],
&[&relid],
)?;
let mut rows = vec![]; let mut rows = vec![];
self.read_rows(|row| rows.push(row))?; self.read_rows(|row| rows.push(row))?;
let mut fields = vec![]; let mut fields = vec![];
for row in rows { for row in rows {
let (name, type_) = { let (name, type_) = {
let name = String::from_sql_nullable(&Type::Name, row.get(0)) let name = String::from_sql_nullable(&Type::Name, row.get(0)).map_err(
.map_err(Error::Conversion)?; Error::Conversion,
let type_ = Oid::from_sql_nullable(&Type::Oid, row.get(1)) )?;
.map_err(Error::Conversion)?; let type_ = Oid::from_sql_nullable(&Type::Oid, row.get(1)).map_err(
Error::Conversion,
)?;
(name, type_) (name, type_)
}; };
let type_ = self.get_type(type_)?; let type_ = self.get_type(type_)?;
@ -892,8 +964,7 @@ impl InnerConnection {
fn quick_query(&mut self, query: &str) -> Result<Vec<Vec<Option<String>>>> { fn quick_query(&mut self, query: &str) -> Result<Vec<Vec<Option<String>>>> {
check_desync!(self); check_desync!(self);
debug!("executing query: {}", query); debug!("executing query: {}", query);
self.stream self.stream.write_message(|buf| frontend::query(query, buf))?;
.write_message(|buf| frontend::query(query, buf))?;
self.stream.flush()?; self.stream.flush()?;
let mut result = vec![]; let mut result = vec![];
@ -903,18 +974,18 @@ impl InnerConnection {
backend::Message::DataRow(body) => { backend::Message::DataRow(body) => {
let row = body.ranges() let row = body.ranges()
.map(|r| { .map(|r| {
r.map(|r| String::from_utf8_lossy(&body.buffer()[r]).into_owned()) r.map(|r| String::from_utf8_lossy(&body.buffer()[r]).into_owned())
}) })
.collect()?; .collect()?;
result.push(row); result.push(row);
} }
backend::Message::CopyInResponse(_) => { backend::Message::CopyInResponse(_) => {
self.stream self.stream.write_message(|buf| {
.write_message(|buf| { frontend::copy_fail("COPY queries cannot be directly executed", buf)
frontend::copy_fail("COPY queries cannot be directly executed", buf) })?;
})?; self.stream.write_message(
self.stream |buf| Ok::<(), io::Error>(frontend::sync(buf)),
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; )?;
self.stream.flush()?; self.stream.flush()?;
} }
backend::Message::ErrorResponse(body) => { backend::Message::ErrorResponse(body) => {
@ -929,8 +1000,9 @@ impl InnerConnection {
fn finish_inner(&mut self) -> Result<()> { fn finish_inner(&mut self) -> Result<()> {
check_desync!(self); check_desync!(self);
self.stream self.stream.write_message(|buf| {
.write_message(|buf| Ok::<(), io::Error>(frontend::terminate(buf)))?; Ok::<(), io::Error>(frontend::terminate(buf))
})?;
self.stream.flush()?; self.stream.flush()?;
Ok(()) Ok(())
} }
@ -1015,7 +1087,8 @@ impl Connection {
/// # } /// # }
/// ``` /// ```
pub fn connect<T>(params: T, tls: TlsMode) -> result::Result<Connection, ConnectError> pub fn connect<T>(params: T, tls: TlsMode) -> result::Result<Connection, ConnectError>
where T: IntoConnectParams where
T: IntoConnectParams,
{ {
InnerConnection::connect(params, tls).map(|conn| Connection(RefCell::new(conn))) InnerConnection::connect(params, tls).map(|conn| Connection(RefCell::new(conn)))
} }
@ -1050,10 +1123,10 @@ impl Connection {
pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result<u64> { pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result<u64> {
let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query)?; let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query)?;
let info = Arc::new(StatementInfo { let info = Arc::new(StatementInfo {
name: String::new(), name: String::new(),
param_types: param_types, param_types: param_types,
columns: columns, columns: columns,
}); });
let stmt = Statement::new(self, info, Cell::new(0), true); let stmt = Statement::new(self, info, Cell::new(0), true);
stmt.execute(params) stmt.execute(params)
} }
@ -1086,10 +1159,10 @@ impl Connection {
pub fn query(&self, query: &str, params: &[&ToSql]) -> Result<Rows<'static>> { pub fn query(&self, query: &str, params: &[&ToSql]) -> Result<Rows<'static>> {
let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query)?; let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query)?;
let info = Arc::new(StatementInfo { let info = Arc::new(StatementInfo {
name: String::new(), name: String::new(),
param_types: param_types, param_types: param_types,
columns: columns, columns: columns,
}); });
let stmt = Statement::new(self, info, Cell::new(0), true); let stmt = Statement::new(self, info, Cell::new(0), true);
stmt.into_query(params) stmt.into_query(params)
} }
@ -1127,8 +1200,10 @@ impl Connection {
pub fn transaction_with<'a>(&'a self, config: &transaction::Config) -> Result<Transaction<'a>> { pub fn transaction_with<'a>(&'a self, config: &transaction::Config) -> Result<Transaction<'a>> {
let mut conn = self.0.borrow_mut(); let mut conn = self.0.borrow_mut();
check_desync!(conn); check_desync!(conn);
assert!(conn.trans_depth == 0, assert!(
"`transaction` must be called on the active transaction"); conn.trans_depth == 0,
"`transaction` must be called on the active transaction"
);
let mut query = "BEGIN".to_owned(); let mut query = "BEGIN".to_owned();
config.build_command(&mut query); config.build_command(&mut query);
conn.quick_query(&query)?; conn.quick_query(&query)?;
@ -1402,22 +1477,24 @@ trait RowsNew {
} }
trait LazyRowsNew<'trans, 'stmt> { trait LazyRowsNew<'trans, 'stmt> {
fn new(stmt: &'stmt Statement<'stmt>, fn new(
data: VecDeque<RowData>, stmt: &'stmt Statement<'stmt>,
name: String, data: VecDeque<RowData>,
row_limit: i32, name: String,
more_rows: bool, row_limit: i32,
finished: bool, more_rows: bool,
trans: &'trans Transaction<'trans>) finished: bool,
-> LazyRows<'trans, 'stmt>; trans: &'trans Transaction<'trans>,
) -> LazyRows<'trans, 'stmt>;
} }
trait StatementInternals<'conn> { trait StatementInternals<'conn> {
fn new(conn: &'conn Connection, fn new(
info: Arc<StatementInfo>, conn: &'conn Connection,
next_portal_id: Cell<u32>, info: Arc<StatementInfo>,
finished: bool) next_portal_id: Cell<u32>,
-> Statement<'conn>; finished: bool,
) -> Statement<'conn>;
fn conn(&self) -> &'conn Connection; fn conn(&self) -> &'conn Connection;

View File

@ -39,8 +39,9 @@ impl MessageStream {
} }
pub fn write_message<F, E>(&mut self, f: F) -> Result<(), E> pub fn write_message<F, E>(&mut self, f: F) -> Result<(), E>
where F: FnOnce(&mut Vec<u8>) -> Result<(), E>, where
E: From<io::Error> F: FnOnce(&mut Vec<u8>) -> Result<(), E>,
E: From<io::Error>,
{ {
self.out_buf.clear(); self.out_buf.clear();
f(&mut self.out_buf)?; f(&mut self.out_buf)?;
@ -59,8 +60,13 @@ impl MessageStream {
fn read_in(&mut self) -> io::Result<()> { fn read_in(&mut self) -> io::Result<()> {
self.in_buf.reserve(1); self.in_buf.reserve(1);
match self.stream.get_mut().read(unsafe { self.in_buf.bytes_mut() }) { match self.stream.get_mut().read(
Ok(0) => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")), unsafe { self.in_buf.bytes_mut() },
) {
Ok(0) => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
Ok(n) => { Ok(n) => {
unsafe { self.in_buf.advance_mut(n) }; unsafe { self.in_buf.advance_mut(n) };
Ok(()) Ok(())
@ -69,18 +75,20 @@ impl MessageStream {
} }
} }
pub fn read_message_timeout(&mut self, pub fn read_message_timeout(
timeout: Duration) &mut self,
-> io::Result<Option<backend::Message>> { timeout: Duration,
) -> io::Result<Option<backend::Message>> {
if self.in_buf.is_empty() { if self.in_buf.is_empty() {
self.set_read_timeout(Some(timeout))?; self.set_read_timeout(Some(timeout))?;
let r = self.read_in(); let r = self.read_in();
self.set_read_timeout(None)?; self.set_read_timeout(None)?;
match r { match r {
Ok(()) => {}, Ok(()) => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || Err(ref e)
e.kind() == io::ErrorKind::TimedOut => return Ok(None), if e.kind() == io::ErrorKind::WouldBlock ||
e.kind() == io::ErrorKind::TimedOut => return Ok(None),
Err(e) => return Err(e), Err(e) => return Err(e),
} }
} }
@ -95,7 +103,7 @@ impl MessageStream {
self.set_nonblocking(false)?; self.set_nonblocking(false)?;
match r { match r {
Ok(()) => {}, Ok(()) => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e), Err(e) => return Err(e),
} }
@ -225,7 +233,9 @@ fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
let port = params.port(); let port = params.port();
match *params.host() { match *params.host() {
Host::Tcp(ref host) => { Host::Tcp(ref host) => {
Ok(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp)?) Ok(TcpStream::connect(&(&**host, port)).map(
InternalStream::Tcp,
)?)
} }
#[cfg(unix)] #[cfg(unix)]
Host::Unix(ref path) => { Host::Unix(ref path) => {
@ -234,15 +244,18 @@ fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
} }
#[cfg(not(unix))] #[cfg(not(unix))]
Host::Unix(..) => { Host::Unix(..) => {
Err(ConnectError::Io(io::Error::new(io::ErrorKind::InvalidInput, Err(ConnectError::Io(io::Error::new(
"unix sockets are not supported on this system"))) io::ErrorKind::InvalidInput,
"unix sockets are not supported on this system",
)))
} }
} }
} }
pub fn initialize_stream(params: &ConnectParams, pub fn initialize_stream(
tls: TlsMode) params: &ConnectParams,
-> Result<Box<TlsStream>, ConnectError> { tls: TlsMode,
) -> Result<Box<TlsStream>, ConnectError> {
let mut socket = Stream(open_socket(params)?); let mut socket = Stream(open_socket(params)?);
let (tls_required, handshaker) = match tls { let (tls_required, handshaker) = match tls {
@ -272,5 +285,7 @@ pub fn initialize_stream(params: &ConnectParams,
Host::Unix(_) => return Err(ConnectError::Io(::bad_response())), Host::Unix(_) => return Err(ConnectError::Io(::bad_response())),
}; };
handshaker.tls_handshake(host, socket).map_err(ConnectError::Tls) handshaker.tls_handshake(host, socket).map_err(
ConnectError::Tls,
)
} }

View File

@ -39,7 +39,7 @@ impl<'a, T> Deref for MaybeOwned<'a, T> {
pub struct Rows<'compat> { pub struct Rows<'compat> {
stmt_info: Arc<StatementInfo>, stmt_info: Arc<StatementInfo>,
data: Vec<RowData>, data: Vec<RowData>,
_marker: PhantomData<&'compat u8> _marker: PhantomData<&'compat u8>,
} }
impl RowsNew for Rows<'static> { impl RowsNew for Rows<'static> {
@ -196,8 +196,9 @@ impl<'a> Row<'a> {
/// } /// }
/// ``` /// ```
pub fn get<I, T>(&self, idx: I) -> T pub fn get<I, T>(&self, idx: I) -> T
where I: RowIndex + fmt::Debug, where
T: FromSql I: RowIndex + fmt::Debug,
T: FromSql,
{ {
match self.get_inner(&idx) { match self.get_inner(&idx) {
Some(Ok(ok)) => ok, Some(Ok(ok)) => ok,
@ -215,15 +216,17 @@ impl<'a> Row<'a> {
/// if there was an error converting the result value, and `Some(Ok(..))` /// if there was an error converting the result value, and `Some(Ok(..))`
/// on success. /// on success.
pub fn get_opt<I, T>(&self, idx: I) -> Option<Result<T>> pub fn get_opt<I, T>(&self, idx: I) -> Option<Result<T>>
where I: RowIndex, where
T: FromSql I: RowIndex,
T: FromSql,
{ {
self.get_inner(&idx) self.get_inner(&idx)
} }
fn get_inner<I, T>(&self, idx: &I) -> Option<Result<T>> fn get_inner<I, T>(&self, idx: &I) -> Option<Result<T>>
where I: RowIndex, where
T: FromSql I: RowIndex,
T: FromSql,
{ {
let idx = match idx.idx(&self.stmt_info.columns) { let idx = match idx.idx(&self.stmt_info.columns) {
Some(idx) => idx, Some(idx) => idx,
@ -244,7 +247,8 @@ impl<'a> Row<'a> {
/// ///
/// Panics if the index does not reference a column. /// Panics if the index does not reference a column.
pub fn get_bytes<I>(&self, idx: I) -> Option<&[u8]> pub fn get_bytes<I>(&self, idx: I) -> Option<&[u8]>
where I: RowIndex + fmt::Debug where
I: RowIndex + fmt::Debug,
{ {
match idx.idx(&self.stmt_info.columns) { match idx.idx(&self.stmt_info.columns) {
Some(idx) => self.data.get(idx), Some(idx) => self.data.get(idx),
@ -281,7 +285,9 @@ impl<'a> RowIndex for &'a str {
// FIXME ASCII-only case insensitivity isn't really the right thing to // FIXME ASCII-only case insensitivity isn't really the right thing to
// do. Postgres itself uses a dubious wrapper around tolower and JDBC // do. Postgres itself uses a dubious wrapper around tolower and JDBC
// uses the US locale. // uses the US locale.
columns.iter().position(|d| d.name().eq_ignore_ascii_case(*self)) columns.iter().position(
|d| d.name().eq_ignore_ascii_case(*self),
)
} }
} }
@ -297,14 +303,15 @@ pub struct LazyRows<'trans, 'stmt> {
} }
impl<'trans, 'stmt> LazyRowsNew<'trans, 'stmt> for LazyRows<'trans, 'stmt> { impl<'trans, 'stmt> LazyRowsNew<'trans, 'stmt> for LazyRows<'trans, 'stmt> {
fn new(stmt: &'stmt Statement<'stmt>, fn new(
data: VecDeque<RowData>, stmt: &'stmt Statement<'stmt>,
name: String, data: VecDeque<RowData>,
row_limit: i32, name: String,
more_rows: bool, row_limit: i32,
finished: bool, more_rows: bool,
trans: &'trans Transaction<'trans>) finished: bool,
-> LazyRows<'trans, 'stmt> { trans: &'trans Transaction<'trans>,
) -> LazyRows<'trans, 'stmt> {
LazyRows { LazyRows {
stmt: stmt, stmt: stmt,
data: data, data: data,
@ -346,10 +353,18 @@ impl<'trans, 'stmt> LazyRows<'trans, 'stmt> {
fn execute(&mut self) -> Result<()> { fn execute(&mut self) -> Result<()> {
let mut conn = self.stmt.conn().0.borrow_mut(); let mut conn = self.stmt.conn().0.borrow_mut();
conn.stream.write_message(|buf| frontend::execute(&self.name, self.row_limit, buf))?; conn.stream.write_message(|buf| {
conn.stream.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; frontend::execute(&self.name, self.row_limit, buf)
})?;
conn.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
conn.stream.flush()?; conn.stream.flush()?;
conn.read_rows(|row| self.data.push_back(row)).map(|more_rows| self.more_rows = more_rows) conn.read_rows(|row| self.data.push_back(row)).map(
|more_rows| {
self.more_rows = more_rows
},
)
} }
/// Returns a slice describing the columns of the `LazyRows`. /// Returns a slice describing the columns of the `LazyRows`.
@ -375,14 +390,12 @@ impl<'trans, 'stmt> FallibleIterator for LazyRows<'trans, 'stmt> {
self.execute()?; self.execute()?;
} }
let row = self.data let row = self.data.pop_front().map(|r| {
.pop_front() Row {
.map(|r| { stmt_info: &**self.stmt.info(),
Row { data: MaybeOwned::Owned(r),
stmt_info: &**self.stmt.info(), }
data: MaybeOwned::Owned(r), });
}
});
Ok(row) Ok(row)
} }

View File

@ -37,11 +37,12 @@ impl<'conn> Drop for Statement<'conn> {
} }
impl<'conn> StatementInternals<'conn> for Statement<'conn> { impl<'conn> StatementInternals<'conn> for Statement<'conn> {
fn new(conn: &'conn Connection, fn new(
info: Arc<StatementInfo>, conn: &'conn Connection,
next_portal_id: Cell<u32>, info: Arc<StatementInfo>,
finished: bool) next_portal_id: Cell<u32>,
-> Statement<'conn> { finished: bool,
) -> Statement<'conn> {
Statement { Statement {
conn: conn, conn: conn,
info: info, info: info,
@ -79,21 +80,25 @@ impl<'conn> Statement<'conn> {
} }
#[allow(type_complexity)] #[allow(type_complexity)]
fn inner_query<F>(&self, fn inner_query<F>(
portal_name: &str, &self,
row_limit: i32, portal_name: &str,
params: &[&ToSql], row_limit: i32,
acceptor: F) params: &[&ToSql],
-> Result<bool> acceptor: F,
where F: FnMut(RowData) ) -> Result<bool>
where
F: FnMut(RowData),
{ {
let mut conn = self.conn.0.borrow_mut(); let mut conn = self.conn.0.borrow_mut();
conn.raw_execute(&self.info.name, conn.raw_execute(
portal_name, &self.info.name,
row_limit, portal_name,
self.param_types(), row_limit,
params)?; self.param_types(),
params,
)?;
conn.read_rows(acceptor) conn.read_rows(acceptor)
} }
@ -131,7 +136,13 @@ impl<'conn> Statement<'conn> {
pub fn execute(&self, params: &[&ToSql]) -> Result<u64> { pub fn execute(&self, params: &[&ToSql]) -> Result<u64> {
let mut conn = self.conn.0.borrow_mut(); let mut conn = self.conn.0.borrow_mut();
check_desync!(conn); check_desync!(conn);
conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)?; conn.raw_execute(
&self.info.name,
"",
0,
self.param_types(),
params,
)?;
let num; let num;
loop { loop {
@ -153,8 +164,9 @@ impl<'conn> Statement<'conn> {
conn.stream.write_message(|buf| { conn.stream.write_message(|buf| {
frontend::copy_fail("COPY queries cannot be directly executed", buf) frontend::copy_fail("COPY queries cannot be directly executed", buf)
})?; })?;
conn.stream conn.stream.write_message(
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; |buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
conn.stream.flush()?; conn.stream.flush()?;
} }
backend::Message::CopyOutResponse(_) => { backend::Message::CopyOutResponse(_) => {
@ -226,18 +238,23 @@ impl<'conn> Statement<'conn> {
/// `Connection` as this `Statement`, if the `Transaction` is not /// `Connection` as this `Statement`, if the `Transaction` is not
/// active, or if the number of parameters provided does not match the /// active, or if the number of parameters provided does not match the
/// number of parameters expected. /// number of parameters expected.
pub fn lazy_query<'trans, 'stmt>(&'stmt self, pub fn lazy_query<'trans, 'stmt>(
trans: &'trans Transaction, &'stmt self,
params: &[&ToSql], trans: &'trans Transaction,
row_limit: i32) params: &[&ToSql],
-> Result<LazyRows<'trans, 'stmt>> { row_limit: i32,
assert!(self.conn as *const _ == trans.conn() as *const _, ) -> Result<LazyRows<'trans, 'stmt>> {
"the `Transaction` passed to `lazy_query` must be associated with the same \ assert!(
`Connection` as the `Statement`"); self.conn as *const _ == trans.conn() as *const _,
"the `Transaction` passed to `lazy_query` must be associated with the same \
`Connection` as the `Statement`"
);
let conn = self.conn.0.borrow(); let conn = self.conn.0.borrow();
check_desync!(conn); check_desync!(conn);
assert!(conn.trans_depth == trans.depth(), assert!(
"`lazy_query` must be passed the active transaction"); conn.trans_depth == trans.depth(),
"`lazy_query` must be passed the active transaction"
);
drop(conn); drop(conn);
let id = self.next_portal_id.get(); let id = self.next_portal_id.get();
@ -245,11 +262,21 @@ impl<'conn> Statement<'conn> {
let portal_name = format!("{}p{}", self.info.name, id); let portal_name = format!("{}p{}", self.info.name, id);
let mut rows = VecDeque::new(); let mut rows = VecDeque::new();
let more_rows = self.inner_query(&portal_name, let more_rows = self.inner_query(
row_limit, &portal_name,
params, row_limit,
|row| rows.push_back(row))?; params,
Ok(LazyRows::new(self, rows, portal_name, row_limit, more_rows, false, trans)) |row| rows.push_back(row),
)?;
Ok(LazyRows::new(
self,
rows,
portal_name,
row_limit,
more_rows,
false,
trans,
))
} }
/// Executes a `COPY FROM STDIN` statement, returning the number of rows /// Executes a `COPY FROM STDIN` statement, returning the number of rows
@ -275,14 +302,18 @@ impl<'conn> Statement<'conn> {
/// ``` /// ```
pub fn copy_in<R: ReadWithInfo>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> { pub fn copy_in<R: ReadWithInfo>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> {
let mut conn = self.conn.0.borrow_mut(); let mut conn = self.conn.0.borrow_mut();
conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)?; conn.raw_execute(
&self.info.name,
"",
0,
self.param_types(),
params,
)?;
let (format, column_formats) = match conn.read_message()? { let (format, column_formats) = match conn.read_message()? {
backend::Message::CopyInResponse(body) => { backend::Message::CopyInResponse(body) => {
let format = body.format(); let format = body.format();
let column_formats = body.column_formats() let column_formats = body.column_formats().map(|f| Format::from_u16(f)).collect()?;
.map(|f| Format::from_u16(f))
.collect()?;
(format, column_formats) (format, column_formats)
} }
backend::Message::ErrorResponse(body) => { backend::Message::ErrorResponse(body) => {
@ -292,10 +323,12 @@ impl<'conn> Statement<'conn> {
_ => { _ => {
loop { loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? { if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput, return Err(Error::Io(io::Error::new(
"called `copy_in` on a \ io::ErrorKind::InvalidInput,
"called `copy_in` on a \
non-`COPY FROM STDIN` \ non-`COPY FROM STDIN` \
statement"))); statement",
)));
} }
} }
} }
@ -311,14 +344,20 @@ impl<'conn> Statement<'conn> {
match fill_copy_buf(&mut buf, r, &info) { match fill_copy_buf(&mut buf, r, &info) {
Ok(0) => break, Ok(0) => break,
Ok(len) => { Ok(len) => {
conn.stream.write_message(|out| frontend::copy_data(&buf[..len], out))?; conn.stream.write_message(
|out| frontend::copy_data(&buf[..len], out),
)?;
} }
Err(err) => { Err(err) => {
conn.stream.write_message(|buf| frontend::copy_fail("", buf))?; conn.stream.write_message(
conn.stream |buf| frontend::copy_fail("", buf),
.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))?; )?;
conn.stream conn.stream.write_message(|buf| {
.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; Ok::<(), io::Error>(frontend::copy_done(buf))
})?;
conn.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
conn.stream.flush()?; conn.stream.flush()?;
match conn.read_message()? { match conn.read_message()? {
backend::Message::ErrorResponse(_) => { backend::Message::ErrorResponse(_) => {
@ -335,8 +374,12 @@ impl<'conn> Statement<'conn> {
} }
} }
conn.stream.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))?; conn.stream.write_message(|buf| {
conn.stream.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; Ok::<(), io::Error>(frontend::copy_done(buf))
})?;
conn.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
conn.stream.flush()?; conn.stream.flush()?;
let num = match conn.read_message()? { let num = match conn.read_message()? {
@ -379,21 +422,30 @@ impl<'conn> Statement<'conn> {
/// ``` /// ```
pub fn copy_out<'a, W: WriteWithInfo>(&'a self, params: &[&ToSql], w: &mut W) -> Result<u64> { pub fn copy_out<'a, W: WriteWithInfo>(&'a self, params: &[&ToSql], w: &mut W) -> Result<u64> {
let mut conn = self.conn.0.borrow_mut(); let mut conn = self.conn.0.borrow_mut();
conn.raw_execute(&self.info.name, "", 0, self.param_types(), params)?; conn.raw_execute(
&self.info.name,
"",
0,
self.param_types(),
params,
)?;
let (format, column_formats) = match conn.read_message()? { let (format, column_formats) = match conn.read_message()? {
backend::Message::CopyOutResponse(body) => { backend::Message::CopyOutResponse(body) => {
let format = body.format(); let format = body.format();
let column_formats = body.column_formats() let column_formats = body.column_formats().map(|f| Format::from_u16(f)).collect()?;
.map(|f| Format::from_u16(f))
.collect()?;
(format, column_formats) (format, column_formats)
} }
backend::Message::CopyInResponse(_) => { backend::Message::CopyInResponse(_) => {
conn.stream.write_message(|buf| frontend::copy_fail("", buf))?; conn.stream.write_message(
conn.stream |buf| frontend::copy_fail("", buf),
.write_message(|buf| Ok::<(), io::Error>(frontend::copy_done(buf)))?; )?;
conn.stream.write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; conn.stream.write_message(|buf| {
Ok::<(), io::Error>(frontend::copy_done(buf))
})?;
conn.stream.write_message(
|buf| Ok::<(), io::Error>(frontend::sync(buf)),
)?;
conn.stream.flush()?; conn.stream.flush()?;
match conn.read_message()? { match conn.read_message()? {
backend::Message::ErrorResponse(_) => { backend::Message::ErrorResponse(_) => {
@ -405,9 +457,11 @@ impl<'conn> Statement<'conn> {
} }
} }
conn.wait_for_ready()?; conn.wait_for_ready()?;
return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput, return Err(Error::Io(io::Error::new(
"called `copy_out` on a non-`COPY TO \ io::ErrorKind::InvalidInput,
STDOUT` statement"))); "called `copy_out` on a non-`COPY TO \
STDOUT` statement",
)));
} }
backend::Message::ErrorResponse(body) => { backend::Message::ErrorResponse(body) => {
conn.wait_for_ready()?; conn.wait_for_ready()?;
@ -416,9 +470,11 @@ impl<'conn> Statement<'conn> {
_ => { _ => {
loop { loop {
if let backend::Message::ReadyForQuery(_) = conn.read_message()? { if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
return Err(Error::Io(io::Error::new(io::ErrorKind::InvalidInput, return Err(Error::Io(io::Error::new(
"called `copy_out` on a \ io::ErrorKind::InvalidInput,
non-`COPY TO STDOUT` statement"))); "called `copy_out` on a \
non-`COPY TO STDOUT` statement",
)));
} }
} }
} }
@ -440,7 +496,8 @@ impl<'conn> Statement<'conn> {
Err(e) => { Err(e) => {
loop { loop {
if let backend::Message::ReadyForQuery(_) = if let backend::Message::ReadyForQuery(_) =
conn.read_message()? { conn.read_message()?
{
return Err(Error::Io(e)); return Err(Error::Io(e));
} }
} }
@ -455,16 +512,14 @@ impl<'conn> Statement<'conn> {
} }
backend::Message::ErrorResponse(body) => { backend::Message::ErrorResponse(body) => {
loop { loop {
if let backend::Message::ReadyForQuery(_) = if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
conn.read_message()? {
return Err(err(&mut body.fields())); return Err(err(&mut body.fields()));
} }
} }
} }
_ => { _ => {
loop { loop {
if let backend::Message::ReadyForQuery(_) = if let backend::Message::ReadyForQuery(_) = conn.read_message()? {
conn.read_message()? {
return Err(Error::Io(bad_response())); return Err(Error::Io(bad_response()));
} }
} }

View File

@ -31,18 +31,19 @@ pub trait TlsHandshake: fmt::Debug {
/// ///
/// The host portion of the connection parameters is provided for hostname /// The host portion of the connection parameters is provided for hostname
/// verification. /// verification.
fn tls_handshake(&self, fn tls_handshake(
host: &str, &self,
stream: Stream) host: &str,
-> Result<Box<TlsStream>, Box<Error + Sync + Send>>; stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>>;
} }
impl<T: TlsHandshake + ?Sized> TlsHandshake for Box<T> { impl<T: TlsHandshake + ?Sized> TlsHandshake for Box<T> {
fn tls_handshake(&self, fn tls_handshake(
host: &str, &self,
stream: Stream) host: &str,
-> Result<Box<TlsStream>, Box<Error + Sync + Send>> { stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>> {
(**self).tls_handshake(host, stream) (**self).tls_handshake(host, stream)
} }
} }

View File

@ -54,10 +54,11 @@ impl From<TlsConnector> for NativeTls {
} }
impl TlsHandshake for NativeTls { impl TlsHandshake for NativeTls {
fn tls_handshake(&self, fn tls_handshake(
domain: &str, &self,
stream: Stream) domain: &str,
-> Result<Box<TlsStream>, Box<Error + Send + Sync>> { stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Send + Sync>> {
let stream = self.0.connect(domain, stream)?; let stream = self.0.connect(domain, stream)?;
Ok(Box::new(stream)) Ok(Box::new(stream))
} }

View File

@ -70,10 +70,11 @@ impl From<SslConnector> for OpenSsl {
} }
impl TlsHandshake for OpenSsl { impl TlsHandshake for OpenSsl {
fn tls_handshake(&self, fn tls_handshake(
domain: &str, &self,
stream: Stream) domain: &str,
-> Result<Box<TlsStream>, Box<Error + Send + Sync>> { stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Send + Sync>> {
let stream = if self.disable_verification { let stream = if self.disable_verification {
self.connector.danger_connect_without_providing_domain_for_certificate_verification_and_server_name_indication(stream)? self.connector.danger_connect_without_providing_domain_for_certificate_verification_and_server_name_indication(stream)?
} else { } else {

View File

@ -38,14 +38,16 @@ impl Schannel {
} }
impl TlsHandshake for Schannel { impl TlsHandshake for Schannel {
fn tls_handshake(&self, fn tls_handshake(
host: &str, &self,
stream: Stream) host: &str,
-> Result<Box<TlsStream>, Box<Error + Sync + Send>> { stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Sync + Send>> {
let creds = SchannelCred::builder().acquire(Direction::Outbound)?; let creds = SchannelCred::builder().acquire(Direction::Outbound)?;
let stream = tls_stream::Builder::new() let stream = tls_stream::Builder::new().domain(host).connect(
.domain(host) creds,
.connect(creds, stream)?; stream,
)?;
Ok(Box::new(stream)) Ok(Box::new(stream))
} }
} }

View File

@ -45,10 +45,11 @@ impl From<ClientBuilder> for SecurityFramework {
} }
impl TlsHandshake for SecurityFramework { impl TlsHandshake for SecurityFramework {
fn tls_handshake(&self, fn tls_handshake(
domain: &str, &self,
stream: Stream) domain: &str,
-> Result<Box<TlsStream>, Box<Error + Send + Sync>> { stream: Stream,
) -> Result<Box<TlsStream>, Box<Error + Send + Sync>> {
let stream = self.0.handshake(domain, stream)?; let stream = self.0.handshake(domain, stream)?;
Ok(Box::new(stream)) Ok(Box::new(stream))
} }

View File

@ -256,8 +256,10 @@ impl<'conn> Transaction<'conn> {
pub fn savepoint<'a>(&'a self, name: &str) -> Result<Transaction<'a>> { pub fn savepoint<'a>(&'a self, name: &str) -> Result<Transaction<'a>> {
let mut conn = self.conn.0.borrow_mut(); let mut conn = self.conn.0.borrow_mut();
check_desync!(conn); check_desync!(conn);
assert!(conn.trans_depth == self.depth, assert!(
"`savepoint` may only be called on the active transaction"); conn.trans_depth == self.depth,
"`savepoint` may only be called on the active transaction"
);
conn.quick_query(&format!("SAVEPOINT {}", name))?; conn.quick_query(&format!("SAVEPOINT {}", name))?;
conn.trans_depth += 1; conn.trans_depth += 1;
Ok(Transaction { Ok(Transaction {

View File

@ -1,8 +1,8 @@
//! Traits dealing with Postgres data types //! Traits dealing with Postgres data types
#[doc(inline)] #[doc(inline)]
pub use postgres_shared::types::{Oid, Type, Date, Timestamp, Kind, Field, Other, pub use postgres_shared::types::{Oid, Type, Date, Timestamp, Kind, Field, Other, WasNull,
WasNull, WrongType, FromSql, IsNull, ToSql}; WrongType, FromSql, IsNull, ToSql};
#[doc(hidden)] #[doc(hidden)]
pub use postgres_shared::types::__to_sql_checked; pub use postgres_shared::types::__to_sql_checked;

File diff suppressed because it is too large Load Diff

View File

@ -8,8 +8,10 @@ fn test_bit_params() {
let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]); let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]);
bv.pop(); bv.pop();
bv.pop(); bv.pop();
test_type("BIT(14)", &[(Some(bv), "B'01101001000001'"), test_type(
(None, "NULL")]) "BIT(14)",
&[(Some(bv), "B'01101001000001'"), (None, "NULL")],
)
} }
#[test] #[test]
@ -17,7 +19,12 @@ fn test_varbit_params() {
let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]); let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]);
bv.pop(); bv.pop();
bv.pop(); bv.pop();
test_type("VARBIT", &[(Some(bv), "B'01101001000001'"), test_type(
(Some(BitVec::from_bytes(&[])), "B''"), "VARBIT",
(None, "NULL")]) &[
(Some(bv), "B'01101001000001'"),
(Some(BitVec::from_bytes(&[])), "B''"),
(None, "NULL"),
],
)
} }

View File

@ -8,87 +8,145 @@ use postgres::types::{Date, Timestamp};
#[test] #[test]
fn test_naive_date_time_params() { fn test_naive_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Option<NaiveDateTime>, &'a str) { fn make_check<'a>(time: &'a str) -> (Option<NaiveDateTime>, &'a str) {
(Some(NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), time) (
Some(
NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(),
),
time,
)
} }
test_type("TIMESTAMP", test_type(
&[make_check("'1970-01-01 00:00:00.010000000'"), "TIMESTAMP",
make_check("'1965-09-25 11:19:33.100314000'"), &[
make_check("'2010-02-09 23:11:45.120200000'"), make_check("'1970-01-01 00:00:00.010000000'"),
(None, "NULL")]); make_check("'1965-09-25 11:19:33.100314000'"),
make_check("'2010-02-09 23:11:45.120200000'"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_with_special_naive_date_time_params() { fn test_with_special_naive_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Timestamp<NaiveDateTime>, &'a str) { fn make_check<'a>(time: &'a str) -> (Timestamp<NaiveDateTime>, &'a str) {
(Timestamp::Value(NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), (
time) Timestamp::Value(
NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(),
),
time,
)
} }
test_type("TIMESTAMP", test_type(
&[make_check("'1970-01-01 00:00:00.010000000'"), "TIMESTAMP",
make_check("'1965-09-25 11:19:33.100314000'"), &[
make_check("'2010-02-09 23:11:45.120200000'"), make_check("'1970-01-01 00:00:00.010000000'"),
(Timestamp::PosInfinity, "'infinity'"), make_check("'1965-09-25 11:19:33.100314000'"),
(Timestamp::NegInfinity, "'-infinity'")]); make_check("'2010-02-09 23:11:45.120200000'"),
(Timestamp::PosInfinity, "'infinity'"),
(Timestamp::NegInfinity, "'-infinity'"),
],
);
} }
#[test] #[test]
fn test_date_time_params() { fn test_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Option<DateTime<Utc>>, &'a str) { fn make_check<'a>(time: &'a str) -> (Option<DateTime<Utc>>, &'a str) {
(Some(Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), time) (
Some(
Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'")
.unwrap(),
),
time,
)
} }
test_type("TIMESTAMP WITH TIME ZONE", test_type(
&[make_check("'1970-01-01 00:00:00.010000000'"), "TIMESTAMP WITH TIME ZONE",
make_check("'1965-09-25 11:19:33.100314000'"), &[
make_check("'2010-02-09 23:11:45.120200000'"), make_check("'1970-01-01 00:00:00.010000000'"),
(None, "NULL")]); make_check("'1965-09-25 11:19:33.100314000'"),
make_check("'2010-02-09 23:11:45.120200000'"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_with_special_date_time_params() { fn test_with_special_date_time_params() {
fn make_check<'a>(time: &'a str) -> (Timestamp<DateTime<Utc>>, &'a str) { fn make_check<'a>(time: &'a str) -> (Timestamp<DateTime<Utc>>, &'a str) {
(Timestamp::Value(Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), time) (
Timestamp::Value(
Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'")
.unwrap(),
),
time,
)
} }
test_type("TIMESTAMP WITH TIME ZONE", test_type(
&[make_check("'1970-01-01 00:00:00.010000000'"), "TIMESTAMP WITH TIME ZONE",
make_check("'1965-09-25 11:19:33.100314000'"), &[
make_check("'2010-02-09 23:11:45.120200000'"), make_check("'1970-01-01 00:00:00.010000000'"),
(Timestamp::PosInfinity, "'infinity'"), make_check("'1965-09-25 11:19:33.100314000'"),
(Timestamp::NegInfinity, "'-infinity'")]); make_check("'2010-02-09 23:11:45.120200000'"),
(Timestamp::PosInfinity, "'infinity'"),
(Timestamp::NegInfinity, "'-infinity'"),
],
);
} }
#[test] #[test]
fn test_date_params() { fn test_date_params() {
fn make_check<'a>(time: &'a str) -> (Option<NaiveDate>, &'a str) { fn make_check<'a>(time: &'a str) -> (Option<NaiveDate>, &'a str) {
(Some(NaiveDate::parse_from_str(time, "'%Y-%m-%d'").unwrap()), time) (
Some(NaiveDate::parse_from_str(time, "'%Y-%m-%d'").unwrap()),
time,
)
} }
test_type("DATE", test_type(
&[make_check("'1970-01-01'"), "DATE",
make_check("'1965-09-25'"), &[
make_check("'2010-02-09'"), make_check("'1970-01-01'"),
(None, "NULL")]); make_check("'1965-09-25'"),
make_check("'2010-02-09'"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_with_special_date_params() { fn test_with_special_date_params() {
fn make_check<'a>(date: &'a str) -> (Date<NaiveDate>, &'a str) { fn make_check<'a>(date: &'a str) -> (Date<NaiveDate>, &'a str) {
(Date::Value(NaiveDate::parse_from_str(date, "'%Y-%m-%d'").unwrap()), date) (
Date::Value(NaiveDate::parse_from_str(date, "'%Y-%m-%d'").unwrap()),
date,
)
} }
test_type("DATE", test_type(
&[make_check("'1970-01-01'"), "DATE",
make_check("'1965-09-25'"), &[
make_check("'2010-02-09'"), make_check("'1970-01-01'"),
(Date::PosInfinity, "'infinity'"), make_check("'1965-09-25'"),
(Date::NegInfinity, "'-infinity'")]); make_check("'2010-02-09'"),
(Date::PosInfinity, "'infinity'"),
(Date::NegInfinity, "'-infinity'"),
],
);
} }
#[test] #[test]
fn test_time_params() { fn test_time_params() {
fn make_check<'a>(time: &'a str) -> (Option<NaiveTime>, &'a str) { fn make_check<'a>(time: &'a str) -> (Option<NaiveTime>, &'a str) {
(Some(NaiveTime::parse_from_str(time, "'%H:%M:%S.%f'").unwrap()), time) (
Some(NaiveTime::parse_from_str(time, "'%H:%M:%S.%f'").unwrap()),
time,
)
} }
test_type("TIME", test_type(
&[make_check("'00:00:00.010000000'"), "TIME",
make_check("'11:19:33.100314000'"), &[
make_check("'23:11:45.120200000'"), make_check("'00:00:00.010000000'"),
(None, "NULL")]); make_check("'11:19:33.100314000'"),
make_check("'23:11:45.120200000'"),
(None, "NULL"),
],
);
} }

View File

@ -4,6 +4,14 @@ use types::test_type;
#[test] #[test]
fn test_eui48_params() { fn test_eui48_params() {
test_type("MACADDR", &[(Some(eui48::MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()), test_type(
"'12-34-56-ab-cd-ef'"), (None, "NULL")]) "MACADDR",
&[
(
Some(eui48::MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()),
"'12-34-56-ab-cd-ef'",
),
(None, "NULL"),
],
)
} }

View File

@ -5,24 +5,50 @@ use types::test_type;
#[test] #[test]
fn test_point_params() { fn test_point_params() {
test_type("POINT", test_type(
&[(Some(Point::new(0.0, 0.0)), "POINT(0, 0)"), "POINT",
(Some(Point::new(-3.14, 1.618)), "POINT(-3.14, 1.618)"), &[
(None, "NULL")]); (Some(Point::new(0.0, 0.0)), "POINT(0, 0)"),
(Some(Point::new(-3.14, 1.618)), "POINT(-3.14, 1.618)"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_box_params() { fn test_box_params() {
test_type("BOX", test_type(
&[(Some(Bbox{xmax: 160.0, ymax: 69701.5615, xmin: -3.14, ymin: 1.618}), "BOX",
"BOX(POINT(160.0, 69701.5615), POINT(-3.14, 1.618))"), &[
(None, "NULL")]); (
Some(Bbox {
xmax: 160.0,
ymax: 69701.5615,
xmin: -3.14,
ymin: 1.618,
}),
"BOX(POINT(160.0, 69701.5615), POINT(-3.14, 1.618))",
),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_path_params() { fn test_path_params() {
let points = vec![Point::new(0.0, 0.0), Point::new(-3.14, 1.618), Point::new(160.0, 69701.5615)]; let points = vec![
test_type("PATH", Point::new(0.0, 0.0),
&[(Some(LineString(points)),"path '((0, 0), (-3.14, 1.618), (160.0, 69701.5615))'"), Point::new(-3.14, 1.618),
(None, "NULL")]); Point::new(160.0, 69701.5615),
];
test_type(
"PATH",
&[
(
Some(LineString(points)),
"path '((0, 0), (-3.14, 1.618), (160.0, 69701.5615))'",
),
(None, "NULL"),
],
);
} }

View File

@ -27,7 +27,10 @@ mod chrono;
mod geo; mod geo;
fn test_type<T: PartialEq + FromSql + ToSql, S: fmt::Display>(sql_type: &str, checks: &[(T, S)]) { fn test_type<T: PartialEq + FromSql + ToSql, S: fmt::Display>(sql_type: &str, checks: &[(T, S)]) {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", TlsMode::None)); let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost",
TlsMode::None,
));
for &(ref val, ref repr) in checks.iter() { for &(ref val, ref repr) in checks.iter() {
let stmt = or_panic!(conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type))); let stmt = or_panic!(conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type)));
let result = or_panic!(stmt.query(&[])).iter().next().unwrap().get(0); let result = or_panic!(stmt.query(&[])).iter().next().unwrap().get(0);
@ -41,7 +44,10 @@ fn test_type<T: PartialEq + FromSql + ToSql, S: fmt::Display>(sql_type: &str, ch
#[test] #[test]
fn test_ref_tosql() { fn test_ref_tosql() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", TlsMode::None)); let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost",
TlsMode::None,
));
let stmt = conn.prepare("SELECT $1::Int").unwrap(); let stmt = conn.prepare("SELECT $1::Int").unwrap();
let num: &ToSql = &&7; let num: &ToSql = &&7;
stmt.query(&[num]).unwrap(); stmt.query(&[num]).unwrap();
@ -49,8 +55,10 @@ fn test_ref_tosql() {
#[test] #[test]
fn test_bool_params() { fn test_bool_params() {
test_type("BOOL", test_type(
&[(Some(true), "'t'"), (Some(false), "'f'"), (None, "NULL")]); "BOOL",
&[(Some(true), "'t'"), (Some(false), "'f'"), (None, "NULL")],
);
} }
#[test] #[test]
@ -60,120 +68,190 @@ fn test_i8_params() {
#[test] #[test]
fn test_name_params() { fn test_name_params() {
test_type("NAME", test_type(
&[(Some("hello world".to_owned()), "'hello world'"), "NAME",
(Some("イロハニホヘト チリヌルヲ".to_owned()), &[
"'イロハニホヘト チリヌルヲ'"), (Some("hello world".to_owned()), "'hello world'"),
(None, "NULL")]); (
Some("イロハニホヘト チリヌルヲ".to_owned()),
"'イロハニホヘト チリヌルヲ'",
),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_i16_params() { fn test_i16_params() {
test_type("SMALLINT", test_type(
&[(Some(15001i16), "15001"), "SMALLINT",
(Some(-15001i16), "-15001"), &[
(None, "NULL")]); (Some(15001i16), "15001"),
(Some(-15001i16), "-15001"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_i32_params() { fn test_i32_params() {
test_type("INT", test_type(
&[(Some(2147483548i32), "2147483548"), "INT",
(Some(-2147483548i32), "-2147483548"), &[
(None, "NULL")]); (Some(2147483548i32), "2147483548"),
(Some(-2147483548i32), "-2147483548"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_oid_params() { fn test_oid_params() {
test_type("OID", test_type(
&[(Some(2147483548u32), "2147483548"), "OID",
(Some(4000000000), "4000000000"), &[
(None, "NULL")]); (Some(2147483548u32), "2147483548"),
(Some(4000000000), "4000000000"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_i64_params() { fn test_i64_params() {
test_type("BIGINT", test_type(
&[(Some(9223372036854775708i64), "9223372036854775708"), "BIGINT",
(Some(-9223372036854775708i64), "-9223372036854775708"), &[
(None, "NULL")]); (Some(9223372036854775708i64), "9223372036854775708"),
(Some(-9223372036854775708i64), "-9223372036854775708"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_f32_params() { fn test_f32_params() {
test_type("REAL", test_type(
&[(Some(f32::INFINITY), "'infinity'"), "REAL",
(Some(f32::NEG_INFINITY), "'-infinity'"), &[
(Some(1000.55), "1000.55"), (Some(f32::INFINITY), "'infinity'"),
(None, "NULL")]); (Some(f32::NEG_INFINITY), "'-infinity'"),
(Some(1000.55), "1000.55"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_f64_params() { fn test_f64_params() {
test_type("DOUBLE PRECISION", test_type(
&[(Some(f64::INFINITY), "'infinity'"), "DOUBLE PRECISION",
(Some(f64::NEG_INFINITY), "'-infinity'"), &[
(Some(10000.55), "10000.55"), (Some(f64::INFINITY), "'infinity'"),
(None, "NULL")]); (Some(f64::NEG_INFINITY), "'-infinity'"),
(Some(10000.55), "10000.55"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_varchar_params() { fn test_varchar_params() {
test_type("VARCHAR", test_type(
&[(Some("hello world".to_owned()), "'hello world'"), "VARCHAR",
(Some("イロハニホヘト チリヌルヲ".to_owned()), &[
"'イロハニホヘト チリヌルヲ'"), (Some("hello world".to_owned()), "'hello world'"),
(None, "NULL")]); (
Some("イロハニホヘト チリヌルヲ".to_owned()),
"'イロハニホヘト チリヌルヲ'",
),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_text_params() { fn test_text_params() {
test_type("TEXT", test_type(
&[(Some("hello world".to_owned()), "'hello world'"), "TEXT",
(Some("イロハニホヘト チリヌルヲ".to_owned()), &[
"'イロハニホヘト チリヌルヲ'"), (Some("hello world".to_owned()), "'hello world'"),
(None, "NULL")]); (
Some("イロハニホヘト チリヌルヲ".to_owned()),
"'イロハニホヘト チリヌルヲ'",
),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_bpchar_params() { fn test_bpchar_params() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", TlsMode::None)); let conn = or_panic!(Connection::connect(
or_panic!(conn.execute("CREATE TEMPORARY TABLE foo ( "postgres://postgres@localhost",
TlsMode::None,
));
or_panic!(conn.execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
b CHAR(5) b CHAR(5)
)", )",
&[])); &[],
or_panic!(conn.execute("INSERT INTO foo (b) VALUES ($1), ($2), ($3)", ));
&[&Some("12345"), &Some("123"), &None::<&'static str>])); or_panic!(conn.execute(
"INSERT INTO foo (b) VALUES ($1), ($2), ($3)",
&[&Some("12345"), &Some("123"), &None::<&'static str>],
));
let stmt = or_panic!(conn.prepare("SELECT b FROM foo ORDER BY id")); let stmt = or_panic!(conn.prepare("SELECT b FROM foo ORDER BY id"));
let res = or_panic!(stmt.query(&[])); let res = or_panic!(stmt.query(&[]));
assert_eq!(vec![Some("12345".to_owned()), Some("123 ".to_owned()), None], assert_eq!(
res.iter().map(|row| row.get(0)).collect::<Vec<_>>()); vec![Some("12345".to_owned()), Some("123 ".to_owned()), None],
res.iter().map(|row| row.get(0)).collect::<Vec<_>>()
);
} }
#[test] #[test]
fn test_citext_params() { fn test_citext_params() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", TlsMode::None)); let conn = or_panic!(Connection::connect(
or_panic!(conn.execute("CREATE TEMPORARY TABLE foo ( "postgres://postgres@localhost",
TlsMode::None,
));
or_panic!(conn.execute(
"CREATE TEMPORARY TABLE foo (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
b CITEXT b CITEXT
)", )",
&[])); &[],
or_panic!(conn.execute("INSERT INTO foo (b) VALUES ($1), ($2), ($3)", ));
&[&Some("foobar"), &Some("FooBar"), &None::<&'static str>])); or_panic!(conn.execute(
let stmt = or_panic!(conn.prepare("SELECT id FROM foo WHERE b = 'FOOBAR' ORDER BY id")); "INSERT INTO foo (b) VALUES ($1), ($2), ($3)",
&[
&Some("foobar"),
&Some("FooBar"),
&None::<&'static str>,
],
));
let stmt = or_panic!(conn.prepare(
"SELECT id FROM foo WHERE b = 'FOOBAR' ORDER BY id",
));
let res = or_panic!(stmt.query(&[])); let res = or_panic!(stmt.query(&[]));
assert_eq!(vec![Some(1i32), Some(2i32)], assert_eq!(
res.iter().map(|row| row.get(0)).collect::<Vec<_>>()); vec![Some(1i32), Some(2i32)],
res.iter().map(|row| row.get(0)).collect::<Vec<_>>()
);
} }
#[test] #[test]
fn test_bytea_params() { fn test_bytea_params() {
test_type("BYTEA", test_type(
&[(Some(vec![0u8, 1, 2, 3, 254, 255]), "'\\x00010203feff'"), "BYTEA",
(None, "NULL")]); &[
(Some(vec![0u8, 1, 2, 3, 254, 255]), "'\\x00010203feff'"),
(None, "NULL"),
],
);
} }
#[test] #[test]
@ -185,26 +263,42 @@ fn test_hstore_params() {
map map
}) })
} }
test_type("hstore", test_type(
&[(Some(make_map!("a".to_owned() => Some("1".to_owned()))), "'a=>1'"), "hstore",
(Some(make_map!("hello".to_owned() => Some("world!".to_owned()), &[
(
Some(make_map!("a".to_owned() => Some("1".to_owned()))),
"'a=>1'",
),
(
Some(make_map!("hello".to_owned() => Some("world!".to_owned()),
"hola".to_owned() => Some("mundo!".to_owned()), "hola".to_owned() => Some("mundo!".to_owned()),
"what".to_owned() => None)), "what".to_owned() => None)),
"'hello=>world!,hola=>mundo!,what=>NULL'"), "'hello=>world!,hola=>mundo!,what=>NULL'",
(None, "NULL")]); ),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_array_params() { fn test_array_params() {
test_type("integer[]", test_type(
&[(Some(vec![1i32, 2i32]), "ARRAY[1,2]"), "integer[]",
(Some(vec![1i32]), "ARRAY[1]"), &[
(Some(vec![]), "ARRAY[]"), (Some(vec![1i32, 2i32]), "ARRAY[1,2]"),
(None, "NULL")]); (Some(vec![1i32]), "ARRAY[1]"),
(Some(vec![]), "ARRAY[]"),
(None, "NULL"),
],
);
} }
fn test_nan_param<T: PartialEq + ToSql + FromSql>(sql_type: &str) { fn test_nan_param<T: PartialEq + ToSql + FromSql>(sql_type: &str) {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", TlsMode::None)); let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost",
TlsMode::None,
));
let stmt = or_panic!(conn.prepare(&*format!("SELECT 'NaN'::{}", sql_type))); let stmt = or_panic!(conn.prepare(&*format!("SELECT 'NaN'::{}", sql_type)));
let result = or_panic!(stmt.query(&[])); let result = or_panic!(stmt.query(&[]));
let val: T = result.iter().next().unwrap().get(0); let val: T = result.iter().next().unwrap().get(0);
@ -223,7 +317,10 @@ fn test_f64_nan_param() {
#[test] #[test]
fn test_pg_database_datname() { fn test_pg_database_datname() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", TlsMode::None)); let conn = or_panic!(Connection::connect(
"postgres://postgres@localhost",
TlsMode::None,
));
let stmt = or_panic!(conn.prepare("SELECT datname FROM pg_database")); let stmt = or_panic!(conn.prepare("SELECT datname FROM pg_database"));
let result = or_panic!(stmt.query(&[])); let result = or_panic!(stmt.query(&[]));
@ -235,18 +332,21 @@ fn test_pg_database_datname() {
#[test] #[test]
fn test_slice() { fn test_slice() {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap(); let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
conn.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, f VARCHAR); conn.batch_execute(
INSERT INTO foo (f) VALUES ('a'), ('b'), ('c'), ('d');") "CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, f VARCHAR);
.unwrap(); INSERT INTO foo (f) VALUES ('a'), ('b'), ('c'), ('d');",
).unwrap();
let stmt = conn.prepare("SELECT f FROM foo WHERE id = ANY($1)") let stmt = conn.prepare("SELECT f FROM foo WHERE id = ANY($1)")
.unwrap(); .unwrap();
let result = stmt.query(&[&&[1i32, 3, 4][..]]).unwrap(); let result = stmt.query(&[&&[1i32, 3, 4][..]]).unwrap();
assert_eq!(vec!["a".to_owned(), "c".to_owned(), "d".to_owned()], assert_eq!(
result vec!["a".to_owned(), "c".to_owned(), "d".to_owned()],
.iter() result
.map(|r| r.get::<_, String>(0)) .iter()
.collect::<Vec<_>>()); .map(|r| r.get::<_, String>(0))
.collect::<Vec<_>>()
);
} }
#[test] #[test]
@ -282,10 +382,11 @@ fn domain() {
struct SessionId(Vec<u8>); struct SessionId(Vec<u8>);
impl ToSql for SessionId { impl ToSql for SessionId {
fn to_sql(&self, fn to_sql(
ty: &Type, &self,
out: &mut Vec<u8>) ty: &Type,
-> result::Result<IsNull, Box<error::Error + Sync + Send>> { out: &mut Vec<u8>,
) -> result::Result<IsNull, Box<error::Error + Sync + Send>> {
let inner = match *ty.kind() { let inner = match *ty.kind() {
Kind::Domain(ref inner) => inner, Kind::Domain(ref inner) => inner,
_ => unreachable!(), _ => unreachable!(),
@ -295,19 +396,20 @@ fn domain() {
fn accepts(ty: &Type) -> bool { fn accepts(ty: &Type) -> bool {
ty.name() == "session_id" && ty.name() == "session_id" &&
match *ty.kind() { match *ty.kind() {
Kind::Domain(_) => true, Kind::Domain(_) => true,
_ => false, _ => false,
} }
} }
to_sql_checked!(); to_sql_checked!();
} }
impl FromSql for SessionId { impl FromSql for SessionId {
fn from_sql(ty: &Type, fn from_sql(
raw: &[u8]) ty: &Type,
-> result::Result<Self, Box<error::Error + Sync + Send>> { raw: &[u8],
) -> result::Result<Self, Box<error::Error + Sync + Send>> {
Vec::<u8>::from_sql(ty, raw).map(SessionId) Vec::<u8>::from_sql(ty, raw).map(SessionId)
} }
@ -318,9 +420,10 @@ fn domain() {
} }
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap(); let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
conn.batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16); conn.batch_execute(
CREATE TABLE pg_temp.foo (id pg_temp.session_id);") "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);
.unwrap(); CREATE TABLE pg_temp.foo (id pg_temp.session_id);",
).unwrap();
let id = SessionId(b"0123456789abcdef".to_vec()); let id = SessionId(b"0123456789abcdef".to_vec());
conn.execute("INSERT INTO pg_temp.foo (id) VALUES ($1)", &[&id]) conn.execute("INSERT INTO pg_temp.foo (id) VALUES ($1)", &[&id])
@ -332,12 +435,13 @@ fn domain() {
#[test] #[test]
fn composite() { fn composite() {
let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap(); let conn = Connection::connect("postgres://postgres@localhost", TlsMode::None).unwrap();
conn.batch_execute("CREATE TYPE pg_temp.inventory_item AS ( conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT, name TEXT,
supplier INTEGER, supplier INTEGER,
price NUMERIC price NUMERIC
)") )",
.unwrap(); ).unwrap();
let stmt = conn.prepare("SELECT $1::inventory_item").unwrap(); let stmt = conn.prepare("SELECT $1::inventory_item").unwrap();
let type_ = &stmt.param_types()[0]; let type_ = &stmt.param_types()[0];
@ -366,8 +470,10 @@ fn enum_() {
assert_eq!(type_.name(), "mood"); assert_eq!(type_.name(), "mood");
match *type_.kind() { match *type_.kind() {
Kind::Enum(ref variants) => { Kind::Enum(ref variants) => {
assert_eq!(variants, assert_eq!(
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]); variants,
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]
);
} }
_ => panic!("bad type"), _ => panic!("bad type"),
} }

View File

@ -6,18 +6,36 @@ use types::test_type;
#[test] #[test]
fn test_json_params() { fn test_json_params() {
test_type("JSON", &[(Some(Json::from_str("[10, 11, 12]").unwrap()), test_type(
"'[10, 11, 12]'"), "JSON",
(Some(Json::from_str("{\"f\": \"asd\"}").unwrap()), &[
"'{\"f\": \"asd\"}'"), (
(None, "NULL")]) Some(Json::from_str("[10, 11, 12]").unwrap()),
"'[10, 11, 12]'",
),
(
Some(Json::from_str("{\"f\": \"asd\"}").unwrap()),
"'{\"f\": \"asd\"}'",
),
(None, "NULL"),
],
)
} }
#[test] #[test]
fn test_jsonb_params() { fn test_jsonb_params() {
test_type("JSONB", &[(Some(Json::from_str("[10, 11, 12]").unwrap()), test_type(
"'[10, 11, 12]'"), "JSONB",
(Some(Json::from_str("{\"f\": \"asd\"}").unwrap()), &[
"'{\"f\": \"asd\"}'"), (
(None, "NULL")]) Some(Json::from_str("[10, 11, 12]").unwrap()),
"'[10, 11, 12]'",
),
(
Some(Json::from_str("{\"f\": \"asd\"}").unwrap()),
"'{\"f\": \"asd\"}'",
),
(None, "NULL"),
],
)
} }

View File

@ -5,18 +5,36 @@ use types::test_type;
#[test] #[test]
fn test_json_params() { fn test_json_params() {
test_type("JSON", &[(Some(serde_json::from_str::<Value>("[10, 11, 12]").unwrap()), test_type(
"'[10, 11, 12]'"), "JSON",
(Some(serde_json::from_str::<Value>("{\"f\": \"asd\"}").unwrap()), &[
"'{\"f\": \"asd\"}'"), (
(None, "NULL")]) Some(serde_json::from_str::<Value>("[10, 11, 12]").unwrap()),
"'[10, 11, 12]'",
),
(
Some(serde_json::from_str::<Value>("{\"f\": \"asd\"}").unwrap()),
"'{\"f\": \"asd\"}'",
),
(None, "NULL"),
],
)
} }
#[test] #[test]
fn test_jsonb_params() { fn test_jsonb_params() {
test_type("JSONB", &[(Some(serde_json::from_str::<Value>("[10, 11, 12]").unwrap()), test_type(
"'[10, 11, 12]'"), "JSONB",
(Some(serde_json::from_str::<Value>("{\"f\": \"asd\"}").unwrap()), &[
"'{\"f\": \"asd\"}'"), (
(None, "NULL")]) Some(serde_json::from_str::<Value>("[10, 11, 12]").unwrap()),
"'[10, 11, 12]'",
),
(
Some(serde_json::from_str::<Value>("{\"f\": \"asd\"}").unwrap()),
"'{\"f\": \"asd\"}'",
),
(None, "NULL"),
],
)
} }

View File

@ -8,36 +8,65 @@ use postgres::types::Timestamp;
#[test] #[test]
fn test_tm_params() { fn test_tm_params() {
fn make_check<'a>(time: &'a str) -> (Option<Timespec>, &'a str) { fn make_check<'a>(time: &'a str) -> (Option<Timespec>, &'a str) {
(Some(time::strptime(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap().to_timespec()), time) (
Some(
time::strptime(time, "'%Y-%m-%d %H:%M:%S.%f'")
.unwrap()
.to_timespec(),
),
time,
)
} }
test_type("TIMESTAMP", test_type(
&[make_check("'1970-01-01 00:00:00.01'"), "TIMESTAMP",
make_check("'1965-09-25 11:19:33.100314'"), &[
make_check("'2010-02-09 23:11:45.1202'"), make_check("'1970-01-01 00:00:00.01'"),
(None, "NULL")]); make_check("'1965-09-25 11:19:33.100314'"),
test_type("TIMESTAMP WITH TIME ZONE", make_check("'2010-02-09 23:11:45.1202'"),
&[make_check("'1970-01-01 00:00:00.01'"), (None, "NULL"),
make_check("'1965-09-25 11:19:33.100314'"), ],
make_check("'2010-02-09 23:11:45.1202'"), );
(None, "NULL")]); test_type(
"TIMESTAMP WITH TIME ZONE",
&[
make_check("'1970-01-01 00:00:00.01'"),
make_check("'1965-09-25 11:19:33.100314'"),
make_check("'2010-02-09 23:11:45.1202'"),
(None, "NULL"),
],
);
} }
#[test] #[test]
fn test_with_special_tm_params() { fn test_with_special_tm_params() {
fn make_check<'a>(time: &'a str) -> (Timestamp<Timespec>, &'a str) { fn make_check<'a>(time: &'a str) -> (Timestamp<Timespec>, &'a str) {
(Timestamp::Value(time::strptime(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap().to_timespec()), (
time) Timestamp::Value(
time::strptime(time, "'%Y-%m-%d %H:%M:%S.%f'")
.unwrap()
.to_timespec(),
),
time,
)
} }
test_type("TIMESTAMP", test_type(
&[make_check("'1970-01-01 00:00:00.01'"), "TIMESTAMP",
make_check("'1965-09-25 11:19:33.100314'"), &[
make_check("'2010-02-09 23:11:45.1202'"), make_check("'1970-01-01 00:00:00.01'"),
(Timestamp::PosInfinity, "'infinity'"), make_check("'1965-09-25 11:19:33.100314'"),
(Timestamp::NegInfinity, "'-infinity'")]); make_check("'2010-02-09 23:11:45.1202'"),
test_type("TIMESTAMP WITH TIME ZONE", (Timestamp::PosInfinity, "'infinity'"),
&[make_check("'1970-01-01 00:00:00.01'"), (Timestamp::NegInfinity, "'-infinity'"),
make_check("'1965-09-25 11:19:33.100314'"), ],
make_check("'2010-02-09 23:11:45.1202'"), );
(Timestamp::PosInfinity, "'infinity'"), test_type(
(Timestamp::NegInfinity, "'-infinity'")]); "TIMESTAMP WITH TIME ZONE",
&[
make_check("'1970-01-01 00:00:00.01'"),
make_check("'1965-09-25 11:19:33.100314'"),
make_check("'2010-02-09 23:11:45.1202'"),
(Timestamp::PosInfinity, "'infinity'"),
(Timestamp::NegInfinity, "'-infinity'"),
],
);
} }

View File

@ -4,7 +4,16 @@ use types::test_type;
#[test] #[test]
fn test_uuid_params() { fn test_uuid_params() {
test_type("UUID", &[(Some(uuid::Uuid::parse_str("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11").unwrap()), test_type(
"'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'"), "UUID",
(None, "NULL")]) &[
(
Some(
uuid::Uuid::parse_str("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11").unwrap(),
),
"'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'",
),
(None, "NULL"),
],
)
} }

View File

@ -137,21 +137,29 @@ pub enum TlsMode {
/// ///
/// Only the host and port of the connection info are used. See /// Only the host and port of the connection info are used. See
/// `Connection::connect` for details of the `params` argument. /// `Connection::connect` for details of the `params` argument.
pub fn cancel_query<T>(params: T, pub fn cancel_query<T>(
tls_mode: TlsMode, params: T,
cancel_data: CancelData, tls_mode: TlsMode,
handle: &Handle) cancel_data: CancelData,
-> BoxFuture<(), ConnectError> handle: &Handle,
where T: IntoConnectParams ) -> BoxFuture<(), ConnectError>
where
T: IntoConnectParams,
{ {
let params = match params.into_connect_params() { let params = match params.into_connect_params() {
Ok(params) => { Ok(params) => {
Either::A(stream::connect(params.host().clone(), params.port(), tls_mode, handle)) Either::A(stream::connect(
params.host().clone(),
params.port(),
tls_mode,
handle,
))
} }
Err(e) => Either::B(Err(ConnectError::ConnectParams(e)).into_future()), Err(e) => Either::B(Err(ConnectError::ConnectParams(e)).into_future()),
}; };
params.and_then(move |c| { params
.and_then(move |c| {
let mut buf = vec![]; let mut buf = vec![];
frontend::cancel_request(cancel_data.process_id, cancel_data.secret_key, &mut buf); frontend::cancel_request(cancel_data.process_id, cancel_data.secret_key, &mut buf);
c.send(buf).map_err(ConnectError::Io) c.send(buf).map_err(ConnectError::Io)
@ -177,31 +185,29 @@ impl InnerConnection {
fn read(self) -> IoFuture<(backend::Message, InnerConnection)> { fn read(self) -> IoFuture<(backend::Message, InnerConnection)> {
self.into_future() self.into_future()
.map_err(|e| e.0) .map_err(|e| e.0)
.and_then(|(m, mut s)| { .and_then(|(m, mut s)| match m {
match m { Some(backend::Message::NotificationResponse(body)) => {
Some(backend::Message::NotificationResponse(body)) => { let process_id = body.process_id();
let process_id = body.process_id(); let channel = match body.channel() {
let channel = match body.channel() { Ok(channel) => channel.to_owned(),
Ok(channel) => channel.to_owned(), Err(e) => return Either::A(Err(e).into_future()),
Err(e) => return Either::A(Err(e).into_future()), };
}; let message = match body.message() {
let message = match body.message() { Ok(channel) => channel.to_owned(),
Ok(channel) => channel.to_owned(), Err(e) => return Either::A(Err(e).into_future()),
Err(e) => return Either::A(Err(e).into_future()), };
}; let notification = Notification {
let notification = Notification { process_id: process_id,
process_id: process_id, channel: channel,
channel: channel, payload: message,
payload: message, };
}; s.notifications.push_back(notification);
s.notifications.push_back(notification); Either::B(s.read())
Either::B(s.read()) }
} Some(m) => Either::A(Ok((m, s)).into_future()),
Some(m) => Either::A(Ok((m, s)).into_future()), None => {
None => { let err = io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF");
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"); Either::A(Err(err).into_future())
Either::A(Err(err).into_future())
}
} }
}) })
.boxed() .boxed()
@ -247,8 +253,7 @@ pub struct Connection(InnerConnection);
// FIXME fill out // FIXME fill out
impl fmt::Debug for Connection { impl fmt::Debug for Connection {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Connection") fmt.debug_struct("Connection").finish()
.finish()
} }
} }
@ -271,48 +276,54 @@ impl Connection {
/// path contains non-UTF 8 characters, a `ConnectParams` struct should be /// path contains non-UTF 8 characters, a `ConnectParams` struct should be
/// created manually and passed in. Note that Postgres does not support TLS /// created manually and passed in. Note that Postgres does not support TLS
/// over Unix sockets. /// over Unix sockets.
pub fn connect<T>(params: T, pub fn connect<T>(
tls_mode: TlsMode, params: T,
handle: &Handle) tls_mode: TlsMode,
-> BoxFuture<Connection, ConnectError> handle: &Handle,
where T: IntoConnectParams ) -> BoxFuture<Connection, ConnectError>
where
T: IntoConnectParams,
{ {
let fut = match params.into_connect_params() { let fut = match params.into_connect_params() {
Ok(params) => { Ok(params) => {
Either::A(stream::connect(params.host().clone(), params.port(), tls_mode, handle) Either::A(
.map(|s| (s, params))) stream::connect(params.host().clone(), params.port(), tls_mode, handle)
.map(|s| (s, params)),
)
} }
Err(e) => Either::B(Err(ConnectError::ConnectParams(e)).into_future()), Err(e) => Either::B(Err(ConnectError::ConnectParams(e)).into_future()),
}; };
fut.map(|(s, params)| { fut.map(|(s, params)| {
let (sender, receiver) = mpsc::channel(); let (sender, receiver) = mpsc::channel();
(Connection(InnerConnection { (
stream: s, Connection(InnerConnection {
close_sender: sender, stream: s,
close_receiver: receiver, close_sender: sender,
parameters: HashMap::new(), close_receiver: receiver,
types: HashMap::new(), parameters: HashMap::new(),
notifications: VecDeque::new(), types: HashMap::new(),
cancel_data: CancelData { notifications: VecDeque::new(),
process_id: 0, cancel_data: CancelData {
secret_key: 0, process_id: 0,
}, secret_key: 0,
has_typeinfo_query: false, },
has_typeinfo_enum_query: false, has_typeinfo_query: false,
has_typeinfo_composite_query: false, has_typeinfo_enum_query: false,
}), has_typeinfo_composite_query: false,
params) }),
}) params,
.and_then(|(s, params)| s.startup(params)) )
}).and_then(|(s, params)| s.startup(params))
.and_then(|(s, params)| s.handle_auth(params)) .and_then(|(s, params)| s.handle_auth(params))
.and_then(|s| s.finish_startup()) .and_then(|s| s.finish_startup())
.boxed() .boxed()
} }
fn startup(self, fn startup(
params: ConnectParams) self,
-> BoxFuture<(Connection, ConnectParams), ConnectError> { params: ConnectParams,
) -> BoxFuture<(Connection, ConnectParams), ConnectError> {
let mut buf = vec![]; let mut buf = vec![];
let result = { let result = {
let options = [("client_encoding", "UTF8"), ("timezone", "GMT")]; let options = [("client_encoding", "UTF8"), ("timezone", "GMT")];
@ -348,27 +359,35 @@ impl Connection {
.map_err(Into::into) .map_err(Into::into)
} }
None => { None => {
Err(ConnectError::ConnectParams("password was required but not \ Err(ConnectError::ConnectParams(
"password was required but not \
provided" provided"
.into())) .into(),
))
} }
} }
} }
backend::Message::AuthenticationMd5Password(body) => { backend::Message::AuthenticationMd5Password(body) => {
match params.user().and_then(|u| u.password().map(|p| (u.name(), p))) { match params.user().and_then(
|u| u.password().map(|p| (u.name(), p)),
) {
Some((user, pass)) => { Some((user, pass)) => {
let pass = authentication::md5_hash(user.as_bytes(), let pass = authentication::md5_hash(
pass.as_bytes(), user.as_bytes(),
body.salt()); pass.as_bytes(),
body.salt(),
);
let mut buf = vec![]; let mut buf = vec![];
frontend::password_message(&pass, &mut buf) frontend::password_message(&pass, &mut buf)
.map(|()| Some(buf)) .map(|()| Some(buf))
.map_err(Into::into) .map_err(Into::into)
} }
None => { None => {
Err(ConnectError::ConnectParams("password was required but not \ Err(ConnectError::ConnectParams(
"password was required but not \
provided" provided"
.into())) .into(),
))
} }
} }
} }
@ -428,9 +447,10 @@ impl Connection {
} }
// This has its own read_rows since it will need to handle multiple query completions // This has its own read_rows since it will need to handle multiple query completions
fn simple_read_rows(self, fn simple_read_rows(
mut rows: Vec<RowData>) self,
-> BoxFuture<(Vec<RowData>, Connection), Error> { mut rows: Vec<RowData>,
) -> BoxFuture<(Vec<RowData>, Connection), Error> {
self.0 self.0
.read() .read()
.map_err(Error::Io) .map_err(Error::Io)
@ -457,7 +477,8 @@ impl Connection {
} }
fn ready<T>(self, t: T) -> BoxFuture<(T, Connection), Error> fn ready<T>(self, t: T) -> BoxFuture<(T, Connection), Error>
where T: 'static + Send where
T: 'static + Send,
{ {
self.0 self.0
.read() .read()
@ -485,7 +506,9 @@ impl Connection {
frontend::sync(&mut buf); frontend::sync(&mut buf);
messages.push(buf); messages.push(buf);
self.0 self.0
.send_all(futures::stream::iter(messages.into_iter().map(Ok::<_, io::Error>))) .send_all(futures::stream::iter(
messages.into_iter().map(Ok::<_, io::Error>),
))
.map_err(Error::Io) .map_err(Error::Io)
.and_then(|s| Connection(s.0).finish_close_gc()) .and_then(|s| Connection(s.0).finish_close_gc())
.boxed() .boxed()
@ -505,7 +528,8 @@ impl Connection {
} }
fn ready_err<T>(self, body: ErrorResponseBody) -> BoxFuture<T, Error> fn ready_err<T>(self, body: ErrorResponseBody) -> BoxFuture<T, Error>
where T: 'static + Send where
T: 'static + Send,
{ {
DbError::new(&mut body.fields()) DbError::new(&mut body.fields())
.map_err(Error::Io) .map_err(Error::Io)
@ -529,15 +553,14 @@ impl Connection {
/// data in the statement. Do not form statements via string concatenation /// data in the statement. Do not form statements via string concatenation
/// and feed them into this method. /// and feed them into this method.
pub fn batch_execute(self, query: &str) -> BoxFuture<Connection, Error> { pub fn batch_execute(self, query: &str) -> BoxFuture<Connection, Error> {
self.simple_query(query) self.simple_query(query).map(|r| r.1).boxed()
.map(|r| r.1)
.boxed()
} }
fn raw_prepare(self, fn raw_prepare(
name: &str, self,
query: &str) name: &str,
-> BoxFuture<(Vec<Type>, Vec<Column>, Connection), Error> { query: &str,
) -> BoxFuture<(Vec<Type>, Vec<Column>, Connection), Error> {
let mut parse = vec![]; let mut parse = vec![];
let mut describe = vec![]; let mut describe = vec![];
let mut sync = vec![]; let mut sync = vec![];
@ -605,17 +628,19 @@ impl Connection {
.boxed() .boxed()
} }
fn get_types<T, U, I, F, G>(self, fn get_types<T, U, I, F, G>(
mut raw: I, self,
mut out: Vec<U>, mut raw: I,
mut get_oid: F, mut out: Vec<U>,
mut build: G) mut get_oid: F,
-> BoxFuture<(Vec<U>, Connection), Error> mut build: G,
where T: 'static + Send, ) -> BoxFuture<(Vec<U>, Connection), Error>
U: 'static + Send, where
I: 'static + Send + Iterator<Item = T>, T: 'static + Send,
F: 'static + Send + FnMut(&T) -> Oid, U: 'static + Send,
G: 'static + Send + FnMut(T, Type) -> U I: 'static + Send + Iterator<Item = T>,
F: 'static + Send + FnMut(&T) -> Oid,
G: 'static + Send + FnMut(T, Type) -> U,
{ {
match raw.next() { match raw.next() {
Some(v) => { Some(v) => {
@ -651,7 +676,9 @@ impl Connection {
fn get_unknown_type(self, oid: Oid) -> BoxFuture<(Other, Connection), Error> { fn get_unknown_type(self, oid: Oid) -> BoxFuture<(Other, Connection), Error> {
self.setup_typeinfo_query() self.setup_typeinfo_query()
.and_then(move |c| c.raw_execute(TYPEINFO_QUERY, "", &[Type::Oid], &[&oid])) .and_then(move |c| {
c.raw_execute(TYPEINFO_QUERY, "", &[Type::Oid], &[&oid])
})
.and_then(|c| c.read_rows().collect()) .and_then(|c| c.read_rows().collect())
.and_then(move |(r, c)| { .and_then(move |(r, c)| {
let get = |idx| r.get(0).and_then(|r| r.get(idx)); let get = |idx| r.get(0).and_then(|r| r.get(idx));
@ -688,22 +715,42 @@ impl Connection {
let kind = if type_ == b'p' as i8 { let kind = if type_ == b'p' as i8 {
Either::A(Ok((Kind::Pseudo, c)).into_future()) Either::A(Ok((Kind::Pseudo, c)).into_future())
} else if type_ == b'e' as i8 { } else if type_ == b'e' as i8 {
Either::B(c.get_enum_variants(oid).map(|(v, c)| (Kind::Enum(v), c)).boxed()) Either::B(
c.get_enum_variants(oid)
.map(|(v, c)| (Kind::Enum(v), c))
.boxed(),
)
} else if basetype != 0 { } else if basetype != 0 {
Either::B(c.get_type(basetype).map(|(t, c)| (Kind::Domain(t), c)).boxed()) Either::B(
c.get_type(basetype)
.map(|(t, c)| (Kind::Domain(t), c))
.boxed(),
)
} else if elem_oid != 0 { } else if elem_oid != 0 {
Either::B(c.get_type(elem_oid).map(|(t, c)| (Kind::Array(t), c)).boxed()) Either::B(
c.get_type(elem_oid)
.map(|(t, c)| (Kind::Array(t), c))
.boxed(),
)
} else if relid != 0 { } else if relid != 0 {
Either::B(c.get_composite_fields(relid) Either::B(
.map(|(f, c)| (Kind::Composite(f), c)) c.get_composite_fields(relid)
.boxed()) .map(|(f, c)| (Kind::Composite(f), c))
.boxed(),
)
} else if let Some(rngsubtype) = rngsubtype { } else if let Some(rngsubtype) = rngsubtype {
Either::B(c.get_type(rngsubtype).map(|(t, c)| (Kind::Range(t), c)).boxed()) Either::B(
c.get_type(rngsubtype)
.map(|(t, c)| (Kind::Range(t), c))
.boxed(),
)
} else { } else {
Either::A(Ok((Kind::Simple, c)).into_future()) Either::A(Ok((Kind::Simple, c)).into_future())
}; };
Either::B(kind.map(move |(k, c)| (Other::new(name, oid, k, schema), c))) Either::B(kind.map(
move |(k, c)| (Other::new(name, oid, k, schema), c),
))
}) })
.boxed() .boxed()
} }
@ -713,16 +760,17 @@ impl Connection {
return Ok(self).into_future().boxed(); return Ok(self).into_future().boxed();
} }
self.raw_prepare(TYPEINFO_QUERY, self.raw_prepare(
"SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, \ TYPEINFO_QUERY,
"SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, \
t.typbasetype, n.nspname, t.typrelid \ t.typbasetype, n.nspname, t.typrelid \
FROM pg_catalog.pg_type t \ FROM pg_catalog.pg_type t \
LEFT OUTER JOIN pg_catalog.pg_range r ON \ LEFT OUTER JOIN pg_catalog.pg_range r ON \
r.rngtypid = t.oid \ r.rngtypid = t.oid \
INNER JOIN pg_catalog.pg_namespace n ON \ INNER JOIN pg_catalog.pg_namespace n ON \
t.typnamespace = n.oid \ t.typnamespace = n.oid \
WHERE t.oid = $1") WHERE t.oid = $1",
.or_else(|e| { ).or_else(|e| {
match e { match e {
// Range types weren't added until Postgres 9.2, so pg_range may not exist // Range types weren't added until Postgres 9.2, so pg_range may not exist
Error::Db(e, c) => { Error::Db(e, c) => {
@ -730,14 +778,16 @@ impl Connection {
return Either::B(Err(Error::Db(e, c)).into_future()); return Either::B(Err(Error::Db(e, c)).into_future());
} }
Either::A(c.raw_prepare(TYPEINFO_QUERY, Either::A(c.raw_prepare(
"SELECT t.typname, t.typtype, t.typelem, \ TYPEINFO_QUERY,
"SELECT t.typname, t.typtype, t.typelem, \
NULL::OID, t.typbasetype, n.nspname, \ NULL::OID, t.typbasetype, n.nspname, \
t.typrelid \ t.typrelid \
FROM pg_catalog.pg_type t \ FROM pg_catalog.pg_type t \
INNER JOIN pg_catalog.pg_namespace n \ INNER JOIN pg_catalog.pg_namespace n \
ON t.typnamespace = n.oid \ ON t.typnamespace = n.oid \
WHERE t.oid = $1")) WHERE t.oid = $1",
))
} }
e => Either::B(Err(e).into_future()), e => Either::B(Err(e).into_future()),
} }
@ -751,7 +801,9 @@ impl Connection {
fn get_enum_variants(self, oid: Oid) -> BoxFuture<(Vec<String>, Connection), Error> { fn get_enum_variants(self, oid: Oid) -> BoxFuture<(Vec<String>, Connection), Error> {
self.setup_typeinfo_enum_query() self.setup_typeinfo_enum_query()
.and_then(move |c| c.raw_execute(TYPEINFO_ENUM_QUERY, "", &[Type::Oid], &[&oid])) .and_then(move |c| {
c.raw_execute(TYPEINFO_ENUM_QUERY, "", &[Type::Oid], &[&oid])
})
.and_then(|c| c.read_rows().collect()) .and_then(|c| c.read_rows().collect())
.and_then(|(r, c)| { .and_then(|(r, c)| {
let mut variants = vec![]; let mut variants = vec![];
@ -772,20 +824,23 @@ impl Connection {
return Ok(self).into_future().boxed(); return Ok(self).into_future().boxed();
} }
self.raw_prepare(TYPEINFO_ENUM_QUERY, self.raw_prepare(
"SELECT enumlabel \ TYPEINFO_ENUM_QUERY,
"SELECT enumlabel \
FROM pg_catalog.pg_enum \ FROM pg_catalog.pg_enum \
WHERE enumtypid = $1 \ WHERE enumtypid = $1 \
ORDER BY enumsortorder") ORDER BY enumsortorder",
.or_else(|e| match e { ).or_else(|e| match e {
Error::Db(e, c) => { Error::Db(e, c) => {
if e.code != SqlState::UndefinedColumn { if e.code != SqlState::UndefinedColumn {
return Either::B(Err(Error::Db(e, c)).into_future()); return Either::B(Err(Error::Db(e, c)).into_future());
} }
Either::A(c.raw_prepare(TYPEINFO_ENUM_QUERY, Either::A(c.raw_prepare(
"SELECT enumlabel FROM pg_catalog.pg_enum WHERE \ TYPEINFO_ENUM_QUERY,
enumtypid = $1 ORDER BY oid")) "SELECT enumlabel FROM pg_catalog.pg_enum WHERE \
enumtypid = $1 ORDER BY oid",
))
} }
e => Either::B(Err(e).into_future()), e => Either::B(Err(e).into_future()),
}) })
@ -803,8 +858,9 @@ impl Connection {
}) })
.and_then(|c| c.read_rows().collect()) .and_then(|c| c.read_rows().collect())
.and_then(|(r, c)| { .and_then(|(r, c)| {
futures::stream::iter(r.into_iter().map(Ok)) futures::stream::iter(r.into_iter().map(Ok)).fold(
.fold((vec![], c), |(mut fields, c), row| { (vec![], c),
|(mut fields, c), row| {
let name = match String::from_sql_nullable(&Type::Name, row.get(0)) { let name = match String::from_sql_nullable(&Type::Name, row.get(0)) {
Ok(name) => name, Ok(name) => name,
Err(e) => return Either::A(Err(Error::Conversion(e, c)).into_future()), Err(e) => return Either::A(Err(Error::Conversion(e, c)).into_future()),
@ -813,12 +869,12 @@ impl Connection {
Ok(oid) => oid, Ok(oid) => oid,
Err(e) => return Either::A(Err(Error::Conversion(e, c)).into_future()), Err(e) => return Either::A(Err(Error::Conversion(e, c)).into_future()),
}; };
Either::B(c.get_type(oid) Either::B(c.get_type(oid).map(move |(ty, c)| {
.map(move |(ty, c)| { fields.push(Field::new(name, ty));
fields.push(Field::new(name, ty)); (fields, c)
(fields, c) }))
})) },
}) )
}) })
.boxed() .boxed()
} }
@ -828,46 +884,52 @@ impl Connection {
return Ok(self).into_future().boxed(); return Ok(self).into_future().boxed();
} }
self.raw_prepare(TYPEINFO_COMPOSITE_QUERY, self.raw_prepare(
"SELECT attname, atttypid \ TYPEINFO_COMPOSITE_QUERY,
"SELECT attname, atttypid \
FROM pg_catalog.pg_attribute \ FROM pg_catalog.pg_attribute \
WHERE attrelid = $1 \ WHERE attrelid = $1 \
AND NOT attisdropped \ AND NOT attisdropped \
AND attnum > 0 \ AND attnum > 0 \
ORDER BY attnum") ORDER BY attnum",
.map(|(_, _, mut c)| { ).map(|(_, _, mut c)| {
c.0.has_typeinfo_composite_query = true; c.0.has_typeinfo_composite_query = true;
c c
}) })
.boxed() .boxed()
} }
fn raw_execute(self, fn raw_execute(
stmt: &str, self,
portal: &str, stmt: &str,
param_types: &[Type], portal: &str,
params: &[&ToSql]) param_types: &[Type],
-> BoxFuture<Connection, Error> { params: &[&ToSql],
assert!(param_types.len() == params.len(), ) -> BoxFuture<Connection, Error> {
"expected {} parameters but got {}", assert!(
param_types.len(), param_types.len() == params.len(),
params.len()); "expected {} parameters but got {}",
param_types.len(),
params.len()
);
let mut bind = vec![]; let mut bind = vec![];
let mut execute = vec![]; let mut execute = vec![];
let mut sync = vec![]; let mut sync = vec![];
frontend::sync(&mut sync); frontend::sync(&mut sync);
let r = frontend::bind(portal, let r = frontend::bind(
stmt, portal,
Some(1), stmt,
params.iter().zip(param_types), Some(1),
|(param, ty), buf| match param.to_sql_checked(ty, buf) { params.iter().zip(param_types),
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), |(param, ty), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
Err(e) => Err(e), Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
}, Err(e) => Err(e),
Some(1), },
&mut bind); Some(1),
&mut bind,
);
let r = match r { let r = match r {
Ok(()) => Ok(self), Ok(()) => Ok(self),
Err(frontend::BindError::Conversion(e)) => Err(Error::Conversion(e, self)), Err(frontend::BindError::Conversion(e)) => Err(Error::Conversion(e, self)),
@ -875,11 +937,10 @@ impl Connection {
}; };
r.and_then(|s| { r.and_then(|s| {
frontend::execute(portal, 0, &mut execute) frontend::execute(portal, 0, &mut execute)
.map(|()| s) .map(|()| s)
.map_err(Error::Io) .map_err(Error::Io)
}) }).into_future()
.into_future()
.and_then(|s| { .and_then(|s| {
let it = Some(bind) let it = Some(bind)
.into_iter() .into_iter()
@ -901,43 +962,35 @@ impl Connection {
self.0 self.0
.read() .read()
.map_err(Error::Io) .map_err(Error::Io)
.and_then(|(m, s)| { .and_then(|(m, s)| match m {
match m { backend::Message::DataRow(_) => Connection(s).finish_execute().boxed(),
backend::Message::DataRow(_) => Connection(s).finish_execute().boxed(), backend::Message::CommandComplete(body) => {
backend::Message::CommandComplete(body) => { body.tag()
body.tag() .map(|tag| {
.map(|tag| { tag.split_whitespace().last().unwrap().parse().unwrap_or(0)
tag.split_whitespace() })
.last() .map_err(Error::Io)
.unwrap() .into_future()
.parse() .and_then(|n| Connection(s).ready(n))
.unwrap_or(0) .boxed()
})
.map_err(Error::Io)
.into_future()
.and_then(|n| Connection(s).ready(n))
.boxed()
}
backend::Message::EmptyQueryResponse => Connection(s).ready(0).boxed(),
backend::Message::ErrorResponse(body) => Connection(s).ready_err(body).boxed(),
_ => Err(bad_message()).into_future().boxed(),
} }
backend::Message::EmptyQueryResponse => Connection(s).ready(0).boxed(),
backend::Message::ErrorResponse(body) => Connection(s).ready_err(body).boxed(),
_ => Err(bad_message()).into_future().boxed(),
}) })
.boxed() .boxed()
} }
fn read_rows(self) -> BoxStateStream<RowData, Connection, Error> { fn read_rows(self) -> BoxStateStream<RowData, Connection, Error> {
futures_state_stream::unfold(self, |c| { futures_state_stream::unfold(self, |c| {
c.read_row() c.read_row().and_then(|(r, c)| match r {
.and_then(|(r, c)| match r { Some(data) => {
Some(data) => { let event = StreamEvent::Next((data, c));
let event = StreamEvent::Next((data, c)); Either::A(Ok(event).into_future())
Either::A(Ok(event).into_future()) }
} None => Either::B(c.ready(()).map(|((), c)| StreamEvent::Done(c))),
None => Either::B(c.ready(()).map(|((), c)| StreamEvent::Done(c))),
})
}) })
.boxed() }).boxed()
} }
fn read_row(self) -> BoxFuture<(Option<RowData>, Connection), Error> { fn read_row(self) -> BoxFuture<(Option<RowData>, Connection), Error> {
@ -948,10 +1001,12 @@ impl Connection {
let c = Connection(s); let c = Connection(s);
match m { match m {
backend::Message::DataRow(body) => { backend::Message::DataRow(body) => {
Either::A(RowData::new(body) Either::A(
.map(|r| (Some(r), c)) RowData::new(body)
.map_err(Error::Io) .map(|r| (Some(r), c))
.into_future()) .map_err(Error::Io)
.into_future(),
)
} }
backend::Message::EmptyQueryResponse | backend::Message::EmptyQueryResponse |
backend::Message::CommandComplete(_) => Either::A(Ok((None, c)).into_future()), backend::Message::CommandComplete(_) => Either::A(Ok((None, c)).into_future()),
@ -981,10 +1036,11 @@ impl Connection {
/// ///
/// Panics if the number of parameters provided does not match the number /// Panics if the number of parameters provided does not match the number
/// expected. /// expected.
pub fn execute(self, pub fn execute(
statement: &Statement, self,
params: &[&ToSql]) statement: &Statement,
-> BoxFuture<(u64, Connection), Error> { params: &[&ToSql],
) -> BoxFuture<(u64, Connection), Error> {
self.raw_execute(statement.name(), "", statement.parameters(), params) self.raw_execute(statement.name(), "", statement.parameters(), params)
.and_then(|conn| conn.finish_execute()) .and_then(|conn| conn.finish_execute())
.boxed() .boxed()
@ -996,10 +1052,11 @@ impl Connection {
/// ///
/// Panics if the number of parameters provided does not match the number /// Panics if the number of parameters provided does not match the number
/// expected. /// expected.
pub fn query(self, pub fn query(
statement: &Statement, self,
params: &[&ToSql]) statement: &Statement,
-> BoxStateStream<Row, Connection, Error> { params: &[&ToSql],
) -> BoxStateStream<Row, Connection, Error> {
let columns = statement.columns_arc().clone(); let columns = statement.columns_arc().clone();
self.raw_execute(statement.name(), "", statement.parameters(), params) self.raw_execute(statement.name(), "", statement.parameters(), params)
.map(|c| c.read_rows().map(move |r| Row::new(columns.clone(), r))) .map(|c| c.read_rows().map(move |r| Row::new(columns.clone(), r)))
@ -1077,7 +1134,8 @@ fn connect_err(fields: &mut ErrorFields) -> ConnectError {
} }
fn bad_message<T>() -> T fn bad_message<T>() -> T
where T: From<io::Error> where
T: From<io::Error>,
{ {
io::Error::new(io::ErrorKind::InvalidInput, "unexpected message").into() io::Error::new(io::ErrorKind::InvalidInput, "unexpected message").into()
} }
@ -1087,11 +1145,12 @@ trait RowNew {
} }
trait StatementNew { trait StatementNew {
fn new(close_sender: Sender<(u8, String)>, fn new(
name: String, close_sender: Sender<(u8, String)>,
params: Vec<Type>, name: String,
columns: Arc<Vec<Column>>) params: Vec<Type>,
-> Statement; columns: Arc<Vec<Column>>,
) -> Statement;
fn columns_arc(&self) -> &Arc<Vec<Column>>; fn columns_arc(&self) -> &Arc<Vec<Column>>;

View File

@ -48,8 +48,9 @@ impl Row {
/// Panics if the index does not reference a column or the return type is /// Panics if the index does not reference a column or the return type is
/// not compatible with the Postgres type. /// not compatible with the Postgres type.
pub fn get<T, I>(&self, idx: I) -> T pub fn get<T, I>(&self, idx: I) -> T
where T: FromSql, where
I: RowIndex + fmt::Debug T: FromSql,
I: RowIndex + fmt::Debug,
{ {
match self.try_get(&idx) { match self.try_get(&idx) {
Ok(Some(v)) => v, Ok(Some(v)) => v,
@ -67,8 +68,9 @@ impl Row {
/// if there was an error converting the result value, and `Some(Ok(..))` /// if there was an error converting the result value, and `Some(Ok(..))`
/// on success. /// on success.
pub fn try_get<T, I>(&self, idx: I) -> Result<Option<T>, Box<Error + Sync + Send>> pub fn try_get<T, I>(&self, idx: I) -> Result<Option<T>, Box<Error + Sync + Send>>
where T: FromSql, where
I: RowIndex T: FromSql,
I: RowIndex,
{ {
let idx = match idx.idx(&self.columns) { let idx = match idx.idx(&self.columns) {
Some(idx) => idx, Some(idx) => idx,

View File

@ -19,11 +19,12 @@ pub struct Statement {
} }
impl StatementNew for Statement { impl StatementNew for Statement {
fn new(close_sender: Sender<(u8, String)>, fn new(
name: String, close_sender: Sender<(u8, String)>,
params: Vec<Type>, name: String,
columns: Arc<Vec<Column>>) params: Vec<Type>,
-> Statement { columns: Arc<Vec<Column>>,
) -> Statement {
Statement { Statement {
close_sender: close_sender, close_sender: close_sender,
name: name, name: name,

View File

@ -20,31 +20,39 @@ use tls::TlsStream;
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>; pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
pub fn connect(host: Host, pub fn connect(
port: u16, host: Host,
tls_mode: TlsMode, port: u16,
handle: &Handle) tls_mode: TlsMode,
-> BoxFuture<PostgresStream, ConnectError> { handle: &Handle,
) -> BoxFuture<PostgresStream, ConnectError> {
let inner = match host { let inner = match host {
Host::Tcp(ref host) => { Host::Tcp(ref host) => {
Either::A(tokio_dns::tcp_connect((&**host, port), handle.remote().clone()) Either::A(
.map(|s| Stream(InnerStream::Tcp(s))) tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
.map_err(ConnectError::Io)) .map(|s| Stream(InnerStream::Tcp(s)))
.map_err(ConnectError::Io),
)
} }
#[cfg(unix)] #[cfg(unix)]
Host::Unix(ref host) => { Host::Unix(ref host) => {
let addr = host.join(format!(".s.PGSQL.{}", port)); let addr = host.join(format!(".s.PGSQL.{}", port));
Either::B(UnixStream::connect(addr, handle) Either::B(
.map(|s| Stream(InnerStream::Unix(s))) UnixStream::connect(addr, handle)
.map_err(ConnectError::Io) .map(|s| Stream(InnerStream::Unix(s)))
.into_future()) .map_err(ConnectError::Io)
.into_future(),
)
} }
#[cfg(not(unix))] #[cfg(not(unix))]
Host::Unix(_) => { Host::Unix(_) => {
Either::B(Err(ConnectError::ConnectParams("unix sockets are not supported on this \ Either::B(
Err(ConnectError::ConnectParams(
"unix sockets are not supported on this \
platform" platform"
.into())) .into(),
.into_future()) )).into_future(),
)
} }
}; };
@ -52,7 +60,8 @@ pub fn connect(host: Host,
TlsMode::Require(h) => (true, h), TlsMode::Require(h) => (true, h),
TlsMode::Prefer(h) => (false, h), TlsMode::Prefer(h) => (false, h),
TlsMode::None => { TlsMode::None => {
return inner.map(|s| { return inner
.map(|s| {
let s: Box<TlsStream> = Box::new(s); let s: Box<TlsStream> = Box::new(s);
s.framed(PostgresCodec) s.framed(PostgresCodec)
}) })
@ -60,29 +69,34 @@ pub fn connect(host: Host,
} }
}; };
inner.map(|s| s.framed(SslCodec)) inner
.map(|s| s.framed(SslCodec))
.and_then(|s| { .and_then(|s| {
let mut buf = vec![]; let mut buf = vec![];
frontend::ssl_request(&mut buf); frontend::ssl_request(&mut buf);
s.send(buf) s.send(buf).map_err(ConnectError::Io)
.map_err(ConnectError::Io)
}) })
.and_then(|s| s.into_future().map_err(|e| ConnectError::Io(e.0))) .and_then(|s| s.into_future().map_err(|e| ConnectError::Io(e.0)))
.and_then(move |(m, s)| { .and_then(move |(m, s)| {
let s = s.into_inner(); let s = s.into_inner();
match (m, required) { match (m, required) {
(Some(b'N'), true) => { (Some(b'N'), true) => {
Either::A(Err(ConnectError::Tls("the server does not support TLS".into())) Either::A(
.into_future()) Err(ConnectError::Tls("the server does not support TLS".into()))
.into_future(),
)
} }
(Some(b'N'), false) => { (Some(b'N'), false) => {
let s: Box<TlsStream> = Box::new(s); let s: Box<TlsStream> = Box::new(s);
Either::A(Ok(s).into_future()) Either::A(Ok(s).into_future())
} }
(None, _) => { (None, _) => {
Either::A(Err(ConnectError::Io(io::Error::new(io::ErrorKind::UnexpectedEof, Either::A(
"unexpected EOF"))) Err(ConnectError::Io(io::Error::new(
.into_future()) io::ErrorKind::UnexpectedEof,
"unexpected EOF",
))).into_future(),
)
} }
_ => { _ => {
let host = match host { let host = match host {
@ -144,7 +158,8 @@ impl AsyncRead for Stream {
} }
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error> fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where B: BufMut where
B: BufMut,
{ {
match self.0 { match self.0 {
InnerStream::Tcp(ref mut s) => s.read_buf(buf), InnerStream::Tcp(ref mut s) => s.read_buf(buf),

View File

@ -14,9 +14,11 @@ use types::{ToSql, FromSql, Type, IsNull, Kind};
fn md5_user() { fn md5_user() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://md5_user:password@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://md5_user:password@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
l.run(done).unwrap(); l.run(done).unwrap();
} }
@ -24,9 +26,11 @@ fn md5_user() {
fn md5_user_no_pass() { fn md5_user_no_pass() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://md5_user@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://md5_user@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
match l.run(done) { match l.run(done) {
Err(ConnectError::ConnectParams(_)) => {} Err(ConnectError::ConnectParams(_)) => {}
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
@ -38,9 +42,11 @@ fn md5_user_no_pass() {
fn md5_user_wrong_pass() { fn md5_user_wrong_pass() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://md5_user:foobar@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://md5_user:foobar@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
match l.run(done) { match l.run(done) {
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {}
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
@ -52,9 +58,11 @@ fn md5_user_wrong_pass() {
fn pass_user() { fn pass_user() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://pass_user:password@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://pass_user:password@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
l.run(done).unwrap(); l.run(done).unwrap();
} }
@ -62,9 +70,11 @@ fn pass_user() {
fn pass_user_no_pass() { fn pass_user_no_pass() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://pass_user@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://pass_user@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
match l.run(done) { match l.run(done) {
Err(ConnectError::ConnectParams(_)) => {} Err(ConnectError::ConnectParams(_)) => {}
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
@ -76,9 +86,11 @@ fn pass_user_no_pass() {
fn pass_user_wrong_pass() { fn pass_user_wrong_pass() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://pass_user:foobar@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://pass_user:foobar@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
match l.run(done) { match l.run(done) {
Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {} Err(ConnectError::Db(ref e)) if e.code == SqlState::InvalidPassword => {}
Err(e) => panic!("unexpected error {}", e), Err(e) => panic!("unexpected error {}", e),
@ -90,7 +102,11 @@ fn pass_user_wrong_pass() {
fn batch_execute_ok() { fn batch_execute_ok() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL);")); .then(|c| {
c.unwrap().batch_execute(
"CREATE TEMPORARY TABLE foo (id SERIAL);",
)
});
l.run(done).unwrap(); l.run(done).unwrap();
} }
@ -99,9 +115,10 @@ fn batch_execute_err() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|r| { .then(|r| {
r.unwrap() r.unwrap().batch_execute(
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL); INSERT INTO foo DEFAULT \ "CREATE TEMPORARY TABLE foo (id SERIAL); INSERT INTO foo DEFAULT \
VALUES;") VALUES;",
)
}) })
.and_then(|c| c.batch_execute("SELECT * FROM bogo")) .and_then(|c| c.batch_execute("SELECT * FROM bogo"))
.then(|r| match r { .then(|r| match r {
@ -120,7 +137,9 @@ fn prepare_execute() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| { .then(|c| {
c.unwrap().prepare("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)") c.unwrap().prepare(
"CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)",
)
}) })
.and_then(|(s, c)| c.execute(&s, &[])) .and_then(|(s, c)| c.execute(&s, &[]))
.and_then(|(n, c)| { .and_then(|(n, c)| {
@ -146,8 +165,10 @@ fn query() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
.then(|c| { .then(|c| {
c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR); c.unwrap().batch_execute(
INSERT INTO foo (name) VALUES ('joe'), ('bob')") "CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);
INSERT INTO foo (name) VALUES ('joe'), ('bob')",
)
}) })
.and_then(|c| c.prepare("SELECT id, name FROM foo ORDER BY id")) .and_then(|c| c.prepare("SELECT id, name FROM foo ORDER BY id"))
.and_then(|(s, c)| c.query(&s, &[]).collect()) .and_then(|(s, c)| c.query(&s, &[]).collect())
@ -166,23 +187,32 @@ fn query() {
#[test] #[test]
fn transaction() { fn transaction() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let done = let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle())
Connection::connect("postgres://postgres@localhost", TlsMode::None, &l.handle()) .then(|c| {
.then(|c| { c.unwrap().batch_execute(
c.unwrap().batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);") "CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);",
}) )
.then(|c| c.unwrap().transaction()) })
.then(|t| t.unwrap().batch_execute("INSERT INTO foo (name) VALUES ('joe');")) .then(|c| c.unwrap().transaction())
.then(|t| t.unwrap().rollback()) .then(|t| {
.then(|c| c.unwrap().transaction()) t.unwrap().batch_execute(
.then(|t| t.unwrap().batch_execute("INSERT INTO foo (name) VALUES ('bob');")) "INSERT INTO foo (name) VALUES ('joe');",
.then(|t| t.unwrap().commit()) )
.then(|c| c.unwrap().prepare("SELECT name FROM foo")) })
.and_then(|(s, c)| c.query(&s, &[]).collect()) .then(|t| t.unwrap().rollback())
.map(|(r, _)| { .then(|c| c.unwrap().transaction())
assert_eq!(r.len(), 1); .then(|t| {
assert_eq!(r[0].get::<String, _>("name"), "bob"); t.unwrap().batch_execute(
}); "INSERT INTO foo (name) VALUES ('bob');",
)
})
.then(|t| t.unwrap().commit())
.then(|c| c.unwrap().prepare("SELECT name FROM foo"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| {
assert_eq!(r.len(), 1);
assert_eq!(r[0].get::<String, _>("name"), "bob");
});
l.run(done).unwrap(); l.run(done).unwrap();
} }
@ -195,9 +225,11 @@ fn unix_socket() {
.and_then(|(s, c)| c.query(&s, &[]).collect()) .and_then(|(s, c)| c.query(&s, &[]).collect())
.then(|r| { .then(|r| {
let r = r.unwrap().0; let r = r.unwrap().0;
let params = ConnectParams::builder() let params = ConnectParams::builder().user("postgres", None).build(
.user("postgres", None) Host::Unix(
.build(Host::Unix(PathBuf::from(r[0].get::<String, _>(0)))); PathBuf::from(r[0].get::<String, _>(0)),
),
);
Connection::connect(params, TlsMode::None, &handle) Connection::connect(params, TlsMode::None, &handle)
}) })
.then(|c| c.unwrap().batch_execute("")); .then(|c| c.unwrap().batch_execute(""));
@ -209,9 +241,11 @@ fn ssl_user_ssl_required() {
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://ssl_user@localhost/postgres", let done = Connection::connect(
TlsMode::None, "postgres://ssl_user@localhost/postgres",
&handle); TlsMode::None,
&handle,
);
match l.run(done) { match l.run(done) {
Err(ConnectError::Db(e)) => assert!(e.code == SqlState::InvalidAuthorizationSpecification), Err(ConnectError::Db(e)) => assert!(e.code == SqlState::InvalidAuthorizationSpecification),
@ -227,14 +261,18 @@ fn openssl_required() {
use tls::openssl::OpenSsl; use tls::openssl::OpenSsl;
let mut builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap(); let mut builder = SslConnectorBuilder::new(SslMethod::tls()).unwrap();
builder.builder_mut().set_ca_file("../.travis/server.crt").unwrap(); builder
.builder_mut()
.set_ca_file("../.travis/server.crt")
.unwrap();
let negotiator = OpenSsl::from(builder.build()); let negotiator = OpenSsl::from(builder.build());
let mut l = Core::new().unwrap(); let mut l = Core::new().unwrap();
let done = Connection::connect("postgres://ssl_user@localhost/postgres", let done = Connection::connect(
TlsMode::Require(Box::new(negotiator)), "postgres://ssl_user@localhost/postgres",
&l.handle()) TlsMode::Require(Box::new(negotiator)),
.then(|c| c.unwrap().prepare("SELECT 1")) &l.handle(),
).then(|c| c.unwrap().prepare("SELECT 1"))
.and_then(|(s, c)| c.query(&s, &[]).collect()) .and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| assert_eq!(r[0].get::<i32, _>(0), 1)); .map(|(r, _)| assert_eq!(r[0].get::<i32, _>(0), 1));
l.run(done).unwrap(); l.run(done).unwrap();
@ -246,10 +284,11 @@ fn domain() {
struct SessionId(Vec<u8>); struct SessionId(Vec<u8>);
impl ToSql for SessionId { impl ToSql for SessionId {
fn to_sql(&self, fn to_sql(
ty: &Type, &self,
out: &mut Vec<u8>) ty: &Type,
-> Result<IsNull, Box<StdError + Sync + Send>> { out: &mut Vec<u8>,
) -> Result<IsNull, Box<StdError + Sync + Send>> {
let inner = match *ty.kind() { let inner = match *ty.kind() {
Kind::Domain(ref inner) => inner, Kind::Domain(ref inner) => inner,
_ => unreachable!(), _ => unreachable!(),
@ -282,10 +321,12 @@ fn domain() {
let handle = l.handle(); let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
.then(|c| { .then(|c| {
c.unwrap().batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea \ c.unwrap().batch_execute(
"CREATE DOMAIN pg_temp.session_id AS bytea \
CHECK(octet_length(VALUE) = 16); CHECK(octet_length(VALUE) = 16);
CREATE \ CREATE \
TABLE pg_temp.foo (id pg_temp.session_id);") TABLE pg_temp.foo (id pg_temp.session_id);",
)
}) })
.and_then(|c| c.prepare("INSERT INTO pg_temp.foo (id) VALUES ($1)")) .and_then(|c| c.prepare("INSERT INTO pg_temp.foo (id) VALUES ($1)"))
.and_then(|(s, c)| { .and_then(|(s, c)| {
@ -309,11 +350,13 @@ fn composite() {
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
.then(|c| { .then(|c| {
c.unwrap().batch_execute("CREATE TYPE pg_temp.inventory_item AS ( c.unwrap().batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT, name TEXT,
supplier INTEGER, supplier INTEGER,
price NUMERIC price NUMERIC
)") )",
)
}) })
.and_then(|c| c.prepare("SELECT $1::inventory_item")) .and_then(|c| c.prepare("SELECT $1::inventory_item"))
.map(|(s, _)| { .map(|(s, _)| {
@ -341,7 +384,9 @@ fn enum_() {
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
.then(|c| { .then(|c| {
c.unwrap().batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');") c.unwrap().batch_execute(
"CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');",
)
}) })
.and_then(|c| c.prepare("SELECT $1::mood")) .and_then(|c| c.prepare("SELECT $1::mood"))
.map(|(s, _)| { .map(|(s, _)| {
@ -349,8 +394,10 @@ fn enum_() {
assert_eq!(type_.name(), "mood"); assert_eq!(type_.name(), "mood");
match *type_.kind() { match *type_.kind() {
Kind::Enum(ref variants) => { Kind::Enum(ref variants) => {
assert_eq!(variants, assert_eq!(
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]); variants,
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]
);
} }
_ => panic!("bad type"), _ => panic!("bad type"),
} }
@ -373,10 +420,12 @@ fn cancel() {
.into_future() .into_future()
.then(move |r| { .then(move |r| {
assert!(r.is_ok()); assert!(r.is_ok());
cancel_query("postgres://postgres@localhost", cancel_query(
TlsMode::None, "postgres://postgres@localhost",
cancel_data, TlsMode::None,
&handle) cancel_data,
&handle,
)
}) })
.then(Ok::<_, ()>); .then(Ok::<_, ()>);
c.batch_execute("SELECT pg_sleep(10)") c.batch_execute("SELECT pg_sleep(10)")
@ -401,10 +450,13 @@ fn notifications() {
let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) let done = Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle)
.then(|c| c.unwrap().batch_execute("LISTEN test_notifications")) .then(|c| c.unwrap().batch_execute("LISTEN test_notifications"))
.and_then(|c1| { .and_then(|c1| {
Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle) Connection::connect("postgres://postgres@localhost", TlsMode::None, &handle).then(
.then(|c2| { |c2| {
c2.unwrap().batch_execute("NOTIFY test_notifications, 'foo'").map(|_| c1) c2.unwrap()
}) .batch_execute("NOTIFY test_notifications, 'foo'")
.map(|_| c1)
},
)
}) })
.and_then(|c| c.notifications().into_future().map_err(|(e, _)| e)) .and_then(|c| c.notifications().into_future().map_err(|(e, _)| e))
.map(|(n, _)| { .map(|(n, _)| {

View File

@ -31,8 +31,9 @@ impl TlsStream for Stream {
/// A trait implemented by types that can manage TLS encryption for a stream. /// A trait implemented by types that can manage TLS encryption for a stream.
pub trait Handshake: 'static + Sync + Send { pub trait Handshake: 'static + Sync + Send {
/// Performs a TLS handshake, returning a wrapped stream. /// Performs a TLS handshake, returning a wrapped stream.
fn handshake(self: Box<Self>, fn handshake(
host: &str, self: Box<Self>,
stream: Stream) host: &str,
-> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>>; stream: Stream,
) -> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>>;
} }

View File

@ -38,10 +38,11 @@ impl From<SslConnector> for OpenSsl {
} }
impl Handshake for OpenSsl { impl Handshake for OpenSsl {
fn handshake(self: Box<Self>, fn handshake(
host: &str, self: Box<Self>,
stream: Stream) host: &str,
-> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>> { stream: Stream,
) -> BoxFuture<Box<TlsStream>, Box<Error + Sync + Send>> {
self.0 self.0
.connect_async(host, stream) .connect_async(host, stream)
.map(|s| { .map(|s| {

View File

@ -39,10 +39,11 @@ impl Transaction {
} }
/// Like `Connection::execute`. /// Like `Connection::execute`.
pub fn execute(self, pub fn execute(
statement: &Statement, self,
params: &[&ToSql]) statement: &Statement,
-> BoxFuture<(u64, Transaction), Error<Transaction>> { params: &[&ToSql],
) -> BoxFuture<(u64, Transaction), Error<Transaction>> {
self.0 self.0
.execute(statement, params) .execute(statement, params)
.map(|(n, c)| (n, Transaction(c))) .map(|(n, c)| (n, Transaction(c)))
@ -51,10 +52,11 @@ impl Transaction {
} }
/// Like `Connection::query`. /// Like `Connection::query`.
pub fn query(self, pub fn query(
statement: &Statement, self,
params: &[&ToSql]) statement: &Statement,
-> BoxStateStream<Row, Transaction, Error<Transaction>> { params: &[&ToSql],
) -> BoxStateStream<Row, Transaction, Error<Transaction>> {
self.0 self.0
.query(statement, params) .query(statement, params)
.map_state(Transaction) .map_state(Transaction)
@ -73,10 +75,7 @@ impl Transaction {
} }
fn finish(self, query: &str) -> BoxFuture<Connection, Error> { fn finish(self, query: &str) -> BoxFuture<Connection, Error> {
self.0 self.0.simple_query(query).map(|(_, c)| c).boxed()
.simple_query(query)
.map(|(_, c)| c)
.boxed()
} }
} }

View File

@ -1,8 +1,8 @@
//! Postgres types //! Postgres types
#[doc(inline)] #[doc(inline)]
pub use postgres_shared::types::{Oid, Type, Date, Timestamp, Kind, Field, Other, WasNull, WrongType, pub use postgres_shared::types::{Oid, Type, Date, Timestamp, Kind, Field, Other, WasNull,
FromSql, IsNull, ToSql}; WrongType, FromSql, IsNull, ToSql};
#[doc(hidden)] #[doc(hidden)]
pub use postgres_shared::types::__to_sql_checked; pub use postgres_shared::types::__to_sql_checked;