Merge pull request #363 from sfackler/tokio-rewrite

Tokio rewrite
This commit is contained in:
Steven Fackler 2018-06-26 00:22:40 -04:00 committed by GitHub
commit 7b5fa05a30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 2027 additions and 2436 deletions

View File

@ -21,9 +21,8 @@ save_deps: &SAVE_DEPS
version: 2 version: 2
jobs: jobs:
build: build:
working_directory: ~/build
docker: docker:
- image: rust:1.23.0 - image: rust:1.26.2
environment: environment:
RUSTFLAGS: -D warnings RUSTFLAGS: -D warnings
- image: sfackler/rust-postgres-test:4 - image: sfackler/rust-postgres-test:4
@ -31,8 +30,6 @@ jobs:
- checkout - checkout
- *RESTORE_REGISTRY - *RESTORE_REGISTRY
- run: cargo generate-lockfile - run: cargo generate-lockfile
- run: cargo update -p nalgebra --precise 0.14.3 # 0.14.4 requires 1.26 :(
- run: cargo update -p geo-types --precise 0.1.0
- *SAVE_REGISTRY - *SAVE_REGISTRY
- run: rustc --version > ~/rust-version - run: rustc --version > ~/rust-version
- *RESTORE_DEPS - *RESTORE_DEPS

View File

@ -8,3 +8,11 @@ members = [
"postgres-native-tls", "postgres-native-tls",
"tokio-postgres", "tokio-postgres",
] ]
[patch.crates-io]
tokio-uds = { git = "https://github.com/sfackler/tokio" }
tokio-io = { git = "https://github.com/sfackler/tokio" }
tokio-timer = { git = "https://github.com/sfackler/tokio" }
tokio-codec = { git = "https://github.com/sfackler/tokio" }
tokio-reactor = { git = "https://github.com/sfackler/tokio" }
tokio-executor = { git = "https://github.com/sfackler/tokio" }

View File

@ -1,9 +1,9 @@
//! Errors. //! Errors.
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::ErrorFields; use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
use std::error;
use std::convert::From; use std::convert::From;
use std::error;
use std::fmt; use std::fmt;
use std::io; use std::io;
@ -214,36 +214,29 @@ impl DbError {
} }
Ok(DbError { Ok(DbError {
severity: severity.ok_or_else(|| { severity: severity
io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing") .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
})?,
parsed_severity: parsed_severity, parsed_severity: parsed_severity,
code: code.ok_or_else(|| { code: code
io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing") .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
})?, message: message
message: message.ok_or_else(|| { .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing")
})?,
detail: detail, detail: detail,
hint: hint, hint: hint,
position: match normal_position { position: match normal_position {
Some(position) => Some(ErrorPosition::Normal(position)), Some(position) => Some(ErrorPosition::Normal(position)),
None => { None => match internal_position {
match internal_position { Some(position) => Some(ErrorPosition::Internal {
Some(position) => { position: position,
Some(ErrorPosition::Internal { query: internal_query.ok_or_else(|| {
position: position, io::Error::new(
query: internal_query.ok_or_else(|| { io::ErrorKind::InvalidInput,
io::Error::new( "`q` field missing but `p` field present",
io::ErrorKind::InvalidInput, )
"`q` field missing but `p` field present", })?,
) }),
})?, None => None,
}) },
}
None => None,
}
}
}, },
where_: where_, where_: where_,
schema: schema, schema: schema,
@ -324,6 +317,14 @@ pub fn db(e: DbError) -> Error {
Error(Box::new(ErrorKind::Db(e))) Error(Box::new(ErrorKind::Db(e)))
} }
#[doc(hidden)]
pub fn __db(e: ErrorResponseBody) -> Error {
match DbError::new(&mut e.fields()) {
Ok(e) => Error(Box::new(ErrorKind::Db(e))),
Err(e) => Error(Box::new(ErrorKind::Io(e))),
}
}
#[doc(hidden)] #[doc(hidden)]
pub fn io(e: io::Error) -> Error { pub fn io(e: io::Error) -> Error {
Error(Box::new(ErrorKind::Io(e))) Error(Box::new(ErrorKind::Io(e)))
@ -401,7 +402,7 @@ impl Error {
pub fn as_db(&self) -> Option<&DbError> { pub fn as_db(&self) -> Option<&DbError> {
match *self.0 { match *self.0 {
ErrorKind::Db(ref err) => Some(err), ErrorKind::Db(ref err) => Some(err),
_ => None _ => None,
} }
} }

View File

@ -2,8 +2,10 @@
use std::error::Error; use std::error::Error;
use std::mem; use std::mem;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration; use std::time::Duration;
use error;
use params::url::Url; use params::url::Url;
mod url; mod url;
@ -96,6 +98,14 @@ impl ConnectParams {
} }
} }
impl FromStr for ConnectParams {
type Err = error::Error;
fn from_str(s: &str) -> Result<ConnectParams, error::Error> {
s.into_connect_params().map_err(error::connect)
}
}
/// A builder for `ConnectParams`. /// A builder for `ConnectParams`.
pub struct Builder { pub struct Builder {
port: u16, port: u16,

View File

@ -31,21 +31,24 @@ circle-ci = { repository = "sfackler/rust-postgres" }
"with-serde_json-1" = ["postgres-shared/with-serde_json-1"] "with-serde_json-1" = ["postgres-shared/with-serde_json-1"]
"with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"] "with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"]
with-openssl = ["tokio-openssl", "openssl"]
[dependencies] [dependencies]
bytes = "0.4" bytes = "0.4"
fallible-iterator = "0.1.3" fallible-iterator = "0.1.3"
futures = "0.1.7" futures = "0.1.7"
futures-state-stream = "0.2" futures-cpupool = "0.1"
lazy_static = "1.0"
log = "0.4"
postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" } postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
postgres-shared = { version = "0.4.0", path = "../postgres-shared" } postgres-shared = { version = "0.4.0", path = "../postgres-shared" }
tokio-core = "0.1.8" state_machine_future = "0.1.7"
tokio-dns-unofficial = "0.1" tokio-codec = "0.1"
tokio-io = "0.1" tokio-io = "0.1"
tokio-tcp = "0.1"
tokio-openssl = { version = "0.2", optional = true } tokio-timer = "0.2"
openssl = { version = "0.10", optional = true }
[target.'cfg(unix)'.dependencies] [target.'cfg(unix)'.dependencies]
tokio-uds = "0.1" tokio-uds = "0.2"
[dev-dependencies]
tokio = "0.1.7"
env_logger = "0.5"

File diff suppressed because it is too large Load Diff

View File

@ -1,31 +0,0 @@
/// Generates a simple implementation of `ToSql::accepts` which accepts the
/// types passed to it.
#[macro_export]
macro_rules! accepts {
($($expected:pat),+) => (
fn accepts(ty: &$crate::types::Type) -> bool {
match *ty {
$($expected)|+ => true,
_ => false
}
}
)
}
/// Generates an implementation of `ToSql::to_sql_checked`.
///
/// All `ToSql` implementations should use this macro.
#[macro_export]
macro_rules! to_sql_checked {
() => {
fn to_sql_checked(&self,
ty: &$crate::types::Type,
out: &mut ::std::vec::Vec<u8>)
-> ::std::result::Result<$crate::types::IsNull,
Box<::std::error::Error +
::std::marker::Sync +
::std::marker::Send>> {
$crate::types::__to_sql_checked(self, ty, out)
}
}
}

View File

@ -0,0 +1,97 @@
use futures::sync::mpsc;
use postgres_protocol;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use disconnected;
use error::{self, Error};
use proto::connection::Request;
use proto::execute::ExecuteFuture;
use proto::prepare::PrepareFuture;
use proto::query::QueryStream;
use proto::statement::Statement;
use types::{IsNull, ToSql, Type};
pub struct PendingRequest {
sender: mpsc::UnboundedSender<Request>,
messages: Result<Vec<u8>, Error>,
}
impl PendingRequest {
pub fn send(self) -> Result<mpsc::Receiver<Message>, Error> {
let messages = self.messages?;
let (sender, receiver) = mpsc::channel(0);
self.sender
.unbounded_send(Request { messages, sender })
.map(|_| receiver)
.map_err(|_| disconnected())
}
}
pub struct Client {
sender: mpsc::UnboundedSender<Request>,
}
impl Client {
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client {
Client { sender }
}
pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture {
let pending = self.pending(|buf| {
frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), buf)?;
frontend::describe(b'S', &name, buf)?;
frontend::sync(buf);
Ok(())
});
PrepareFuture::new(pending, self.sender.clone(), name)
}
pub fn execute(&mut self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture {
let pending = self.pending_execute(statement, params);
ExecuteFuture::new(pending, statement.clone())
}
pub fn query(&mut self, statement: &Statement, params: &[&ToSql]) -> QueryStream {
let pending = self.pending_execute(statement, params);
QueryStream::new(pending, statement.clone())
}
fn pending_execute(&self, statement: &Statement, params: &[&ToSql]) -> PendingRequest {
self.pending(|buf| {
let r = frontend::bind(
"",
statement.name(),
Some(1),
params.iter().zip(statement.params()),
|(param, ty), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
Err(e) => Err(e),
},
Some(1),
buf,
);
match r {
Ok(()) => {}
Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)),
Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)),
}
frontend::execute("", 0, buf)?;
frontend::sync(buf);
Ok(())
})
}
fn pending<F>(&self, messages: F) -> PendingRequest
where
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
{
let mut buf = vec![];
PendingRequest {
sender: self.sender.clone(),
messages: messages(&mut buf).map(|()| buf),
}
}
}

