From 92ef98ddcd8910d540ec35f2f2ec99f59959888a Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Tue, 20 Dec 2016 15:33:16 -0800 Subject: [PATCH] Add a postgres-shared crate --- Cargo.toml | 2 +- postgres-shared/Cargo.toml | 7 + postgres-shared/src/lib.rs | 3 + postgres-shared/src/params/mod.rs | 203 ++++++++++++++ postgres-shared/src/params/url.rs | 428 ++++++++++++++++++++++++++++++ 5 files changed, 642 insertions(+), 1 deletion(-) create mode 100644 postgres-shared/Cargo.toml create mode 100644 postgres-shared/src/lib.rs create mode 100644 postgres-shared/src/params/mod.rs create mode 100644 postgres-shared/src/params/url.rs diff --git a/Cargo.toml b/Cargo.toml index 383d5017..ad710159 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,2 @@ [workspace] -members = ["postgres", "codegen"] +members = ["codegen", "postgres", "postgres-shared"] diff --git a/postgres-shared/Cargo.toml b/postgres-shared/Cargo.toml new file mode 100644 index 00000000..80e16b7c --- /dev/null +++ b/postgres-shared/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "postgres-shared" +version = "0.1.0" +authors = ["Steven Fackler "] + +[dependencies] +hex = "0.2" diff --git a/postgres-shared/src/lib.rs b/postgres-shared/src/lib.rs new file mode 100644 index 00000000..4027c423 --- /dev/null +++ b/postgres-shared/src/lib.rs @@ -0,0 +1,3 @@ +extern crate hex; + +pub mod params; diff --git a/postgres-shared/src/params/mod.rs b/postgres-shared/src/params/mod.rs new file mode 100644 index 00000000..990b0521 --- /dev/null +++ b/postgres-shared/src/params/mod.rs @@ -0,0 +1,203 @@ +//! Postgres connection parameters. + +use std::error::Error; +use std::path::{Path, PathBuf}; +use std::mem; + +use url::Url; + +mod url; + +/// The host. +#[derive(Clone, Debug)] +pub enum Host { + /// A TCP hostname. + Tcp(String), + + /// The path to a directory containing the server's Unix socket. + Unix(PathBuf), +} + +/// Authentication information. +#[derive(Clone, Debug)] +pub struct User { + name: String, + password: Option, +} + +impl User { + /// The username. + pub fn name(&self) -> &str { + &self.name + } + + /// An optional password. + pub fn password(&self) -> Option<&str> { + self.password.as_ref().map(|p| &**p) + } +} + +/// Information necessary to open a new connection to a Postgres server. +#[derive(Clone, Debug)] +pub struct ConnectParams { + host: Host, + port: u16, + user: Option, + database: Option, + options: Vec<(String, String)>, +} + +impl ConnectParams { + /// Returns a new builder. + pub fn builder() -> Builder { + Builder { + port: 5432, + user: None, + database: None, + options: vec![], + } + } + + /// The target server. + pub fn host(&self) -> &Host { + &self.host + } + + /// The target port. + /// + /// Defaults to 5432. + pub fn port(&self) -> u16 { + self.port + } + + /// The user to login as. + /// + /// Connection requires a user but query cancellation does not. + pub fn user(&self) -> Option<&User> { + self.user.as_ref() + } + + /// The database to connect to. + /// + /// Defaults to the username. + pub fn database(&self) -> Option<&str> { + self.database.as_ref().map(|d| &**d) + } + + /// Runtime parameters to be passed to the Postgres backend. + pub fn options(&self) -> &[(String, String)] { + &self.options + } +} + +/// A builder type for `ConnectParams`. +pub struct Builder { + port: u16, + user: Option, + database: Option, + options: Vec<(String, String)>, +} + +impl Builder { + pub fn port(&mut self, port: u16) -> &mut Builder { + self.port = port; + self + } + + pub fn user(&mut self, name: &str, password: Option<&str>) -> &mut Builder { + self.user = Some(User { + name: name.to_owned(), + password: password.map(ToOwned::to_owned), + }); + self + } + + pub fn database(&mut self, database: &str) -> &mut Builder { + self.database = Some(database.to_owned()); + self + } + + pub fn option(&mut self, name: &str, value: &str) -> &mut Builder { + self.options.push((name.to_owned(), value.to_owned())); + self + } + + pub fn build_tcp(&mut self, host: &str) -> ConnectParams { + self.build(Host::Tcp(host.to_owned())) + } + + pub fn build_unix

