rust-postgres/postgres-tokio/src/lib.rs
2016-12-20 19:50:44 -08:00

365 lines
13 KiB
Rust

extern crate fallible_iterator;
extern crate futures;
extern crate postgres_shared;
extern crate postgres_protocol;
extern crate tokio_core;
extern crate tokio_dns;
extern crate tokio_uds;
use fallible_iterator::FallibleIterator;
use futures::{Future, IntoFuture, BoxFuture, Stream, Sink, Poll, StartSend};
use futures::future::Either;
use postgres_protocol::authentication;
use postgres_protocol::message::{backend, frontend};
use postgres_protocol::message::backend::ErrorFields;
use postgres_shared::RowData;
use std::collections::HashMap;
use std::fmt;
use std::io;
use tokio_core::reactor::Handle;
#[doc(inline)]
pub use postgres_shared::params;
use error::{ConnectError, Error, DbError};
use params::{ConnectParams, IntoConnectParams};
use stream::PostgresStream;
pub mod error;
mod stream;
#[cfg(test)]
mod test;
#[derive(Debug, Copy, Clone)]
pub struct CancelData {
pub process_id: i32,
pub secret_key: i32,
}
struct InnerConnection {
stream: PostgresStream,
parameters: HashMap<String, String>,
cancel_data: CancelData,
}
impl InnerConnection {
fn read(self) -> BoxFuture<(backend::Message<Vec<u8>>, InnerConnection), (io::Error, InnerConnection)> {
self.into_future()
.then(|r| {
let (m, mut s) = match r {
Ok((m, s)) => (m, s),
Err((e, s)) => return Either::A(Err((e, s)).into_future()),
};
match m {
Some(backend::Message::ParameterStatus(body)) => {
let name = match body.name() {
Ok(name) => name.to_owned(),
Err(e) => return Either::A(Err((e, s)).into_future()),
};
let value = match body.value() {
Ok(value) => value.to_owned(),
Err(e) => return Either::A(Err((e, s)).into_future()),
};
s.parameters.insert(name, value);
Either::B(s.read())
}
Some(backend::Message::NoticeResponse(_)) => {
// TODO forward the error
Either::B(s.read())
}
Some(m) => Either::A(Ok((m, s)).into_future()),
None => Either::A(Err((eof(), s)).into_future()),
}
})
.boxed()
}
}
impl Stream for InnerConnection {
type Item = backend::Message<Vec<u8>>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<backend::Message<Vec<u8>>>, io::Error> {
self.stream.poll()
}
}
impl Sink for InnerConnection {
type SinkItem = Vec<u8>;
type SinkError = io::Error;
fn start_send(&mut self, item: Vec<u8>) -> StartSend<Vec<u8>, io::Error> {
self.stream.start_send(item)
}
fn poll_complete(&mut self) -> Poll<(), io::Error> {
self.stream.poll_complete()
}
}
pub struct Connection(InnerConnection);
// FIXME fill out
impl fmt::Debug for Connection {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Connection")
.finish()
}
}
impl Connection {
pub fn connect<T>(params: T, handle: &Handle) -> BoxFuture<Connection, ConnectError>
where T: IntoConnectParams
{
let params = match params.into_connect_params() {
Ok(params) => params,
Err(e) => return futures::failed(ConnectError::ConnectParams(e)).boxed(),
};
stream::connect(params.host(), params.port(), handle)
.map_err(ConnectError::Io)
.map(|s| {
Connection(InnerConnection {
stream: s,
parameters: HashMap::new(),
cancel_data: CancelData {
process_id: 0,
secret_key: 0,
}
})
})
.and_then(|s| s.startup(params))
.and_then(|(s, params)| s.handle_auth(params))
.and_then(|s| s.finish_startup())
.boxed()
}
fn startup(self, params: ConnectParams) -> BoxFuture<(Connection, ConnectParams), ConnectError> {
let mut buf = vec![];
let result = {
let options = [("client_encoding", "UTF8"), ("timezone", "GMT")];
let options = options.iter().cloned();
let options = options.chain(params.user().map(|u| ("user", u.name())));
let options = options.chain(params.database().map(|d| ("database", d)));
let options = options.chain(params.options().iter().map(|e| (&*e.0, &*e.1)));
frontend::startup_message(options, &mut buf)
};
result
.into_future()
.and_then(move |()| self.0.send(buf))
.and_then(|s| s.flush())
.map_err(ConnectError::Io)
.map(move |s| (Connection(s), params))
.boxed()
}
fn handle_auth(self, params: ConnectParams) -> BoxFuture<Connection, ConnectError> {
self.0.read()
.map_err(|e| e.0.into())
.and_then(move |(m, s)| {
let response = match m {
backend::Message::AuthenticationOk => Ok(None),
backend::Message::AuthenticationCleartextPassword => {
match params.user().and_then(|u| u.password()) {
Some(pass) => {
let mut buf = vec![];
frontend::password_message(pass, &mut buf)
.map(|()| Some(buf))
.map_err(Into::into)
}
None => {
Err(ConnectError::ConnectParams(
"password was required but not provided".into()))
}
}
}
backend::Message::AuthenticationMd5Password(body) => {
match params.user().and_then(|u| u.password().map(|p| (u.name(), p))) {
Some((user, pass)) => {
let pass = authentication::md5_hash(user.as_bytes(),
pass.as_bytes(),
body.salt());
let mut buf = vec![];
frontend::password_message(&pass, &mut buf)
.map(|()| Some(buf))
.map_err(Into::into)
}
None => {
Err(ConnectError::ConnectParams(
"password was required but not provided".into()))
}
}
}
backend::Message::ErrorResponse(body) => Err(connect_err(&mut body.fields())),
_ => Err(bad_message()),
};
response.map(|m| (m, Connection(s)))
})
.and_then(|(m, s)| {
match m {
Some(m) => Either::A(s.handle_auth_response(m)),
None => Either::B(Ok(s).into_future())
}
})
.boxed()
}
fn handle_auth_response(self, message: Vec<u8>) -> BoxFuture<Connection, ConnectError> {
self.0.send(message)
.and_then(|s| s.flush())
.and_then(|s| s.read().map_err(|e| e.0))
.map_err(ConnectError::Io)
.and_then(|(m, s)| {
match m {
backend::Message::AuthenticationOk => Ok(Connection(s)),
backend::Message::ErrorResponse(body) => Err(connect_err(&mut body.fields())),
_ => Err(bad_message()),
}
})
.boxed()
}
fn finish_startup(self) -> BoxFuture<Connection, ConnectError> {
self.0.read()
.map_err(|e| ConnectError::Io(e.0))
.and_then(|(m, mut s)| {
match m {
backend::Message::BackendKeyData(body) => {
s.cancel_data.process_id = body.process_id();
s.cancel_data.secret_key = body.secret_key();
Either::A(Connection(s).finish_startup())
}
backend::Message::ReadyForQuery(_) => Either::B(Ok(Connection(s)).into_future()),
backend::Message::ErrorResponse(body) => {
Either::B(Err(connect_err(&mut body.fields())).into_future())
}
_ => Either::B(Err(bad_message()).into_future()),
}
})
.boxed()
}
fn simple_query(self, query: &str) -> BoxFuture<(Vec<RowData>, Connection), Error> {
let mut buf = vec![];
frontend::query(query, &mut buf)
.map(|()| buf)
.into_future()
.and_then(move |buf| self.0.send(buf))
.and_then(|s| s.flush())
.map_err(Error::Io)
.and_then(|s| Connection(s).simple_read_rows(vec![]))
.boxed()
}
// This has its own read_rows since it will need to handle multiple query completions
fn simple_read_rows(self, mut rows: Vec<RowData>) -> BoxFuture<(Vec<RowData>, Connection), Error> {
self.0.read()
.map_err(|e| Error::Io(e.0))
.and_then(|(m, s)| {
match m {
backend::Message::ReadyForQuery(_) => {
Either::A(Ok((rows, Connection(s))).into_future())
}
backend::Message::DataRow(body) => {
match body.values().collect() {
Ok(row) => {
rows.push(row);
Either::B(Connection(s).simple_read_rows(rows))
}
Err(e) => Either::A(Err(Error::Io(e)).into_future()),
}
}
backend::Message::EmptyQueryResponse |
backend::Message::CommandComplete(_) => {
Either::B(Connection(s).simple_read_rows(rows))
}
backend::Message::ErrorResponse(body) => {
Either::A(Err(err(&mut body.fields(), Connection(s))).into_future())
}
_ => Either::A(Err(bad_message()).into_future()),
}
})
.boxed()
}
fn read_rows(self, mut rows: Vec<RowData>) -> BoxFuture<(Vec<RowData>, Connection), Error> {
self.0.read()
.map_err(|e| Error::Io(e.0))
.and_then(|(m, s)| {
match m {
backend::Message::EmptyQueryResponse |
backend::Message::CommandComplete(_) => {
Either::B(Connection(s).ready(rows))
},
backend::Message::DataRow(body) => {
match body.values().collect() {
Ok(row) => {
rows.push(row);
Either::B(Connection(s).read_rows(rows))
}
Err(e) => Either::A(Err(Error::Io(e)).into_future()),
}
}
backend::Message::ErrorResponse(body) => {
Either::A(Err(err(&mut body.fields(), Connection(s))).into_future())
}
_ => Either::A(Err(bad_message()).into_future()),
}
})
.boxed()
}
fn ready<T>(self, t: T) -> BoxFuture<(T, Connection), Error>
where T: 'static + Send
{
self.0.read()
.map_err(|e| Error::Io(e.0))
.and_then(|(m, s)| {
match m {
backend::Message::ReadyForQuery(_) => Ok((t, Connection(s))),
_ => Err(bad_message())
}
})
.boxed()
}
pub fn batch_execute(self, query: &str) -> BoxFuture<Connection, Error> {
self.simple_query(query).map(|r| r.1).boxed()
}
pub fn cancel_data(&self) -> CancelData {
self.0.cancel_data
}
}
fn connect_err(fields: &mut ErrorFields) -> ConnectError {
match DbError::new(fields) {
Ok(err) => ConnectError::Db(Box::new(err)),
Err(err) => ConnectError::Io(err),
}
}
fn err(fields: &mut ErrorFields, conn: Connection) -> Error {
match DbError::new(fields) {
Ok(err) => Error::Db(Box::new(err), conn),
Err(err) => Error::Io(err),
}
}
fn bad_message<T>() -> T
where T: From<io::Error>
{
io::Error::new(io::ErrorKind::InvalidInput, "unexpected message").into()
}
fn eof<T>() -> T
where T: From<io::Error>
{
io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF").into()
}