View File

@ -0,0 +1,25 @@
use bytes::BytesMut;
use postgres_protocol::message::backend;
use std::io;
use tokio_codec::{Decoder, Encoder};
pub struct PostgresCodec;
impl Encoder for PostgresCodec {
type Item = Vec<u8>;
type Error = io::Error;
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
dst.extend_from_slice(&item);
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = backend::Message;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<backend::Message>, io::Error> {
backend::Message::parse(src)
}
}

View File

@ -0,0 +1,242 @@
use futures::sync::mpsc;
use futures::{Async, AsyncSink, Future, Poll, Sink, Stream};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::io;
use tokio_codec::Framed;
use disconnected;
use error::{self, Error};
use proto::codec::PostgresCodec;
use tls::TlsStream;
use {bad_response, CancelData};
pub struct Request {
pub messages: Vec<u8>,
pub sender: mpsc::Sender<Message>,
}
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
Closing,
}
pub struct Connection {
stream: Framed<Box<TlsStream>, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<Vec<u8>>,
pending_response: Option<Message>,
responses: VecDeque<mpsc::Sender<Message>>,
state: State,
}
impl Connection {
pub fn new(
stream: Framed<Box<TlsStream>, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection {
Connection {
stream,
cancel_data,
parameters,
receiver,
pending_request: None,
pending_response: None,
responses: VecDeque::new(),
state: State::Active,
}
}
pub fn cancel_data(&self) -> CancelData {
self.cancel_data
}
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
fn poll_response(&mut self) -> Poll<Option<Message>, io::Error> {
if let Some(message) = self.pending_response.take() {
trace!("retrying pending response");
return Ok(Async::Ready(Some(message)));
}
self.stream.poll()
}
fn poll_read(&mut self) -> Result<(), Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(());
}
loop {
let message = match self.poll_response()? {
Async::Ready(Some(message)) => message,
Async::Ready(None) => {
return Err(disconnected());
}
Async::NotReady => {
trace!("poll_read: waiting on response");
return Ok(());
}
};
let message = match message {
Message::NoticeResponse(_) | Message::NotificationResponse(_) => {
// FIXME handle these
continue;
}
Message::ParameterStatus(body) => {
self.parameters
.insert(body.name()?.to_string(), body.value()?.to_string());
continue;
}
m => m,
};
let mut sender = match self.responses.pop_front() {
Some(sender) => sender,
None => match message {
Message::ErrorResponse(error) => return Err(error::__db(error)),
_ => return Err(bad_response()),
},
};
let request_complete = match message {
Message::ReadyForQuery(_) => true,
_ => false,
};
match sender.start_send(message) {
// if the receiver's hung up we still need to page through the rest of the messages
// designated to it
Ok(AsyncSink::Ready) | Err(_) => {
if !request_complete {
self.responses.push_front(sender);
}
}
Ok(AsyncSink::NotReady(message)) => {
self.responses.push_front(sender);
self.pending_response = Some(message);
trace!("poll_read: waiting on socket");
return Ok(());
}
}
}
}
fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> {
if let Some(message) = self.pending_request.take() {
trace!("retrying pending request");
return Ok(Async::Ready(Some(message)));
}
match try_receive!(self.receiver.poll()) {
Some(request) => {
trace!("polled new request");
self.responses.push_back(request.sender);
Ok(Async::Ready(Some(request.messages)))
}
None => Ok(Async::Ready(None)),
}
}
fn poll_write(&mut self) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
let request = match self.poll_request()? {
Async::Ready(Some(request)) => request,
Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
request
}
Async::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len(),
);
return Ok(true);
}
Async::NotReady => {
trace!("poll_write: waiting on request");
return Ok(true);
}
};
match self.stream.start_send(request)? {
AsyncSink::Ready => {
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
AsyncSink::NotReady(request) => {
trace!("poll_write: waiting on socket");
self.pending_request = Some(request);
return Ok(false);
}
}
}
}
fn poll_flush(&mut self) -> Result<(), Error> {
match self.stream.poll_complete() {
Ok(Async::Ready(())) => {
trace!("poll_flush: flushed");
Ok(())
}
Ok(Async::NotReady) => {
trace!("poll_flush: waiting on socket");
Ok(())
}
Err(e) => Err(Error::from(e)),
}
}
fn poll_shutdown(&mut self) -> Poll<(), Error> {
if self.state != State::Closing {
return Ok(Async::NotReady);
}
match self.stream.close() {
Ok(Async::Ready(())) => {
trace!("poll_shutdown: complete");
Ok(Async::Ready(()))
}
Ok(Async::NotReady) => {
trace!("poll_shutdown: waiting on socket");
Ok(Async::NotReady)
}
Err(e) => Err(Error::from(e)),
}
}
}
impl Future for Connection {
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<(), Error> {
self.poll_read()?;
let want_flush = self.poll_write()?;
if want_flush {
self.poll_flush()?;
}
self.poll_shutdown()
}
}

View File