(&mut self, host: P) -> ConnectParams + where P: AsRef + { + self.build(Host::Unix(host.as_ref().to_owned())) + } + + pub fn build(&mut self, host: Host) -> ConnectParams { + ConnectParams { + host: host, + port: self.port, + database: self.database.take(), + user: self.user.take(), + options: mem::replace(&mut self.options, vec![]), + } + } +} + +/// A trait implemented by types that can be converted into a `ConnectParams`. +pub trait IntoConnectParams { + /// Converts the value of `self` into a `ConnectParams`. + fn into_connect_params(self) -> Result>; +} + +impl IntoConnectParams for ConnectParams { + fn into_connect_params(self) -> Result> { + Ok(self) + } +} + +impl<'a> IntoConnectParams for &'a str { + fn into_connect_params(self) -> Result> { + match Url::parse(self) { + Ok(url) => url.into_connect_params(), + Err(err) => Err(err.into()), + } + } +} + +impl IntoConnectParams for String { + fn into_connect_params(self) -> Result> { + self.as_str().into_connect_params() + } +} + +impl IntoConnectParams for Url { + fn into_connect_params(self) -> Result> { + let Url { host, port, user, path: url::Path { path, query: options, .. }, .. } = self; + + let mut builder = ConnectParams::builder(); + + if let Some(port) = port { + builder.port(port); + } + + if let Some(info) = user { + builder.user(&info.user, info.pass.as_ref().map(|p| &**p)); + } + + if !path.is_empty() { + // path contains the leading / + builder.database(&path[1..]); + } + + for (name, value) in options { + builder.option(&name, &value); + } + + let maybe_path = try!(url::decode_component(&host)); + if maybe_path.starts_with('/') { + Ok(builder.build_unix(maybe_path)) + } else { + Ok(builder.build_tcp(&maybe_path)) + } + } +} diff --git a/postgres-shared/src/params/url.rs b/postgres-shared/src/params/url.rs new file mode 100644 index 00000000..34b77c20 --- /dev/null +++ b/postgres-shared/src/params/url.rs @@ -0,0 +1,428 @@ +// Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. +use std::str::FromStr; +use hex::FromHex; + +pub struct Url { + pub scheme: String, + pub user: Option, + pub host: String, + pub port: Option, + pub path: Path, +} + +pub struct Path { + pub path: String, + pub query: Query, + pub fragment: Option, +} + +pub struct UserInfo { + pub user: String, + pub pass: Option, +} + +pub type Query = Vec<(String, String)>; + +impl Url { + pub fn new(scheme: String, + user: Option, + host: String, + port: Option, + path: String, + query: Query, + fragment: Option) + -> Url { + Url { + scheme: scheme, + user: user, + host: host, + port: port, + path: Path::new(path, query, fragment), + } + } + + pub fn parse(rawurl: &str) -> DecodeResult { + // scheme + let (scheme, rest) = try!(get_scheme(rawurl)); + + // authority + let (userinfo, host, port, rest) = try!(get_authority(rest)); + + // path + let has_authority = !host.is_empty(); + let (path, rest) = try!(get_path(rest, has_authority)); + + // query and fragment + let (query, fragment) = try!(get_query_fragment(rest)); + + let url = Url::new(scheme.to_owned(), + userinfo, + host.to_owned(), + port, + path, + query, + fragment); + Ok(url) + } +} + +impl Path { + pub fn new(path: String, query: Query, fragment: Option) -> Path { + Path { + path: path, + query: query, + fragment: fragment, + } + } + + pub fn parse(rawpath: &str) -> DecodeResult { + let (path, rest) = try!(get_path(rawpath, false)); + + // query and fragment + let (query, fragment) = try!(get_query_fragment(&rest)); + + Ok(Path { + path: path, + query: query, + fragment: fragment, + }) + } +} + +impl UserInfo { + #[inline] + pub fn new(user: String, pass: Option) -> UserInfo { + UserInfo { + user: user, + pass: pass, + } + } +} + +pub type DecodeResult = Result; + +pub fn decode_component(container: &str) -> DecodeResult { + decode_inner(container, false) +} + +fn decode_inner(c: &str, full_url: bool) -> DecodeResult { + let mut out = String::new(); + let mut iter = c.as_bytes().iter().cloned(); + + loop { + match iter.next() { + Some(b) => { + match b as char { + '%' => { + let bytes = match (iter.next(), iter.next()) { + (Some(one), Some(two)) => [one, two], + _ => { + return Err("Malformed input: found '%' without two \ + trailing bytes" + .to_owned()) + } + }; + + let bytes_from_hex = match Vec::::from_hex(&bytes) { + Ok(b) => b, + _ => { + return Err("Malformed input: found '%' followed by \ + invalid hex values. Character '%' must \ + escaped." + .to_owned()) + } + }; + + // Only decode some characters if full_url: + match bytes_from_hex[0] as char { + // gen-delims: + ':' | '/' | '?' | '#' | '[' | ']' | '@' | '!' | '$' | '&' | '"' | + '(' | ')' | '*' | '+' | ',' | ';' | '=' if full_url => { + out.push('%'); + out.push(bytes[0] as char); + out.push(bytes[1] as char); + } + + ch => out.push(ch), + } + } + ch => out.push(ch), + } + } + None => return Ok(out), + } + } +} + +fn split_char_first(s: &str, c: char) -> (&str, &str) { + let mut iter = s.splitn(2, c); + + match (iter.next(), iter.next()) { + (Some(a), Some(b)) => (a, b), + (Some(a), None) => (a, ""), + (None, _) => unreachable!(), + } +} + +fn query_from_str(rawquery: &str) -> DecodeResult { + let mut query: Query = vec![]; + if !rawquery.is_empty() { + for p in rawquery.split('&') { + let (k, v) = split_char_first(p, '='); + query.push((try!(decode_component(k)), try!(decode_component(v)))); + } + } + + Ok(query) +} + +pub fn get_scheme(rawurl: &str) -> DecodeResult<(&str, &str)> { + for (i, c) in rawurl.chars().enumerate() { + let result = match c { + 'A'...'Z' | 'a'...'z' => continue, + '0'...'9' | '+' | '-' | '.' => { + if i != 0 { + continue; + } + + Err("url: Scheme must begin with a letter.".to_owned()) + } + ':' => { + if i == 0 { + Err("url: Scheme cannot be empty.".to_owned()) + } else { + Ok((&rawurl[0..i], &rawurl[i + 1..rawurl.len()])) + } + } + _ => Err("url: Invalid character in scheme.".to_owned()), + }; + + return result; + } + + Err("url: Scheme must be terminated with a colon.".to_owned()) +} + +// returns userinfo, host, port, and unparsed part, or an error +fn get_authority(rawurl: &str) -> DecodeResult<(Option, &str, Option, &str)> { + enum State { + Start, // starting state + PassHostPort, // could be in user or port + Ip6Port, // either in ipv6 host or port + Ip6Host, // are in an ipv6 host + InHost, // are in a host - may be ipv6, but don't know yet + InPort, // are in port + } + + #[derive(Clone, PartialEq)] + enum Input { + Digit, // all digits + Hex, // digits and letters a-f + Unreserved, // all other legal characters + } + + if !rawurl.starts_with("//") { + // there is no authority. + return Ok((None, "", None, rawurl)); + } + + let len = rawurl.len(); + let mut st = State::Start; + let mut input = Input::Digit; // most restricted, start here. + + let mut userinfo = None; + let mut host = ""; + let mut port = None; + + let mut colon_count = 0usize; + let mut pos = 0; + let mut begin = 2; + let mut end = len; + + for (i, c) in rawurl.chars() + .enumerate() + .skip(2) { + // deal with input class first + match c { + '0'...'9' => (), + 'A'...'F' | 'a'...'f' => { + if input == Input::Digit { + input = Input::Hex; + } + } + 'G'...'Z' | 'g'...'z' | '-' | '.' | '_' | '~' | '%' | '&' | '\'' | '(' | ')' | + '+' | '!' | '*' | ',' | ';' | '=' => input = Input::Unreserved, + ':' | '@' | '?' | '#' | '/' => { + // separators, don't change anything + } + _ => return Err("Illegal character in authority".to_owned()), + } + + // now process states + match c { + ':' => { + colon_count += 1; + match st { + State::Start => { + pos = i; + st = State::PassHostPort; + } + State::PassHostPort => { + // multiple colons means ipv6 address. + if input == Input::Unreserved { + return Err("Illegal characters in IPv6 address.".to_owned()); + } + st = State::Ip6Host; + } + State::InHost => { + pos = i; + if input == Input::Unreserved { + // must be port + host = &rawurl[begin..i]; + st = State::InPort; + } else { + // can't be sure whether this is an ipv6 address or a port + st = State::Ip6Port; + } + } + State::Ip6Port => { + if input == Input::Unreserved { + return Err("Illegal characters in authority.".to_owned()); + } + st = State::Ip6Host; + } + State::Ip6Host => { + if colon_count > 7 { + host = &rawurl[begin..i]; + pos = i; + st = State::InPort; + } + } + _ => return Err("Invalid ':' in authority.".to_owned()), + } + input = Input::Digit; // reset input class + } + + '@' => { + input = Input::Digit; // reset input class + colon_count = 0; // reset count + match st { + State::Start => { + let user = try!(decode_component(&rawurl[begin..i])); + userinfo = Some(UserInfo::new(user, None)); + st = State::InHost; + } + State::PassHostPort => { + let user = try!(decode_component(&rawurl[begin..pos])); + let pass = try!(decode_component(&rawurl[pos + 1..i])); + userinfo = Some(UserInfo::new(user, Some(pass))); + st = State::InHost; + } + _ => return Err("Invalid '@' in authority.".to_owned()), + } + begin = i + 1; + } + + '?' | '#' | '/' => { + end = i; + break; + } + _ => (), + } + } + + // finish up + match st { + State::PassHostPort | State::Ip6Port => { + if input != Input::Digit { + return Err("Non-digit characters in port.".to_owned()); + } + host = &rawurl[begin..pos]; + port = Some(&rawurl[pos + 1..end]); + } + State::Ip6Host | State::InHost | State::Start => host = &rawurl[begin..end], + State::InPort => { + if input != Input::Digit { + return Err("Non-digit characters in port.".to_owned()); + } + port = Some(&rawurl[pos + 1..end]); + } + } + + let rest = &rawurl[end..len]; + // If we have a port string, ensure it parses to u16. + let port = match port { + None => None, + opt => { + match opt.and_then(|p| FromStr::from_str(p).ok()) { + None => return Err(format!("Failed to parse port: {:?}", port)), + opt => opt, + } + } + }; + + Ok((userinfo, host, port, rest)) +} + + +// returns the path and unparsed part of url, or an error +fn get_path(rawurl: &str, is_authority: bool) -> DecodeResult<(String, &str)> { + let len = rawurl.len(); + let mut end = len; + for (i, c) in rawurl.chars().enumerate() { + match c { + 'A'...'Z' | 'a'...'z' | '0'...'9' | '&' | '\'' | '(' | ')' | '.' | '@' | ':' | + '%' | '/' | '+' | '!' | '*' | ',' | ';' | '=' | '_' | '-' | '~' => continue, + '?' | '#' => { + end = i; + break; + } + _ => return Err("Invalid character in path.".to_owned()), + } + } + + if is_authority && end != 0 && !rawurl.starts_with('/') { + Err("Non-empty path must begin with '/' in presence of authority.".to_owned()) + } else { + Ok((try!(decode_component(&rawurl[0..end])), &rawurl[end..len])) + } +} + +// returns the parsed query and the fragment, if present +fn get_query_fragment(rawurl: &str) -> DecodeResult<(Query, Option)> { + let (before_fragment, raw_fragment) = split_char_first(rawurl, '#'); + + // Parse the fragment if available + let fragment = match raw_fragment { + "" => None, + raw => Some(try!(decode_component(raw))), + }; + + match before_fragment.chars().next() { + Some('?') => Ok((try!(query_from_str(&before_fragment[1..])), fragment)), + None => Ok((vec![], fragment)), + _ => Err(format!("Query didn't start with '?': '{}..'", before_fragment)), + } +} + +impl FromStr for Url { + type Err = String; + fn from_str(s: &str) -> Result { + Url::parse(s) + } +} + +impl FromStr for Path { + type Err = String; + fn from_str(s: &str) -> Result { + Path::parse(s) + } +}