diff --git a/Cargo.toml b/Cargo.toml index ad710159..9681338d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,2 @@ [workspace] -members = ["codegen", "postgres", "postgres-shared"] +members = ["codegen", "postgres", "postgres-shared", "postgres-tokio"] diff --git a/postgres-tokio/Cargo.toml b/postgres-tokio/Cargo.toml new file mode 100644 index 00000000..bd28c39d --- /dev/null +++ b/postgres-tokio/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "postgres-tokio" +version = "0.1.0" +authors = ["Steven Fackler "] + +[dependencies] +futures = "0.1.7" +postgres-shared = { path = "../postgres-shared" } +postgres-protocol = "0.2" +tokio-core = "0.1" +tokio-dns-unofficial = "0.1" +tokio-uds = "0.1" diff --git a/postgres-tokio/src/lib.rs b/postgres-tokio/src/lib.rs new file mode 100644 index 00000000..9ccc383f --- /dev/null +++ b/postgres-tokio/src/lib.rs @@ -0,0 +1,260 @@ +extern crate postgres_shared; +extern crate postgres_protocol; +extern crate tokio_core; +extern crate tokio_dns; +extern crate tokio_uds; + +#[macro_use] +extern crate futures; + +use futures::{Future, IntoFuture, BoxFuture, Stream, Sink, Poll, StartSend}; +use futures::future::Either; +use postgres_shared::params::{ConnectParams, IntoConnectParams}; +use postgres_protocol::authentication; +use postgres_protocol::message::{backend, frontend}; +use std::collections::HashMap; +use std::error::Error; +use std::io; +use tokio_core::reactor::Handle; + +use stream::PostgresStream; + +mod stream; + +#[cfg(test)] +mod test; + +#[derive(Debug)] +pub enum ConnectError { + Params(Box), + Io(io::Error), +} + +#[derive(Debug, Copy, Clone)] +pub struct CancelData { + pub process_id: i32, + pub secret_key: i32, +} + +impl From for ConnectError { + fn from(e: io::Error) -> ConnectError { + ConnectError::Io(e) + } +} + +struct InnerConnectionState { + parameters: HashMap, + cancel_data: CancelData, +} + +struct InnerConnection { + stream: PostgresStream, + state: InnerConnectionState, +} + +impl InnerConnection { + fn read(self) -> BoxFuture<(backend::Message>, InnerConnection), (io::Error, InnerConnection)> { + self.into_future() + .then(|r| { + let (m, mut s) = match r { + Ok((m, s)) => (m, s), + Err((e, s)) => return Either::A(Err((e, s)).into_future()), + }; + + match m { + Some(backend::Message::ParameterStatus(body)) => { + let name = match body.name() { + Ok(name) => name.to_owned(), + Err(e) => return Either::A(Err((e, s)).into_future()), + }; + let value = match body.value() { + Ok(value) => value.to_owned(), + Err(e) => return Either::A(Err((e, s)).into_future()), + }; + s.state.parameters.insert(name, value); + Either::B(s.read()) + } + Some(backend::Message::NoticeResponse(_)) => { + // TODO forward the error + Either::B(s.read()) + } + Some(m) => Either::A(Ok((m, s)).into_future()), + None => Either::A(Err((eof(), s)).into_future()), + } + }) + .boxed() + } +} + +impl Stream for InnerConnection { + type Item = backend::Message>; + type Error = io::Error; + + fn poll(&mut self) -> Poll>>, io::Error> { + self.stream.poll() + } +} + +impl Sink for InnerConnection { + type SinkItem = Vec; + type SinkError = io::Error; + + fn start_send(&mut self, item: Vec) -> StartSend, io::Error> { + self.stream.start_send(item) + } + + fn poll_complete(&mut self) -> Poll<(), io::Error> { + self.stream.poll_complete() + } +} + +pub struct Connection(InnerConnection); + +impl Connection { + pub fn connect(params: T, handle: &Handle) -> BoxFuture + where T: IntoConnectParams + { + let params = match params.into_connect_params() { + Ok(params) => params, + Err(e) => return futures::failed(ConnectError::Params(e)).boxed(), + }; + + stream::connect(params.host(), params.port(), handle) + .map_err(ConnectError::Io) + .map(|s| { + Connection(InnerConnection { + stream: s, + state: InnerConnectionState { + parameters: HashMap::new(), + cancel_data: CancelData { + process_id: 0, + secret_key: 0, + } + } + }) + }) + .and_then(|s| s.startup(params)) + .and_then(|(s, params)| s.handle_auth(params)) + .and_then(|s| s.finish_startup()) + .boxed() + } + + fn startup(self, params: ConnectParams) -> BoxFuture<(Connection, ConnectParams), ConnectError> { + let mut buf = vec![]; + let result = { + let options = [("client_encoding", "UTF8"), ("timezone", "GMT")]; + let options = options.iter().cloned(); + let options = options.chain(params.user().map(|u| ("user", u.name()))); + let options = options.chain(params.database().map(|d| ("database", d))); + let options = options.chain(params.options().iter().map(|e| (&*e.0, &*e.1))); + + frontend::startup_message(options, &mut buf) + }; + + result + .into_future() + .and_then(move |()| self.0.send(buf)) + .and_then(|s| s.flush()) + .map_err(ConnectError::Io) + .map(move |s| (Connection(s), params)) + .boxed() + } + + fn handle_auth(self, params: ConnectParams) -> BoxFuture { + self.0.read() + .map_err(|e| e.0.into()) + .and_then(move |(m, s)| { + let response = match m { + backend::Message::AuthenticationOk => Ok(None), + backend::Message::AuthenticationCleartextPassword => { + match params.user().and_then(|u| u.password()) { + Some(pass) => { + let mut buf = vec![]; + frontend::password_message(pass, &mut buf) + .map(|()| Some(buf)) + .map_err(Into::into) + } + None => { + Err(ConnectError::Params( + "password was required but not provided".into())) + } + } + } + backend::Message::AuthenticationMd5Password(body) => { + match params.user().and_then(|u| u.password().map(|p| (u.name(), p))) { + Some((user, pass)) => { + let pass = authentication::md5_hash(user.as_bytes(), + pass.as_bytes(), + body.salt()); + let mut buf = vec![]; + frontend::password_message(&pass, &mut buf) + .map(|()| Some(buf)) + .map_err(Into::into) + } + None => { + Err(ConnectError::Params( + "password was required but not provided".into())) + } + } + } + _ => Err(bad_message()), + }; + + response.map(|m| (m, Connection(s))) + }) + .and_then(|(m, s)| { + match m { + Some(m) => Either::A(s.handle_auth_response(m)), + None => Either::B(Ok(s).into_future()) + } + }) + .boxed() + } + + fn handle_auth_response(self, message: Vec) -> BoxFuture { + self.0.send(message) + .and_then(|s| s.flush()) + .and_then(|s| s.read().map_err(|e| e.0)) + .map_err(ConnectError::Io) + .and_then(|(m, s)| { + match m { + backend::Message::AuthenticationOk => Ok(Connection(s)), + _ => Err(bad_message()), + } + }) + .boxed() + } + + fn finish_startup(self) -> BoxFuture { + self.0.read() + .map_err(|e| ConnectError::Io(e.0)) + .and_then(|(m, mut s)| { + match m { + backend::Message::BackendKeyData(body) => { + s.state.cancel_data.process_id = body.process_id(); + s.state.cancel_data.secret_key = body.secret_key(); + Either::A(Connection(s).finish_startup()) + } + backend::Message::ReadyForQuery(_) => Either::B(Ok(Connection(s)).into_future()), + _ => Either::B(Err(bad_message()).into_future()), + } + }) + .boxed() + } + + pub fn cancel_data(&self) -> CancelData { + self.0.state.cancel_data + } +} + +fn bad_message() -> T + where T: From +{ + io::Error::new(io::ErrorKind::InvalidInput, "unexpected message").into() +} + +fn eof() -> T + where T: From +{ + io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF").into() +} diff --git a/postgres-tokio/src/stream.rs b/postgres-tokio/src/stream.rs new file mode 100644 index 00000000..daaa6083 --- /dev/null +++ b/postgres-tokio/src/stream.rs @@ -0,0 +1,100 @@ +use futures::{BoxFuture, Future, IntoFuture, Async}; +use postgres_shared::params::Host; +use postgres_protocol::message::backend::{self, ParseResult}; +use std::io::{self, Read, Write}; +use tokio_core::io::{Io, Codec, EasyBuf, Framed}; +use tokio_core::net::TcpStream; +use tokio_core::reactor::Handle; +use tokio_dns; +use tokio_uds::UnixStream; + +pub type PostgresStream = Framed; + +pub fn connect(host: &Host, + port: u16, + handle: &Handle) + -> BoxFuture { + match *host { + Host::Tcp(ref host) => { + tokio_dns::tcp_connect((&**host, port), handle.remote().clone()) + .map(|s| InnerStream::Tcp(s).framed(PostgresCodec)) + .boxed() + } + Host::Unix(ref host) => { + let addr = host.join(format!(".s.PGSQL.{}", port)); + UnixStream::connect(addr, handle) + .map(|s| InnerStream::Unix(s).framed(PostgresCodec)) + .into_future() + .boxed() + } + } +} + +pub enum InnerStream { + Tcp(TcpStream), + Unix(UnixStream), +} + +impl Read for InnerStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match *self { + InnerStream::Tcp(ref mut s) => s.read(buf), + InnerStream::Unix(ref mut s) => s.read(buf), + } + } +} + +impl Write for InnerStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + match *self { + InnerStream::Tcp(ref mut s) => s.write(buf), + InnerStream::Unix(ref mut s) => s.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match *self { + InnerStream::Tcp(ref mut s) => s.flush(), + InnerStream::Unix(ref mut s) => s.flush(), + } + } +} + +impl Io for InnerStream { + fn poll_read(&mut self) -> Async<()> { + match *self { + InnerStream::Tcp(ref mut s) => s.poll_read(), + InnerStream::Unix(ref mut s) => s.poll_read(), + } + } + + fn poll_write(&mut self) -> Async<()> { + match *self { + InnerStream::Tcp(ref mut s) => s.poll_write(), + InnerStream::Unix(ref mut s) => s.poll_write(), + } + } +} + +pub struct PostgresCodec; + +impl Codec for PostgresCodec { + type In = backend::Message>; + type Out = Vec; + + // FIXME ideally we'd avoid re-copying the data + fn decode(&mut self, buf: &mut EasyBuf) -> io::Result> { + match try!(backend::Message::parse_owned(buf.as_ref())) { + ParseResult::Complete { message, consumed } => { + buf.drain_to(consumed); + Ok(Some(message)) + } + ParseResult::Incomplete { .. } => Ok(None) + } + } + + fn encode(&mut self, msg: Vec, buf: &mut Vec) -> io::Result<()> { + buf.extend_from_slice(&msg); + Ok(()) + } +} diff --git a/postgres-tokio/src/test.rs b/postgres-tokio/src/test.rs new file mode 100644 index 00000000..973f29a7 --- /dev/null +++ b/postgres-tokio/src/test.rs @@ -0,0 +1,28 @@ +use tokio_core::reactor::Core; + +use super::*; + +#[test] +fn basic() { + let mut l = Core::new().unwrap(); + let handle = l.handle(); + let done = Connection::connect("postgres://postgres@localhost", &handle); + let conn = l.run(done).unwrap(); + assert!(conn.cancel_data().process_id != 0); +} + +#[test] +fn md5_user() { + let mut l = Core::new().unwrap(); + let handle = l.handle(); + let done = Connection::connect("postgres://md5_user:password@localhost/postgres", &handle); + l.run(done).unwrap(); +} + +#[test] +fn pass_user() { + let mut l = Core::new().unwrap(); + let handle = l.handle(); + let done = Connection::connect("postgres://pass_user:password@localhost/postgres", &handle); + l.run(done).unwrap(); +} \ No newline at end of file