@ -0,0 +1,88 @@
use futures::sync::mpsc;
use futures::{Poll, Stream};
use postgres_protocol::message::backend::Message;
use state_machine_future::RentToOwn;
use error::{self, Error};
use proto::client::PendingRequest;
use proto::statement::Statement;
use {bad_response, disconnected};
#[derive(StateMachineFuture)]
pub enum Execute {
#[state_machine_future(start, transitions(ReadResponse))]
Start {
request: PendingRequest,
statement: Statement,
},
#[state_machine_future(transitions(ReadReadyForQuery))]
ReadResponse { receiver: mpsc::Receiver<Message> },
#[state_machine_future(transitions(Finished))]
ReadReadyForQuery {
receiver: mpsc::Receiver<Message>,
rows: u64,
},
#[state_machine_future(ready)]
Finished(u64),
#[state_machine_future(error)]
Failed(Error),
}
impl PollExecute for Execute {
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
let state = state.take();
let receiver = state.request.send()?;
// the statement can drop after this point, since its close will queue up after the execution
transition!(ReadResponse { receiver })
}
fn poll_read_response<'a>(
state: &'a mut RentToOwn<'a, ReadResponse>,
) -> Poll<AfterReadResponse, Error> {
loop {
let message = try_receive!(state.receiver.poll());
match message {
Some(Message::BindComplete) => {}
Some(Message::DataRow(_)) => {}
Some(Message::ErrorResponse(body)) => return Err(error::__db(body)),
Some(Message::CommandComplete(body)) => {
let rows = body.tag()?.rsplit(' ').next().unwrap().parse().unwrap_or(0);
let state = state.take();
transition!(ReadReadyForQuery {
receiver: state.receiver,
rows,
});
}
Some(Message::EmptyQueryResponse) => {
let state = state.take();
transition!(ReadReadyForQuery {
receiver: state.receiver,
rows: 0,
});
}
Some(_) => return Err(bad_response()),
None => return Err(disconnected()),
}
}
}
fn poll_read_ready_for_query<'a>(
state: &'a mut RentToOwn<'a, ReadReadyForQuery>,
) -> Poll<AfterReadReadyForQuery, Error> {
let message = try_receive!(state.receiver.poll());
match message {
Some(Message::ReadyForQuery(_)) => transition!(Finished(state.rows)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
}
impl ExecuteFuture {
pub fn new(request: PendingRequest, statement: Statement) -> ExecuteFuture {
Execute::start(request, statement)
}
}

View File

@ -0,0 +1,413 @@
use fallible_iterator::FallibleIterator;
use futures::sink;
use futures::sync::mpsc;
use futures::{Future, Poll, Sink, Stream};
use postgres_protocol::authentication;
use postgres_protocol::authentication::sasl::{self, ChannelBinding, ScramSha256};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use state_machine_future::RentToOwn;
use std::collections::HashMap;
use std::error::Error as StdError;
use std::io;
use tokio_codec::Framed;
use tokio_io::io::{read_exact, write_all, ReadExact, WriteAll};
use error::{self, Error};
use params::{ConnectParams, Host, User};
use proto::client::Client;
use proto::codec::PostgresCodec;
use proto::connection::Connection;
use proto::socket::{ConnectFuture, Socket};
use tls::{self, TlsConnect, TlsStream};
use {bad_response, disconnected, CancelData, TlsMode};
#[derive(StateMachineFuture)]
pub enum Handshake {
#[state_machine_future(start, transitions(BuildingStartup, SendingSsl))]
Start {
future: ConnectFuture,
params: ConnectParams,
tls: TlsMode,
},
#[state_machine_future(transitions(ReadingSsl))]
SendingSsl {
future: WriteAll<Socket, Vec<u8>>,
params: ConnectParams,
connector: Box<TlsConnect>,
required: bool,
},
#[state_machine_future(transitions(ConnectingTls, BuildingStartup))]
ReadingSsl {
future: ReadExact<Socket, [u8; 1]>,
params: ConnectParams,
connector: Box<TlsConnect>,
required: bool,
},
#[state_machine_future(transitions(BuildingStartup))]
ConnectingTls {
future:
Box<Future<Item = Box<TlsStream>, Error = Box<StdError + Sync + Send>> + Sync + Send>,
params: ConnectParams,
},
#[state_machine_future(transitions(SendingStartup))]
BuildingStartup {
stream: Framed<Box<TlsStream>, PostgresCodec>,
params: ConnectParams,
},
#[state_machine_future(transitions(ReadingAuth))]
SendingStartup {
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
user: User,
},
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
ReadingAuth {
stream: Framed<Box<TlsStream>, PostgresCodec>,
user: User,
},
#[state_machine_future(transitions(ReadingAuthCompletion))]
SendingPassword {
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
},
#[state_machine_future(transitions(ReadingSasl))]
SendingSasl {
future: sink::Send<Framed<Box<TlsStream>, PostgresCodec>>,
scram: ScramSha256,
},
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
ReadingSasl {
stream: Framed<Box<TlsStream>, PostgresCodec>,
scram: ScramSha256,
},
#[state_machine_future(transitions(ReadingInfo))]
ReadingAuthCompletion {
stream: Framed<Box<TlsStream>, PostgresCodec>,
},
#[state_machine_future(transitions(Finished))]
ReadingInfo {
stream: Framed<Box<TlsStream>, PostgresCodec>,
cancel_data: Option<CancelData>,
parameters: HashMap<String, String>,
},
#[state_machine_future(ready)]
Finished((Client, Connection)),
#[state_machine_future(error)]
Failed(Error),
}
impl PollHandshake for Handshake {
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
let stream = try_ready!(state.future.poll());
let state = state.take();
let (connector, required) = match state.tls {
TlsMode::None => {
transition!(BuildingStartup {
stream: Framed::new(Box::new(stream), PostgresCodec),
params: state.params,
});
}
TlsMode::Prefer(connector) => (connector, false),
TlsMode::Require(connector) => (connector, true),
};
let mut buf = vec![];
frontend::ssl_request(&mut buf);
transition!(SendingSsl {
future: write_all(stream, buf),
params: state.params,
connector,
required,
})
}
fn poll_sending_ssl<'a>(
state: &'a mut RentToOwn<'a, SendingSsl>,
) -> Poll<AfterSendingSsl, Error> {
let (stream, _) = try_ready!(state.future.poll());
let state = state.take();
transition!(ReadingSsl {
future: read_exact(stream, [0]),
params: state.params,
connector: state.connector,
required: state.required,
})
}
fn poll_reading_ssl<'a>(
state: &'a mut RentToOwn<'a, ReadingSsl>,
) -> Poll<AfterReadingSsl, Error> {
let (stream, buf) = try_ready!(state.future.poll());
let state = state.take();
match buf[0] {
b'S' => {
let future = match state.params.host() {
Host::Tcp(domain) => state.connector.connect(domain, tls::Socket(stream)),
Host::Unix(_) => {
return Err(error::tls("TLS over unix sockets not supported".into()))
}
};
transition!(ConnectingTls {
future,
params: state.params,
})
}
b'N' if !state.required => transition!(BuildingStartup {
stream: Framed::new(Box::new(stream), PostgresCodec),
params: state.params,
}),
b'N' => Err(error::tls("TLS was required but not supported".into())),
_ => Err(bad_response()),
}
}
fn poll_connecting_tls<'a>(
state: &'a mut RentToOwn<'a, ConnectingTls>,
) -> Poll<AfterConnectingTls, Error> {
let stream = try_ready!(state.future.poll().map_err(error::tls));
let state = state.take();
transition!(BuildingStartup {
stream: Framed::new(stream, PostgresCodec),
params: state.params,
})
}
fn poll_building_startup<'a>(
state: &'a mut RentToOwn<'a, BuildingStartup>,
) -> Poll<AfterBuildingStartup, Error> {
let state = state.take();
let user = match state.params.user() {
Some(user) => user.clone(),
None => {
return Err(error::connect(
"user missing from connection parameters".into(),
))
}
};
let mut buf = vec![];
{
let options = state
.params
.options()
.iter()
.map(|&(ref key, ref value)| (&**key, &**value));
let client_encoding = Some(("client_encoding", "UTF8"));
let timezone = Some(("timezone", "GMT"));
let user = Some(("user", user.name()));
let database = state.params.database().map(|s| ("database", s));
frontend::startup_message(
options
.chain(client_encoding)
.chain(timezone)
.chain(user)
.chain(database),
&mut buf,
)?;
}
transition!(SendingStartup {
future: state.stream.send(buf),
user,
})
}
fn poll_sending_startup<'a>(
state: &'a mut RentToOwn<'a, SendingStartup>,
) -> Poll<AfterSendingStartup, Error> {
let stream = try_ready!(state.future.poll());
let state = state.take();
transition!(ReadingAuth {
stream,
user: state.user,
})
}
fn poll_reading_auth<'a>(
state: &'a mut RentToOwn<'a, ReadingAuth>,
) -> Poll<AfterReadingAuth, Error> {
let message = try_ready!(state.stream.poll());
let state = state.take();
match message {
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
stream: state.stream,
cancel_data: None,
parameters: HashMap::new(),
}),
Some(Message::AuthenticationCleartextPassword) => {
let pass = state.user.password().ok_or_else(missing_password)?;
let mut buf = vec![];
frontend::password_message(pass, &mut buf)?;
transition!(SendingPassword {
future: state.stream.send(buf)
})
}
Some(Message::AuthenticationMd5Password(body)) => {
let pass = state.user.password().ok_or_else(missing_password)?;
let output = authentication::md5_hash(
state.user.name().as_bytes(),
pass.as_bytes(),
body.salt(),
);
let mut buf = vec![];
frontend::password_message(&output, &mut buf)?;
transition!(SendingPassword {
future: state.stream.send(buf)
})
}
Some(Message::AuthenticationSasl(body)) => {
let pass = state.user.password().ok_or_else(missing_password)?;
let mut has_scram = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next()? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
_ => {}
}
}
if !has_scram {
return Err(io::Error::new(
io::ErrorKind::Other,
"unsupported SASL authentication",
).into());
}
let mut scram = ScramSha256::new(pass.as_bytes(), ChannelBinding::unsupported())?;
let mut buf = vec![];
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
transition!(SendingSasl {
future: state.stream.send(buf),
scram,
})
}
Some(Message::AuthenticationKerberosV5)
| Some(Message::AuthenticationScmCredential)
| Some(Message::AuthenticationGss)
| Some(Message::AuthenticationSspi) => Err(io::Error::new(
io::ErrorKind::Other,
"unsupported authentication method",
).into()),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_sending_password<'a>(
state: &'a mut RentToOwn<'a, SendingPassword>,
) -> Poll<AfterSendingPassword, Error> {
let stream = try_ready!(state.future.poll());
transition!(ReadingAuthCompletion { stream })
}
fn poll_sending_sasl<'a>(
state: &'a mut RentToOwn<'a, SendingSasl>,
) -> Poll<AfterSendingSasl, Error> {
let stream = try_ready!(state.future.poll());
let state = state.take();
transition!(ReadingSasl {
stream,
scram: state.scram
})
}
fn poll_reading_sasl<'a>(
state: &'a mut RentToOwn<'a, ReadingSasl>,
) -> Poll<AfterReadingSasl, Error> {
let message = try_ready!(state.stream.poll());
let mut state = state.take();
match message {
Some(Message::AuthenticationSaslContinue(body)) => {
state.scram.update(body.data())?;
let mut buf = vec![];
frontend::sasl_response(state.scram.message(), &mut buf)?;
transition!(SendingSasl {
future: state.stream.send(buf),
scram: state.scram,
})
}
Some(Message::AuthenticationSaslFinal(body)) => {
state.scram.finish(body.data())?;
transition!(ReadingAuthCompletion {
stream: state.stream,
})
}
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_reading_auth_completion<'a>(
state: &'a mut RentToOwn<'a, ReadingAuthCompletion>,
) -> Poll<AfterReadingAuthCompletion, Error> {
let message = try_ready!(state.stream.poll());
let state = state.take();
match message {
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
stream: state.stream,
cancel_data: None,
parameters: HashMap::new(),
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_reading_info<'a>(
state: &'a mut RentToOwn<'a, ReadingInfo>,
) -> Poll<AfterReadingInfo, Error> {
loop {
let message = try_ready!(state.stream.poll());
match message {
Some(Message::BackendKeyData(body)) => {
state.cancel_data = Some(CancelData {
process_id: body.process_id(),
secret_key: body.secret_key(),
});
}
Some(Message::ParameterStatus(body)) => {
state
.parameters
.insert(body.name()?.to_string(), body.value()?.to_string());
}
Some(Message::ReadyForQuery(_)) => {
let state = state.take();
let cancel_data = state.cancel_data.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "BackendKeyData message missing")
})?;
let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender);
let connection =
Connection::new(state.stream, cancel_data, state.parameters, receiver);
transition!(Finished((client, connection)))
}
Some(Message::ErrorResponse(body)) => return Err(error::__db(body)),
Some(Message::NoticeResponse(_)) => {}
Some(_) => return Err(bad_response()),
None => return Err(disconnected()),
}
}
}
}
impl HandshakeFuture {
pub fn new(params: ConnectParams, tls: TlsMode) -> HandshakeFuture {
Handshake::start(Socket::connect(&params), params, tls)
}
}
fn missing_password() -> Error {
error::connect("a password was requested but not provided".into())
}

View File

