Fix parameter parsing and add test

Our behavior matches libpq's - in particular it allows any escape
sequence and trailing \'s...
This commit is contained in:
Steven Fackler 2018-12-16 18:11:52 -08:00
parent 7297661cef
commit 707b87a18e
3 changed files with 109 additions and 39 deletions

View File

@ -1,5 +1,4 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::hash_map::{self, HashMap};
use std::iter;
use std::str::{self, FromStr};
use tokio_io::{AsyncRead, AsyncWrite};
@ -44,6 +43,11 @@ impl Builder {
self
}
/// FIXME do we want this?
pub fn iter(&self) -> Iter<'_> {
Iter(self.params.iter())
}
pub fn connect<S, T>(&self, stream: S, tls_mode: T) -> Connect<S, T>
where
S: AsyncRead + AsyncWrite,
@ -61,13 +65,30 @@ impl FromStr for Builder {
let mut builder = Builder::new();
while let Some((key, value)) = parser.parameter()? {
builder.param(key, &value);
builder.params.insert(key.to_string(), value);
}
Ok(builder)
}
}
#[derive(Debug, Clone)]
pub struct Iter<'a>(hash_map::Iter<'a, String, String>);
impl<'a> Iterator for Iter<'a> {
type Item = (&'a str, &'a str);
fn next(&mut self) -> Option<(&'a str, &'a str)> {
self.0.next().map(|(k, v)| (&**k, &**v))
}
}
impl<'a> ExactSizeIterator for Iter<'a> {
fn len(&self) -> usize {
self.0.len()
}
}
struct Parser<'a> {
s: &'a str,
it: iter::Peekable<str::CharIndices<'a>>,
@ -82,9 +103,7 @@ impl<'a> Parser<'a> {
}
fn skip_ws(&mut self) {
while let Some(&(_, ' ')) = self.it.peek() {
self.it.next();
}
self.take_while(|c| c.is_whitespace());
}
fn take_while<F>(&mut self, f: F) -> &'a str
@ -133,7 +152,8 @@ impl<'a> Parser<'a> {
fn keyword(&mut self) -> Option<&'a str> {
let s = self.take_while(|c| match c {
' ' | '=' => false,
c if c.is_whitespace() => false,
'=' => false,
_ => true,
});
@ -144,52 +164,67 @@ impl<'a> Parser<'a> {
}
}
fn value(&mut self) -> Result<Cow<'a, str>, Error> {
let raw = if self.eat_if('\'') {
let s = self.take_while(|c| c != '\'');
fn value(&mut self) -> Result<String, Error> {
let value = if self.eat_if('\'') {
let value = self.quoted_value()?;
self.eat('\'')?;
s
value
} else {
let s = self.take_while(|c| c != ' ');
if s.is_empty() {
return Err(Error::connection_syntax("unexpected EOF".into()));
}
s
self.simple_value()?
};
self.unescape_value(raw)
Ok(value)
}
fn unescape_value(&mut self, raw: &'a str) -> Result<Cow<'a, str>, Error> {
if !raw.contains('\\') {
return Ok(Cow::Borrowed(raw));
}
fn simple_value(&mut self) -> Result<String, Error> {
let mut value = String::new();
let mut s = String::with_capacity(raw.len());
while let Some(&(_, c)) = self.it.peek() {
if c.is_whitespace() {
break;
}
let mut it = raw.chars();
while let Some(c) = it.next() {
let to_push = if c == '\\' {
match it.next() {
Some('\'') => '\'',
Some('\\') => '\\',
Some(c) => {
return Err(Error::connection_syntax(
format!("invalid escape `\\{}`", c).into(),
));
}
None => return Err(Error::connection_syntax("unexpected EOF".into())),
self.it.next();
if c == '\\' {
if let Some((_, c2)) = self.it.next() {
value.push(c2);
}
} else {
c
};
s.push(to_push);
value.push(c);
}
}
Ok(Cow::Owned(s))
if value.is_empty() {
return Err(Error::connection_syntax("unexpected EOF".into()));
}
Ok(value)
}
fn parameter(&mut self) -> Result<Option<(&'a str, Cow<'a, str>)>, Error> {
fn quoted_value(&mut self) -> Result<String, Error> {
let mut value = String::new();
while let Some(&(_, c)) = self.it.peek() {
if c == '\'' {
return Ok(value);
}
self.it.next();
if c == '\\' {
if let Some((_, c2)) = self.it.next() {
value.push(c2);
}
} else {
value.push(c);
}
}
Err(Error::connection_syntax(
"unterminated quoted connection parameter value".into(),
))
}
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
self.skip_ws();
let keyword = match self.keyword() {
Some(keyword) => keyword,

View File

@ -13,6 +13,7 @@ use tokio_postgres::error::SqlState;
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::{AsyncMessage, Client, Connection, NoTls};
mod parse;
mod types;
fn connect(

View File

@ -0,0 +1,34 @@
use std::collections::HashMap;
#[test]
fn pairs_ok() {
let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''"
.parse::<tokio_postgres::Builder>()
.unwrap();
let params = params.iter().collect::<HashMap<_, _>>();
let mut expected = HashMap::new();
expected.insert("user", "foo");
expected.insert("password", r" fizz 'buzz\ ");
expected.insert("thing", "");
expected.insert("client_encoding", "UTF8");
expected.insert("timezone", "GMT");
assert_eq!(params, expected);
}
#[test]
fn pairs_ws() {
let params = " user\t=\r\n\x0bfoo \t password = hunter2 "
.parse::<tokio_postgres::Builder>()
.unwrap();;
let params = params.iter().collect::<HashMap<_, _>>();
let mut expected = HashMap::new();
expected.insert("user", "foo");
expected.insert("password", r"hunter2");
expected.insert("client_encoding", "UTF8");
expected.insert("timezone", "GMT");
assert_eq!(params, expected);
}