@ -0,0 +1,31 @@
macro_rules! try_receive {
($e:expr) => {
match $e {
Ok(::futures::Async::Ready(v)) => v,
Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady),
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
}
};
}
mod client;
mod codec;
mod connection;
mod execute;
mod handshake;
mod prepare;
mod query;
mod row;
mod socket;
mod statement;
pub use proto::client::Client;
pub use proto::codec::PostgresCodec;
pub use proto::connection::Connection;
pub use proto::execute::ExecuteFuture;
pub use proto::handshake::HandshakeFuture;
pub use proto::prepare::PrepareFuture;
pub use proto::query::QueryStream;
pub use proto::row::Row;
pub use proto::socket::Socket;
pub use proto::statement::Statement;

View File

@ -0,0 +1,171 @@
use fallible_iterator::FallibleIterator;
use futures::sync::mpsc;
use futures::{Poll, Stream};
use postgres_protocol::message::backend::{Message, ParameterDescriptionBody, RowDescriptionBody};
use state_machine_future::RentToOwn;
use error::{self, Error};
use proto::client::PendingRequest;
use proto::connection::Request;
use proto::statement::Statement;
use types::Type;
use Column;
use {bad_response, disconnected};
#[derive(StateMachineFuture)]
pub enum Prepare {
#[state_machine_future(start, transitions(ReadParseComplete))]
Start {
request: PendingRequest,
sender: mpsc::UnboundedSender<Request>,
name: String,
},
#[state_machine_future(transitions(ReadParameterDescription))]
ReadParseComplete {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
},
#[state_machine_future(transitions(ReadRowDescription))]
ReadParameterDescription {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
},
#[state_machine_future(transitions(ReadReadyForQuery))]
ReadRowDescription {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
parameters: ParameterDescriptionBody,
},
#[state_machine_future(transitions(Finished))]
ReadReadyForQuery {
sender: mpsc::UnboundedSender<Request>,
receiver: mpsc::Receiver<Message>,
name: String,
parameters: ParameterDescriptionBody,
columns: Option<RowDescriptionBody>,
},
#[state_machine_future(ready)]
Finished(Statement),
#[state_machine_future(error)]
Failed(Error),
}
impl PollPrepare for Prepare {
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
let state = state.take();
let receiver = state.request.send()?;
transition!(ReadParseComplete {
sender: state.sender,
receiver,
name: state.name,
})
}
fn poll_read_parse_complete<'a>(
state: &'a mut RentToOwn<'a, ReadParseComplete>,
) -> Poll<AfterReadParseComplete, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ParseComplete) => transition!(ReadParameterDescription {
sender: state.sender,
receiver: state.receiver,
name: state.name,
}),
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_parameter_description<'a>(
state: &'a mut RentToOwn<'a, ReadParameterDescription>,
) -> Poll<AfterReadParameterDescription, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription {
sender: state.sender,
receiver: state.receiver,
name: state.name,
parameters: body,
}),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
fn poll_read_row_description<'a>(
state: &'a mut RentToOwn<'a, ReadRowDescription>,
) -> Poll<AfterReadRowDescription, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
let body = match message {
Some(Message::RowDescription(body)) => Some(body),
Some(Message::NoData) => None,
Some(_) => return Err(bad_response()),
None => return Err(disconnected()),
};
transition!(ReadReadyForQuery {
sender: state.sender,
receiver: state.receiver,
name: state.name,
parameters: state.parameters,
columns: body,
})
}
fn poll_read_ready_for_query<'a>(
state: &'a mut RentToOwn<'a, ReadReadyForQuery>,
) -> Poll<AfterReadReadyForQuery, Error> {
let message = try_receive!(state.receiver.poll());
let state = state.take();
match message {
Some(Message::ReadyForQuery(_)) => {
// FIXME handle custom types
let parameters = state
.parameters
.parameters()
.map(|oid| Type::from_oid(oid).unwrap())
.collect()?;
let columns = match state.columns {
Some(body) => body
.fields()
.map(|f| {
Column::new(f.name().to_string(), Type::from_oid(f.type_oid()).unwrap())
})
.collect()?,
None => vec![],
};
transition!(Finished(Statement::new(
state.sender,
state.name,
parameters,
columns
)))
}
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
}
impl PrepareFuture {
pub fn new(
request: PendingRequest,
sender: mpsc::UnboundedSender<Request>,
name: String,
) -> PrepareFuture {
Prepare::start(request, sender, name)
}
}

View File

@ -0,0 +1,108 @@
use futures::sync::mpsc;
use futures::{Async, Poll, Stream};
use postgres_protocol::message::backend::Message;
use std::mem;
use error::{self, Error};
use proto::client::PendingRequest;
use proto::row::Row;
use proto::statement::Statement;
use {bad_response, disconnected};
enum State {
Start {
request: PendingRequest,
statement: Statement,
},
ReadingResponse {
receiver: mpsc::Receiver<Message>,
statement: Statement,
},
ReadingReadyForQuery {
receiver: mpsc::Receiver<Message>,
},
Done,
}
pub struct QueryStream(State);
impl Stream for QueryStream {
type Item = Row;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Row>, Error> {
loop {
match mem::replace(&mut self.0, State::Done) {
State::Start { request, statement } => {
let receiver = request.send()?;
self.0 = State::ReadingResponse {
receiver,
statement,
};
}
State::ReadingResponse {
mut receiver,
statement,
} => {
let message = match receiver.poll() {
Ok(Async::Ready(message)) => message,
Ok(Async::NotReady) => {
self.0 = State::ReadingResponse {
receiver,
statement,
};
break Ok(Async::NotReady);
}
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
};
match message {
Some(Message::BindComplete) => {
self.0 = State::ReadingResponse {
receiver,
statement,
};
}
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
Some(Message::DataRow(body)) => {
let row = Row::new(statement.clone(), body)?;
self.0 = State::ReadingResponse {
receiver,
statement,
};
break Ok(Async::Ready(Some(row)));
}
Some(Message::EmptyQueryResponse) | Some(Message::CommandComplete(_)) => {
self.0 = State::ReadingReadyForQuery { receiver };
}
Some(_) => break Err(bad_response()),
None => break Err(disconnected()),
}
}
State::ReadingReadyForQuery { mut receiver } => {
let message = match receiver.poll() {
Ok(Async::Ready(message)) => message,
Ok(Async::NotReady) => {
self.0 = State::ReadingReadyForQuery { receiver };
break Ok(Async::NotReady);
}
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
};
match message {
Some(Message::ReadyForQuery(_)) => break Ok(Async::Ready(None)),
Some(_) => break Err(bad_response()),
None => break Err(disconnected()),
}
}
State::Done => break Ok(Async::Ready(None)),
}
}
}
}
impl QueryStream {
pub fn new(request: PendingRequest, statement: Statement) -> QueryStream {
QueryStream(State::Start { request, statement })
}
}

View File

@ -0,0 +1,66 @@
use postgres_protocol::message::backend::DataRowBody;
use postgres_shared::rows::{RowData, RowIndex};
use std::fmt;
use error::{self, Error};
use proto::statement::Statement;
use types::{FromSql, WrongType};
use Column;
pub struct Row {
statement: Statement,
data: RowData,
}
impl Row {
pub fn new(statement: Statement, data: DataRowBody) -> Result<Row, Error> {
let data = RowData::new(data)?;
Ok(Row { statement, data })
}
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
pub fn len(&self) -> usize {
self.columns().len()
}
pub fn get<'b, I, T>(&'b self, idx: I) -> T
where
I: RowIndex + fmt::Debug,
T: FromSql<'b>,
{
match self.get_inner(&idx) {
Ok(Some(ok)) => ok,
Err(err) => panic!("error retrieving column {:?}: {:?}", idx, err),
Ok(None) => panic!("no such column {:?}", idx),
}
}
pub fn try_get<'b, I, T>(&'b self, idx: I) -> Result<Option<T>, Error>
where
I: RowIndex,
T: FromSql<'b>,
{
self.get_inner(&idx)
}
fn get_inner<'b, I, T>(&'b self, idx: &I) -> Result<Option<T>, Error>
where
I: RowIndex,
T: FromSql<'b>,
{
let idx = match idx.__idx(&self.columns()) {
Some(idx) => idx,
None => return Ok(None),
};
let ty = self.statement.columns()[idx].type_();
if !<T as FromSql>::accepts(ty) {
return Err(error::conversion(Box::new(WrongType::new(ty.clone()))));
}
let value = FromSql::from_sql_nullable(ty, self.data.get(idx));
value.map(Some).map_err(error::conversion)
}
}

View File

@ -0,0 +1,233 @@
use bytes::{Buf, BufMut};
use futures::{Async, Future, Poll};
use futures_cpupool::{CpuFuture, CpuPool};
use state_machine_future::RentToOwn;
use std::io::{self, Read, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::time::{Duration, Instant};
use std::vec;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_tcp::{self, TcpStream};
use tokio_timer::Delay;
#[cfg(unix)]
use tokio_uds::{self, UnixStream};
use params::{ConnectParams, Host};
lazy_static! {
static ref DNS_POOL: CpuPool = CpuPool::new(2);
}
pub enum Socket {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
impl Read for Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Socket::Tcp(stream) => stream.read(buf),
#[cfg(unix)]
Socket::Unix(stream) => stream.read(buf),
}
}
}
impl AsyncRead for Socket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self {
Socket::Tcp(stream) => stream.prepare_uninitialized_buffer(buf),
#[cfg(unix)]
Socket::Unix(stream) => stream.prepare_uninitialized_buffer(buf),
}
}
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
match self {
Socket::Tcp(stream) => stream.read_buf(buf),
#[cfg(unix)]
Socket::Unix(stream) => stream.read_buf(buf),
}
}
}
impl Write for Socket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Socket::Tcp(stream) => stream.write(buf),
#[cfg(unix)]
Socket::Unix(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Socket::Tcp(stream) => stream.flush(),
#[cfg(unix)]
Socket::Unix(stream) => stream.flush(),
}
}
}
impl AsyncWrite for Socket {
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self {
Socket::Tcp(stream) => stream.shutdown(),
#[cfg(unix)]
Socket::Unix(stream) => stream.shutdown(),
}
}
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: Buf,
{
match self {
Socket::Tcp(stream) => stream.write_buf(buf),
#[cfg(unix)]
Socket::Unix(stream) => stream.write_buf(buf),
}
}
}
impl Socket {
pub fn connect(params: &ConnectParams) -> ConnectFuture {
Connect::start(params.clone())
}
}
#[derive(StateMachineFuture)]
pub enum Connect {
#[state_machine_future(start)]
#[cfg_attr(unix, state_machine_future(transitions(ResolvingDns, ConnectingUnix)))]
#[cfg_attr(not(unix), state_machine_future(transitions(ResolvingDns)))]
Start { params: ConnectParams },
#[state_machine_future(transitions(ConnectingTcp))]
ResolvingDns {
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
timeout: Option<Duration>,
},
#[state_machine_future(transitions(Ready))]
ConnectingTcp {
addrs: vec::IntoIter<SocketAddr>,
future: tokio_tcp::ConnectFuture,
timeout: Option<(Duration, Delay)>,
},
#[cfg(unix)]
#[state_machine_future(transitions(Ready))]
ConnectingUnix {
future: tokio_uds::ConnectFuture,
timeout: Option<Delay>,
},
#[state_machine_future(ready)]
Ready(Socket),
#[state_machine_future(error)]
Failed(io::Error),
}
impl PollConnect for Connect {
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, io::Error> {
let timeout = state.params.connect_timeout();
let port = state.params.port();
match state.params.host() {
Host::Tcp(ref host) => {
let host = host.clone();
transition!(ResolvingDns {
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
timeout,
})
}
#[cfg(unix)]
Host::Unix(ref path) => {
let path = path.join(format!(".s.PGSQL.{}", port));
transition!(ConnectingUnix {
future: UnixStream::connect(path),
timeout: timeout.map(|t| Delay::new(Instant::now() + t))
})
}
}
}
fn poll_resolving_dns<'a>(
state: &'a mut RentToOwn<'a, ResolvingDns>,
) -> Poll<AfterResolvingDns, io::Error> {
let mut addrs = try_ready!(state.future.poll());
let addr = match addrs.next() {
Some(addr) => addr,
None => {
return Err(io::Error::new(
io::ErrorKind::Other,
"resolved to 0 addresses",
))
}
};
transition!(ConnectingTcp {
addrs,
future: TcpStream::connect(&addr),
timeout: state.timeout.map(|t| (t, Delay::new(Instant::now() + t))),
})
}
fn poll_connecting_tcp<'a>(
state: &'a mut RentToOwn<'a, ConnectingTcp>,
) -> Poll<AfterConnectingTcp, io::Error> {
loop {
let error = match state.future.poll() {
Ok(Async::Ready(socket)) => transition!(Ready(Socket::Tcp(socket))),
Ok(Async::NotReady) => match state.timeout {
Some((_, ref mut delay)) => {
try_ready!(
delay
.poll()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
);
io::Error::new(io::ErrorKind::TimedOut, "connection timed out")
}
None => return Ok(Async::NotReady),
},
Err(e) => e,
};
let addr = match state.addrs.next() {
Some(addr) => addr,
None => return Err(error),
};
state.future = TcpStream::connect(&addr);
if let Some((timeout, ref mut delay)) = state.timeout {
delay.reset(Instant::now() + timeout);
}
}
}
#[cfg(unix)]
fn poll_connecting_unix<'a>(
state: &'a mut RentToOwn<'a, ConnectingUnix>,
) -> Poll<AfterConnectingUnix, io::Error> {
match state.future.poll()? {
Async::Ready(socket) => transition!(Ready(Socket::Unix(socket))),
Async::NotReady => match state.timeout {
Some(ref mut delay) => {
try_ready!(
delay
.poll()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
);
Err(io::Error::new(
io::ErrorKind::TimedOut,
"connection timed out",
))
}
None => Ok(Async::NotReady),
},
}
}
}

View File

@ -0,0 +1,58 @@
use futures::sync::mpsc;
use postgres_protocol::message::frontend;
use postgres_shared::stmt::Column;
use std::sync::Arc;
use proto::connection::Request;
use types::Type;
pub struct StatementInner {
sender: mpsc::UnboundedSender<Request>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
}
impl Drop for StatementInner {
fn drop(&mut self) {
let mut buf = vec![];
frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid");
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0);
let _ = self.sender.unbounded_send(Request {
messages: buf,
sender,
});
}
}
#[derive(Clone)]
pub struct Statement(Arc<StatementInner>);
impl Statement {
pub fn new(
sender: mpsc::UnboundedSender<Request>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
Statement(Arc::new(StatementInner {
sender,
name,
params,
columns,
}))
}
pub fn name(&self) -> &str {
&self.0.name
}
pub fn params(&self) -> &[Type] {
&self.0.params
}
pub fn columns(&self) -> &[Column] {
&self.0.columns
}
}

View File

@ -1,84 +0,0 @@
//! Postgres rows.
use postgres_shared::rows::RowData;
use postgres_shared::stmt::Column;
use std::error::Error;
use std::fmt;
use std::sync::Arc;
#[doc(inline)]
pub use postgres_shared::rows::RowIndex;
use types::{FromSql, WrongType};
/// A row from Postgres.
pub struct Row {
columns: Arc<Vec<Column>>,
data: RowData,
}
impl Row {
pub(crate) fn new(columns: Arc<Vec<Column>>, data: RowData) -> Row {
Row {
columns: columns,
data: data,
}
}
/// Returns information about the columns in the row.
pub fn columns(&self) -> &[Column] {
&self.columns
}
/// Returns the number of values in the row
pub fn len(&self) -> usize {
self.columns.len()
}
/// Retrieves the contents of a field of the row.
///
/// A field can be accessed by the name or index of its column, though
/// access by index is more efficient. Rows are 0-indexed.
///
/// # Panics
///
/// Panics if the index does not reference a column or the return type is
/// not compatible with the Postgres type.
pub fn get<'a, T, I>(&'a self, idx: I) -> T
where
T: FromSql<'a>,
I: RowIndex + fmt::Debug,
{
match self.try_get(&idx) {
Ok(Some(v)) => v,
Ok(None) => panic!("no such column {:?}", idx),
Err(e) => panic!("error retrieving row {:?}: {}", idx, e),
}
}
/// Retrieves the contents of a field of the row.
///
/// A field can be accessed by the name or index of its column, though
/// access by index is more efficient. Rows are 0-indexed.
///
/// Returns `None` if the index does not reference a column, `Some(Err(..))`
/// if there was an error converting the result value, and `Some(Ok(..))`
/// on success.
pub fn try_get<'a, T, I>(&'a self, idx: I) -> Result<Option<T>, Box<Error + Sync + Send>>
where
T: FromSql<'a>,
I: RowIndex,
{
let idx = match idx.__idx(&self.columns) {
Some(idx) => idx,
None => return Ok(None),
};
let ty = self.columns[idx].type_();
if !T::accepts(ty) {
return Err(Box::new(WrongType::new(ty.clone())));
}
T::from_sql_nullable(ty, self.data.get(idx)).map(Some)
}
}

View File

@ -1,162 +0,0 @@
use futures::{Sink, Future, Poll, AsyncSink, Async, Stream};
use futures::stream::Fuse;
pub trait SinkExt: Sink {
fn send2(self, item: Self::SinkItem) -> Send2<Self>
where
Self: Sized;
// unlike send_all, this doesn't close the stream.
fn send_all2<S>(self, stream: S) -> SendAll2<Self, S>
where
S: Stream<Item = Self::SinkItem>,
Self::SinkError: From<S::Error>,
Self: Sized;
}
impl<T> SinkExt for T
where
T: Sink,
{
fn send2(self, item: Self::SinkItem) -> Send2<Self>
where
Self: Sized,
{
Send2 {
sink: Some(self),
item: Some(item),
}
}
fn send_all2<S>(self, stream: S) -> SendAll2<Self, S>
where
S: Stream<Item = Self::SinkItem>,
Self::SinkError: From<S::Error>,
Self: Sized,
{
SendAll2 {
sink: Some(self),
stream: Some(stream.fuse()),
buffered: None,
}
}
}
pub struct Send2<T>
where
T: Sink,
{
sink: Option<T>,
item: Option<T::SinkItem>,
}
impl<T> Future for Send2<T>
where
T: Sink,
{
type Item = T;
type Error = (T::SinkError, T);
fn poll(&mut self) -> Poll<T, (T::SinkError, T)> {
let mut sink = self.sink.take().expect("poll called after completion");
if let Some(item) = self.item.take() {
match sink.start_send(item) {
Ok(AsyncSink::NotReady(item)) => {
self.sink = Some(sink);
self.item = Some(item);
return Ok(Async::NotReady);
}
Ok(AsyncSink::Ready) => {}
Err(e) => return Err((e, sink)),
}
}
match sink.poll_complete() {
Ok(Async::Ready(())) => {}
Ok(Async::NotReady) => {
self.sink = Some(sink);
return Ok(Async::NotReady);
}
Err(e) => return Err((e, sink)),
}
Ok(Async::Ready(sink))
}
}
pub struct SendAll2<T, U>
where
U: Stream,
{
sink: Option<T>,
stream: Option<Fuse<U>>,
buffered: Option<U::Item>,
}
impl<T, U> Future for SendAll2<T, U>
where
T: Sink,
U: Stream<Item = T::SinkItem>,
T::SinkError: From<U::Error>,
{
type Item = (T, U);
type Error = (T::SinkError, T, U);
fn poll(&mut self) -> Poll<(T, U), (T::SinkError, T, U)> {
let mut stream = self.stream.take().expect("poll called after completion");
let mut sink = self.sink.take().expect("poll called after completion");
if let Some(item) = self.buffered.take() {
match sink.start_send(item) {
Ok(AsyncSink::Ready) => {}
Ok(AsyncSink::NotReady(item)) => {
self.sink = Some(sink);
self.buffered = Some(item);
self.stream = Some(stream);
return Ok(Async::NotReady);
}
Err(e) => return Err((e, sink, stream.into_inner())),
}
}
loop {
match stream.poll() {
Ok(Async::Ready(Some(item))) => {
match sink.start_send(item) {
Ok(AsyncSink::Ready) => {}
Ok(AsyncSink::NotReady(item)) => {
self.sink = Some(sink);
self.buffered = Some(item);
self.stream = Some(stream);
return Ok(Async::NotReady);
}
Err(e) => return Err((e, sink, stream.into_inner())),
}
}
Ok(Async::Ready(None)) => {
match sink.poll_complete() {
Ok(Async::Ready(())) => return Ok(Async::Ready((sink, stream.into_inner()))),
Ok(Async::NotReady) => {
self.sink = Some(sink);
self.stream = Some(stream);
return Ok(Async::NotReady);
}
Err(e) => return Err((e, sink, stream.into_inner())),
}
}
Ok(Async::NotReady) => {
match sink.poll_complete() {
Ok(Async::Ready(())) | Ok(Async::NotReady) => {
self.sink = Some(sink);
self.stream = Some(stream);
return Ok(Async::NotReady);
}
Err(e) => return Err((e, sink, stream.into_inner())),
}
}
Err(e) => return Err((e.into(), sink, stream.into_inner())),
}
}
}
}

View File

@ -1,59 +0,0 @@
//! Prepared statements.
use std::mem;
use std::sync::Arc;
use std::sync::mpsc::Sender;
#[doc(inline)]
pub use postgres_shared::stmt::Column;
use types::Type;
/// A prepared statement.
pub struct Statement {
close_sender: Sender<(u8, String)>,
name: String,
params: Vec<Type>,
columns: Arc<Vec<Column>>,
}
impl Drop for Statement {
fn drop(&mut self) {
let name = mem::replace(&mut self.name, String::new());
let _ = self.close_sender.send((b'S', name));
}
}
impl Statement {
pub(crate) fn new(
close_sender: Sender<(u8, String)>,
name: String,
params: Vec<Type>,
columns: Arc<Vec<Column>>,
) -> Statement {
Statement {
close_sender: close_sender,
name: name,
params: params,
columns: columns,
}
}
pub(crate) fn columns_arc(&self) -> &Arc<Vec<Column>> {
&self.columns
}
pub(crate) fn name(&self) -> &str {
&self.name
}
/// Returns the types of query parameters for this statement.
pub fn parameters(&self) -> &[Type] {
&self.params
}
/// Returns information about the resulting columns for this statement.
pub fn columns(&self) -> &[Column] {
&self.columns
}
}

View File

@ -1,221 +0,0 @@
use bytes::{BufMut, BytesMut};
use futures::future::Either;
use futures::{Future, IntoFuture, Poll, Sink, Stream as FuturesStream};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend;
use postgres_shared::params::Host;
use std::io::{self, Read, Write};
use std::time::Duration;
use tokio_core::net::TcpStream;
use tokio_core::reactor::Handle;
use tokio_dns;
use tokio_io::codec::{Decoder, Encoder, Framed};
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(unix)]
use tokio_uds::UnixStream;
use error;
use tls::TlsStream;
use {BoxedFuture, Error, TlsMode};
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
pub fn connect(
host: Host,
port: u16,
keepalive: Option<Duration>,
tls_mode: TlsMode,
handle: &Handle,
) -> Box<Future<Item = PostgresStream, Error = Error> + Send> {
let inner = match host {
Host::Tcp(ref host) => Either::A(
tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
.and_then(move |s| match keepalive {
Some(keepalive) => s.set_keepalive(Some(keepalive)).map(|_| s),
None => Ok(s),
})
.map(|s| Stream(InnerStream::Tcp(s)))
.map_err(error::io),
),
#[cfg(unix)]
Host::Unix(ref host) => {
let addr = host.join(format!(".s.PGSQL.{}", port));
Either::B(
UnixStream::connect(addr, handle)
.map(|s| Stream(InnerStream::Unix(s)))
.map_err(error::io)
.into_future(),
)
}
#[cfg(not(unix))]
Host::Unix(_) => Either::B(
Err(error::connect(
"unix sockets are not supported on this platform".into(),
)).into_future(),
),
};
let (required, handshaker) = match tls_mode {
TlsMode::Require(h) => (true, h),
TlsMode::Prefer(h) => (false, h),
TlsMode::None => {
return inner
.map(|s| {
let s: Box<TlsStream> = Box::new(s);
s.framed(PostgresCodec)
})
.boxed2()
}
};
inner
.map(|s| s.framed(SslCodec))
.and_then(|s| {
let mut buf = vec![];
frontend::ssl_request(&mut buf);
s.send(buf).map_err(error::io)
})
.and_then(|s| s.into_future().map_err(|e| error::io(e.0)))
.and_then(move |(m, s)| {
let s = s.into_inner();
match (m, required) {
(Some(b'N'), true) => Either::A(
Err(error::tls("the server does not support TLS".into())).into_future(),
),
(Some(b'N'), false) => {
let s: Box<TlsStream> = Box::new(s);
Either::A(Ok(s).into_future())
}
(None, _) => Either::A(
Err(error::io(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
))).into_future(),
),
_ => {
let host = match host {
Host::Tcp(ref host) => host,
Host::Unix(_) => unreachable!(),
};
Either::B(handshaker.handshake(host, s).map_err(error::tls))
}
}
})
.map(|s| s.framed(PostgresCodec))
.boxed2()
}
/// A raw connection to the database.
pub struct Stream(InnerStream);
enum InnerStream {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self.0 {
InnerStream::Tcp(ref mut s) => s.read(buf),
#[cfg(unix)]
InnerStream::Unix(ref mut s) => s.read(buf),
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.0 {
InnerStream::Tcp(ref mut s) => s.write(buf),
#[cfg(unix)]
InnerStream::Unix(ref mut s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self.0 {
InnerStream::Tcp(ref mut s) => s.flush(),
#[cfg(unix)]
InnerStream::Unix(ref mut s) => s.flush(),
}
}
}
impl AsyncRead for Stream {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match self.0 {
InnerStream::Tcp(ref s) => s.prepare_uninitialized_buffer(buf),
#[cfg(unix)]
InnerStream::Unix(ref s) => s.prepare_uninitialized_buffer(buf),
}
}
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
match self.0 {
InnerStream::Tcp(ref mut s) => s.read_buf(buf),
#[cfg(unix)]
InnerStream::Unix(ref mut s) => s.read_buf(buf),
}
}
}
impl AsyncWrite for Stream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.0 {
InnerStream::Tcp(ref mut s) => s.shutdown(),
#[cfg(unix)]
InnerStream::Unix(ref mut s) => s.shutdown(),
}
}
}
pub struct PostgresCodec;
impl Decoder for PostgresCodec {
type Item = backend::Message;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<backend::Message>> {
backend::Message::parse(buf)
}
}
impl Encoder for PostgresCodec {
type Item = Vec<u8>;
type Error = io::Error;
fn encode(&mut self, msg: Vec<u8>, buf: &mut BytesMut) -> io::Result<()> {
buf.extend(&msg);
Ok(())
}
}
struct SslCodec;
impl Decoder for SslCodec {
type Item = u8;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u8>> {
if buf.is_empty() {
Ok(None)
} else {
Ok(Some(buf.split_to(1)[0]))
}
}
}
impl Encoder for SslCodec {
type Item = Vec<u8>;
type Error = io::Error;
fn encode(&mut self, msg: Vec<u8>, buf: &mut BytesMut) -> io::Result<()> {
buf.extend(&msg);
Ok(())
}
}

View File

@ -1,481 +0,0 @@
use futures::{Future, Stream};
use futures_state_stream::StateStream;
use std::error::Error as StdError;
use std::path::PathBuf;
use std::time::Duration;
use tokio_core::reactor::{Core, Interval};
use super::*;
use error::SqlState;
use params::{ConnectParams, Host};
use types::{FromSql, IsNull, Kind, ToSql, Type};
#[test]
fn md5_user() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://md5_user:password@localhost:5433/postgres",
TlsMode::None,
&handle,
);
l.run(done).unwrap();
}
#[test]
fn md5_user_no_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://md5_user@localhost:5433/postgres",
TlsMode::None,
&handle,
);
match l.run(done) {
Err(ref e) if e.as_connection().is_some() => {}
Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"),
}
}
#[test]
fn md5_user_wrong_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://md5_user:foobar@localhost:5433/postgres",
TlsMode::None,
&handle,
);
match l.run(done) {
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"),
}
}
#[test]
fn pass_user() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://pass_user:password@localhost:5433/postgres",
TlsMode::None,
&handle,
);
l.run(done).unwrap();
}
#[test]
fn pass_user_no_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://pass_user@localhost:5433/postgres",
TlsMode::None,
&handle,
);
match l.run(done) {
Err(ref e) if e.as_connection().is_some() => {}
Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"),
}
}
#[test]
fn pass_user_wrong_pass() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://pass_user:foobar@localhost:5433/postgres",
TlsMode::None,
&handle,
);
match l.run(done) {
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("unexpected error {}", e),
Ok(_) => panic!("unexpected success"),
}
}
#[test]
fn batch_execute_ok() {
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
&l.handle(),
).then(|c| {
c.unwrap()
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL);")
});
l.run(done).unwrap();
}
#[test]
fn batch_execute_err() {
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
&l.handle(),
).then(|r| {
r.unwrap().batch_execute(
"CREATE TEMPORARY TABLE foo (id SERIAL); INSERT INTO foo DEFAULT \
VALUES;",
)
})
.and_then(|c| c.batch_execute("SELECT * FROM bogo"))
.then(|r| match r {
Err((e, s)) => {
assert_eq!(e.code(), Some(&SqlState::UNDEFINED_TABLE));
s.batch_execute("SELECT * FROM foo")
}
Ok(_) => panic!("unexpected success"),
});
l.run(done).unwrap();
}
#[test]
fn prepare_execute() {
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
&l.handle(),
).then(|c| {
c.unwrap()
.prepare("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)")
})
.and_then(|(s, c)| c.execute(&s, &[]))
.and_then(|(n, c)| {
assert_eq!(0, n);
c.prepare("INSERT INTO foo (name) VALUES ($1), ($2)")
})
.and_then(|(s, c)| c.execute(&s, &[&"steven", &"bob"]))
.map(|(n, _)| assert_eq!(n, 2));
l.run(done).unwrap();
}
#[test]
fn prepare_execute_rows() {
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
&l.handle(),
).then(|c| c.unwrap().prepare("SELECT 1"))
.and_then(|(s, c)| c.execute(&s, &[]));
l.run(done).unwrap();
}
#[test]
fn query() {
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
&l.handle(),
).then(|c| {
c.unwrap().batch_execute(
"CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);
INSERT INTO foo (name) VALUES ('joe'), ('bob')",
)
})
.and_then(|c| c.prepare("SELECT id, name FROM foo ORDER BY id"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.and_then(|(r, c)| {
assert_eq!(r[0].get::<i32, _>("id"), 1);
assert_eq!(r[0].get::<String, _>("name"), "joe");
assert_eq!(r[1].get::<i32, _>("id"), 2);
assert_eq!(r[1].get::<String, _>("name"), "bob");
c.prepare("")
})
.and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| assert!(r.is_empty()));
l.run(done).unwrap();
}
#[test]
fn transaction() {
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://postgres@localhost:5433",
TlsMode::None,
&l.handle(),
).then(|c| {
c.unwrap()
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);")
})
.then(|c| c.unwrap().transaction())
.then(|t| {
t.unwrap()
.batch_execute("INSERT INTO foo (name) VALUES ('joe');")
})
.then(|t| t.unwrap().rollback())
.then(|c| c.unwrap().transaction())
.then(|t| {
t.unwrap()
.batch_execute("INSERT INTO foo (name) VALUES ('bob');")
})
.then(|t| t.unwrap().commit())
.then(|c| c.unwrap().prepare("SELECT name FROM foo"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| {
assert_eq!(r.len(), 1);
assert_eq!(r[0].get::<String, _>("name"), "bob");
});
l.run(done).unwrap();
}
#[test]
#[ignore] // not supported on our CI setup :(
fn unix_socket() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
.then(|c| c.unwrap().prepare("SHOW unix_socket_directories"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.then(|r| {
let r = r.unwrap().0;
let params = ConnectParams::builder()
.user("postgres", None)
.build(Host::Unix(PathBuf::from(r[0].get::<String, _>(0))));
Connection::connect(params, TlsMode::None, &handle)
})
.then(|c| c.unwrap().batch_execute(""));
l.run(done).unwrap();
}
#[test]
fn ssl_user_ssl_required() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::None,
&handle,
);
match l.run(done) {
Err(ref e) => assert_eq!(
e.code(),
Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION)
),
Ok(_) => panic!("unexpected success"),
}
}
#[cfg(feature = "with-openssl")]
#[test]
fn openssl_required() {
use tls::openssl::openssl::ssl::{SslConnector, SslMethod};
use tls::openssl::OpenSsl;
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let negotiator = OpenSsl::from(builder.build());
let mut l = Core::new().unwrap();
let done = Connection::connect(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::Require(Box::new(negotiator)),
&l.handle(),
).then(|c| c.unwrap().prepare("SELECT 1"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| assert_eq!(r[0].get::<i32, _>(0), 1));
l.run(done).unwrap();
}
#[test]
fn domain() {
#[derive(Debug, PartialEq)]
struct SessionId(Vec<u8>);
impl ToSql for SessionId {
fn to_sql(
&self,
ty: &Type,
out: &mut Vec<u8>,
) -> Result<IsNull, Box<StdError + Sync + Send>> {
let inner = match *ty.kind() {
Kind::Domain(ref inner) => inner,
_ => unreachable!(),
};
self.0.to_sql(inner, out)
}
fn accepts(ty: &Type) -> bool {
match *ty.kind() {
Kind::Domain(Type::BYTEA) => ty.name() == "session_id",
_ => false,
}
}
to_sql_checked!();
}
impl<'a> FromSql<'a> for SessionId {
fn from_sql(ty: &Type, raw: &[u8]) -> Result<Self, Box<StdError + Sync + Send>> {
Vec::<u8>::from_sql(ty, raw).map(SessionId)
}
fn accepts(ty: &Type) -> bool {
// This is super weird!
<Vec<u8> as FromSql>::accepts(ty)
}
}
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
.then(|c| {
c.unwrap().batch_execute(
"CREATE DOMAIN pg_temp.session_id AS bytea \
CHECK(octet_length(VALUE) = 16);
CREATE \
TABLE pg_temp.foo (id pg_temp.session_id);",
)
})
.and_then(|c| c.prepare("INSERT INTO pg_temp.foo (id) VALUES ($1)"))
.and_then(|(s, c)| {
let id = SessionId(b"0123456789abcdef".to_vec());
c.execute(&s, &[&id])
})
.and_then(|(_, c)| c.prepare("SELECT id FROM pg_temp.foo"))
.and_then(|(s, c)| c.query(&s, &[]).collect())
.map(|(r, _)| {
let id = SessionId(b"0123456789abcdef".to_vec());
assert_eq!(id, r[0].get(0));
});
l.run(done).unwrap();
}
#[test]
fn composite() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
.then(|c| {
c.unwrap().batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
name TEXT,
supplier INTEGER,
price NUMERIC
)",
)
})
.and_then(|c| c.prepare("SELECT $1::inventory_item"))
.map(|(s, _)| {
let type_ = &s.parameters()[0];
assert_eq!(type_.name(), "inventory_item");
match *type_.kind() {
Kind::Composite(ref fields) => {
assert_eq!(fields[0].name(), "name");
assert_eq!(fields[0].type_(), &Type::TEXT);
assert_eq!(fields[1].name(), "supplier");
assert_eq!(fields[1].type_(), &Type::INT4);
assert_eq!(fields[2].name(), "price");
assert_eq!(fields[2].type_(), &Type::NUMERIC);
}
ref t => panic!("bad type {:?}", t),
}
});
l.run(done).unwrap();
}
#[test]
fn enum_() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
.then(|c| {
c.unwrap()
.batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');")
})
.and_then(|c| c.prepare("SELECT $1::mood"))
.map(|(s, _)| {
let type_ = &s.parameters()[0];
assert_eq!(type_.name(), "mood");
match *type_.kind() {
Kind::Enum(ref variants) => {
assert_eq!(
variants,
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]
);
}
_ => panic!("bad type"),
}
});
l.run(done).unwrap();
}
#[test]
fn cancel() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
.then(move |c| {
let c = c.unwrap();
let cancel_data = c.cancel_data();
let cancel = Interval::new(Duration::from_secs(1), &handle)
.unwrap()
.into_future()
.then(move |r| {
assert!(r.is_ok());
cancel_query(
"postgres://postgres@localhost:5433",
TlsMode::None,
cancel_data,
&handle,
)
})
.then(Ok::<_, ()>);
c.batch_execute("SELECT pg_sleep(10)")
.then(Ok::<_, ()>)
.join(cancel)
});
let (select, cancel) = l.run(done).unwrap();
cancel.unwrap();
match select {
Err((e, _)) => assert_eq!(e.code(), Some(&SqlState::QUERY_CANCELED)),
Ok(_) => panic!("unexpected success"),
}
}
#[test]
fn notifications() {
let mut l = Core::new().unwrap();
let handle = l.handle();
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
.then(|c| c.unwrap().batch_execute("LISTEN test_notifications"))
.and_then(|c1| {
Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle).then(
|c2| {
c2.unwrap()
.batch_execute("NOTIFY test_notifications, 'foo'")
.map(|_| c1)
},
)
})
.and_then(|c| {
c.notifications()
.into_future()
.map_err(|(e, n)| (e, n.into_inner()))
})
.map(|(n, _)| {
let n = n.unwrap();
assert_eq!(n.channel, "test_notifications");
assert_eq!(n.payload, "foo");
});
l.run(done).unwrap();
}

63
tokio-postgres/src/tls.rs Normal file
View File

@ -0,0 +1,63 @@
use bytes::{Buf, BufMut};
use futures::{Future, Poll};
use std::error::Error;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
use proto;
pub struct Socket(pub(crate) proto::Socket);
impl Read for Socket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl AsyncRead for Socket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
self.0.read_buf(buf)
}
}
impl Write for Socket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl AsyncWrite for Socket {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown()
}
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: Buf,
{
self.0.write_buf(buf)
}
}
pub trait TlsConnect {
fn connect(
&self,
domain: &str,
socket: Socket,
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send>;
}
pub trait TlsStream: 'static + Sync + Send + AsyncRead + AsyncWrite {}
impl TlsStream for proto::Socket {}

View File

@ -1,39 +0,0 @@
//! TLS support.
use futures::Future;
use std::error::Error;
use tokio_io::{AsyncRead, AsyncWrite};
pub use stream::Stream;
#[cfg(feature = "with-openssl")]
pub mod openssl;
/// A trait implemented by streams returned from `Handshake` implementations.
pub trait TlsStream: AsyncRead + AsyncWrite + Send {
/// Returns a shared reference to the inner stream.
fn get_ref(&self) -> &Stream;
/// Returns a mutable reference to the inner stream.
fn get_mut(&mut self) -> &mut Stream;
}
impl TlsStream for Stream {
fn get_ref(&self) -> &Stream {
self
}
fn get_mut(&mut self) -> &mut Stream {
self
}
}
/// A trait implemented by types that can manage TLS encryption for a stream.
pub trait Handshake: 'static + Sync + Send {
/// Performs a TLS handshake, returning a wrapped stream.
fn handshake(
&self,
host: &str,
stream: Stream,
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Send>;
}

View File

@ -1,59 +0,0 @@
//! OpenSSL support.
extern crate tokio_openssl;
pub extern crate openssl;
use futures::Future;
use self::openssl::ssl::{SslMethod, SslConnector};
use self::openssl::error::ErrorStack;
use std::error::Error;
use self::tokio_openssl::{SslConnectorExt, SslStream};
use BoxedFuture;
use tls::{Stream, TlsStream, Handshake};
impl TlsStream for SslStream<Stream> {
fn get_ref(&self) -> &Stream {
self.get_ref().get_ref()
}
fn get_mut(&mut self) -> &mut Stream {
self.get_mut().get_mut()
}
}
/// A `Handshake` implementation using OpenSSL.
pub struct OpenSsl(SslConnector);
impl OpenSsl {
/// Creates a new `OpenSsl` with default settings.
pub fn new() -> Result<OpenSsl, ErrorStack> {
let connector = SslConnector::builder(SslMethod::tls())?.build();
Ok(OpenSsl(connector))
}
}
impl From<SslConnector> for OpenSsl {
fn from(connector: SslConnector) -> OpenSsl {
OpenSsl(connector)
}
}
impl Handshake for OpenSsl {
fn handshake(
&self,
host: &str,
stream: Stream,
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Send> {
self.0
.connect_async(host, stream)
.map(|s| {
let s: Box<TlsStream> = Box::new(s);
s
})
.map_err(|e| {
let e: Box<Error + Sync + Send> = Box::new(e);
e
})
.boxed2()
}
}

View File

@ -1,90 +0,0 @@
//! Transactions.
use futures::Future;
use futures_state_stream::StateStream;
use {Connection, BoxedFuture, BoxedStateStream};
use error::Error;
use stmt::Statement;
use types::ToSql;
use rows::Row;
/// An in progress Postgres transaction.
#[derive(Debug)]
pub struct Transaction(Connection);
impl Transaction {
pub(crate) fn new(c: Connection) -> Transaction {
Transaction(c)
}
/// Like `Connection::batch_execute`.
pub fn batch_execute(
self,
query: &str,
) -> Box<Future<Item = Transaction, Error = (Error, Transaction)> + Send> {
self.0
.batch_execute(query)
.map(Transaction)
.map_err(transaction_err)
.boxed2()
}
/// Like `Connection::prepare`.
pub fn prepare(
self,
query: &str,
) -> Box<Future<Item = (Statement, Transaction), Error = (Error, Transaction)> + Send> {
self.0
.prepare(query)
.map(|(s, c)| (s, Transaction(c)))
.map_err(transaction_err)
.boxed2()
}
/// Like `Connection::execute`.
pub fn execute(
self,
statement: &Statement,
params: &[&ToSql],
) -> Box<Future<Item = (u64, Transaction), Error = (Error, Transaction)> + Send> {
self.0
.execute(statement, params)
.map(|(n, c)| (n, Transaction(c)))
.map_err(transaction_err)
.boxed2()
}
/// Like `Connection::query`.
pub fn query(
self,
statement: &Statement,
params: &[&ToSql],
) -> Box<StateStream<Item = Row, State = Transaction, Error = Error> + Send> {
self.0
.query(statement, params)
.map_state(Transaction)
.boxed2()
}
/// Commits the transaction.
pub fn commit(self) -> Box<Future<Item = Connection, Error = (Error, Connection)> + Send> {
self.finish("COMMIT")
}
/// Rolls back the transaction.
pub fn rollback(self) -> Box<Future<Item = Connection, Error = (Error, Connection)> + Send> {
self.finish("ROLLBACK")
}
fn finish(
self,
query: &str,
) -> Box<Future<Item = Connection, Error = (Error, Connection)> + Send> {
self.0.simple_query(query).map(|(_, c)| c).boxed2()
}
}
fn transaction_err((e, c): (Error, Connection)) -> (Error, Transaction) {
(e, Transaction(c))
}

View File

@ -0,0 +1,216 @@
extern crate env_logger;
extern crate tokio;
extern crate tokio_postgres;
use tokio::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::error::SqlState;
use tokio_postgres::types::Type;
use tokio_postgres::TlsMode;
fn smoke_test(url: &str) {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(url.parse().unwrap(), TlsMode::None);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare = client.prepare("SELECT 1::INT4");
let statement = runtime.block_on(prepare).unwrap();
let select = client.query(&statement, &[]).collect().map(|rows| {
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
});
runtime.block_on(select).unwrap();
drop(statement);
drop(client);
runtime.run().unwrap();
}
#[test]
fn plain_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://pass_user@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.as_connection().is_some() => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn plain_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://pass_user:foo@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn plain_password_ok() {
smoke_test("postgres://pass_user:password@localhost:5433/postgres");
}
#[test]
fn md5_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://md5_user@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.as_connection().is_some() => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn md5_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://md5_user:foo@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn md5_password_ok() {
smoke_test("postgres://md5_user:password@localhost:5433/postgres");
}
#[test]
fn scram_password_missing() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://scram_user@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.as_connection().is_some() => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn scram_password_wrong() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://scram_user:foo@localhost:5433".parse().unwrap(),
TlsMode::None,
);
match runtime.block_on(handshake) {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
}
}
#[test]
fn scram_password_ok() {
smoke_test("postgres://scram_user:password@localhost:5433/postgres");
}
#[test]
fn pipelined_prepare() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
let prepare1 = client.prepare("SELECT 1::BIGINT WHERE $1::BOOL");
let prepare2 = client.prepare("SELECT ''::TEXT, 1::FLOAT4 WHERE $1::VARCHAR IS NOT NULL");
let prepare = prepare1.join(prepare2);
let (statement1, statement2) = runtime.block_on(prepare).unwrap();
assert_eq!(statement1.params(), &[Type::BOOL]);
assert_eq!(statement1.columns().len(), 1);
assert_eq!(statement1.columns()[0].type_(), &Type::INT8);
assert_eq!(statement2.params(), &[Type::VARCHAR]);
assert_eq!(statement2.columns().len(), 2);
assert_eq!(statement2.columns()[0].type_(), &Type::TEXT);
assert_eq!(statement2.columns()[1].type_(), &Type::FLOAT4);
drop(statement1);
drop(statement2);
drop(client);
runtime.run().unwrap();
}
#[test]
fn insert_select() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let handshake = tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(
client
.prepare("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")
.and_then(|create| client.execute(&create, &[]))
.map(|n| assert_eq!(n, 0)),
)
.unwrap();
let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)");
let select = client.prepare("SELECT id, name FROM foo ORDER BY id");
let prepare = insert.join(select);
let (insert, select) = runtime.block_on(prepare).unwrap();
let insert = client
.execute(&insert, &[&"alice", &"bob"])
.map(|n| assert_eq!(n, 2));
let select = client.query(&select, &[]).collect().map(|rows| {
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, &str>(1), "alice");
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, &str>(1), "bob");
});
let tests = insert.join(select);
runtime.block_on(tests).unwrap();
}