diff --git a/postgres-shared/src/error/mod.rs b/postgres-shared/src/error/mod.rs index 9d250d20..c437219c 100644 --- a/postgres-shared/src/error/mod.rs +++ b/postgres-shared/src/error/mod.rs @@ -325,6 +325,14 @@ pub fn __db(e: ErrorResponseBody) -> Error { } } +#[doc(hidden)] +pub fn __user(e: T) -> Error +where + T: Into>, +{ + Error(Box::new(ErrorKind::Conversion(e.into()))) +} + #[doc(hidden)] pub fn io(e: io::Error) -> Error { Error(Box::new(ErrorKind::Io(e))) diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 309c6140..5fdd9fce 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -24,6 +24,7 @@ extern crate tokio_uds; use bytes::Bytes; use futures::{Async, Future, Poll, Stream}; use postgres_shared::rows::RowIndex; +use std::error::Error as StdError; use std::fmt; use std::io; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -96,6 +97,14 @@ impl Client { Query(self.0.query(&statement.0, params)) } + pub fn copy_in(&mut self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyIn + where + S: Stream>, + S::Error: Into>, + { + CopyIn(self.0.copy_in(&statement.0, params, stream)) + } + pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut { CopyOut(self.0.copy_out(&statement.0, params)) } @@ -227,6 +236,25 @@ impl Stream for Query { } } +#[must_use = "futures do nothing unless polled"] +pub struct CopyIn(proto::CopyInFuture) +where + S: Stream>, + S::Error: Into>; + +impl Future for CopyIn +where + S: Stream>, + S::Error: Into>, +{ + type Item = u64; + type Error = Error; + + fn poll(&mut self) -> Poll { + self.0.poll() + } +} + #[must_use = "streams do nothing unless polled"] pub struct CopyOut(proto::CopyOutStream); diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index 1e37edec..4ceb922a 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -1,14 +1,17 @@ use antidote::Mutex; use futures::sync::mpsc; +use futures::{AsyncSink, Sink, Stream}; use postgres_protocol; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::collections::HashMap; +use std::error::Error as StdError; use std::sync::{Arc, Weak}; use disconnected; use error::{self, Error}; -use proto::connection::Request; +use proto::connection::{Request, RequestMessages}; +use proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage}; use proto::copy_out::CopyOutStream; use proto::execute::ExecuteFuture; use proto::prepare::PrepareFuture; @@ -17,7 +20,7 @@ use proto::simple_query::SimpleQueryFuture; use proto::statement::Statement; use types::{IsNull, Oid, ToSql, Type}; -pub struct PendingRequest(Result, Error>); +pub struct PendingRequest(Result); pub struct WeakClient(Weak); @@ -122,17 +125,45 @@ impl Client { } pub fn execute(&self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture { - let pending = self.pending_execute(statement, params); + let pending = PendingRequest( + self.excecute_message(statement, params) + .map(RequestMessages::Single), + ); ExecuteFuture::new(self.clone(), pending, statement.clone()) } pub fn query(&self, statement: &Statement, params: &[&ToSql]) -> QueryStream { - let pending = self.pending_execute(statement, params); + let pending = PendingRequest( + self.excecute_message(statement, params) + .map(RequestMessages::Single), + ); QueryStream::new(self.clone(), pending, statement.clone()) } + pub fn copy_in(&self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyInFuture + where + S: Stream>, + S::Error: Into>, + { + let (mut sender, receiver) = mpsc::channel(0); + let pending = PendingRequest(self.excecute_message(statement, params).map(|buf| { + match sender.start_send(CopyMessage::Data(buf)) { + Ok(AsyncSink::Ready) => {} + _ => unreachable!("channel should have capacity"), + } + RequestMessages::CopyIn { + receiver: CopyInReceiver::new(receiver), + pending_message: None, + } + })); + CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender) + } + pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream { - let pending = self.pending_execute(statement, params); + let pending = PendingRequest( + self.excecute_message(statement, params) + .map(RequestMessages::Single), + ); CopyOutStream::new(self.clone(), pending, statement.clone()) } @@ -142,35 +173,34 @@ impl Client { frontend::sync(&mut buf); let (sender, _) = mpsc::channel(0); let _ = self.0.sender.unbounded_send(Request { - messages: buf, + messages: RequestMessages::Single(buf), sender, }); } - fn pending_execute(&self, statement: &Statement, params: &[&ToSql]) -> PendingRequest { - self.pending(|buf| { - let r = frontend::bind( - "", - statement.name(), - Some(1), - params.iter().zip(statement.params()), - |(param, ty), buf| match param.to_sql_checked(ty, buf) { - Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), - Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), - Err(e) => Err(e), - }, - Some(1), - buf, - ); - match r { - Ok(()) => {} - Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)), - Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)), - } - frontend::execute("", 0, buf)?; - frontend::sync(buf); - Ok(()) - }) + fn excecute_message(&self, statement: &Statement, params: &[&ToSql]) -> Result, Error> { + let mut buf = vec![]; + let r = frontend::bind( + "", + statement.name(), + Some(1), + params.iter().zip(statement.params()), + |(param, ty), buf| match param.to_sql_checked(ty, buf) { + Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), + Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), + Err(e) => Err(e), + }, + Some(1), + &mut buf, + ); + match r { + Ok(()) => {} + Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)), + Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)), + } + frontend::execute("", 0, &mut buf)?; + frontend::sync(&mut buf); + Ok(buf) } fn pending(&self, messages: F) -> PendingRequest @@ -178,6 +208,6 @@ impl Client { F: FnOnce(&mut Vec) -> Result<(), Error>, { let mut buf = vec![]; - PendingRequest(messages(&mut buf).map(|()| buf)) + PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf))) } } diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index 828f6087..a2b01c8d 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -8,11 +8,20 @@ use tokio_codec::Framed; use error::{self, DbError, Error}; use proto::codec::PostgresCodec; +use proto::copy_in::CopyInReceiver; use tls::TlsStream; use {bad_response, disconnected, AsyncMessage, CancelData, Notification}; +pub enum RequestMessages { + Single(Vec), + CopyIn { + receiver: CopyInReceiver, + pending_message: Option>, + }, +} + pub struct Request { - pub messages: Vec, + pub messages: RequestMessages, pub sender: mpsc::Sender, } @@ -28,7 +37,7 @@ pub struct Connection { cancel_data: CancelData, parameters: HashMap, receiver: mpsc::UnboundedReceiver, - pending_request: Option>, + pending_request: Option, pending_response: Option, responses: VecDeque>, state: State, @@ -140,7 +149,7 @@ impl Connection { } } - fn poll_request(&mut self) -> Poll>, Error> { + fn poll_request(&mut self) -> Poll, Error> { if let Some(message) = self.pending_request.take() { trace!("retrying pending request"); return Ok(Async::Ready(Some(message))); @@ -170,7 +179,7 @@ impl Connection { self.state = State::Terminating; let mut request = vec![]; frontend::terminate(&mut request); - request + RequestMessages::Single(request) } Async::Ready(None) => { trace!( @@ -185,17 +194,60 @@ impl Connection { } }; - match self.stream.start_send(request)? { - AsyncSink::Ready => { - if self.state == State::Terminating { - trace!("poll_write: sent eof, closing"); - self.state = State::Closing; + match request { + RequestMessages::Single(request) => match self.stream.start_send(request)? { + AsyncSink::Ready => { + if self.state == State::Terminating { + trace!("poll_write: sent eof, closing"); + self.state = State::Closing; + } } - } - AsyncSink::NotReady(request) => { - trace!("poll_write: waiting on socket"); - self.pending_request = Some(request); - return Ok(false); + AsyncSink::NotReady(request) => { + trace!("poll_write: waiting on socket"); + self.pending_request = Some(RequestMessages::Single(request)); + return Ok(false); + } + }, + RequestMessages::CopyIn { + mut receiver, + mut pending_message, + } => { + let message = match pending_message.take() { + Some(message) => message, + None => match receiver.poll() { + Ok(Async::Ready(Some(message))) => message, + Ok(Async::Ready(None)) => { + trace!("poll_write: finished copy_in request"); + continue; + } + Ok(Async::NotReady) => { + trace!("poll_write: waiting on copy_in stream"); + self.pending_request = Some(RequestMessages::CopyIn { + receiver, + pending_message, + }); + return Ok(true); + } + Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), + }, + }; + + match self.stream.start_send(message)? { + AsyncSink::Ready => { + self.pending_request = Some(RequestMessages::CopyIn { + receiver, + pending_message: None, + }); + } + AsyncSink::NotReady(message) => { + trace!("poll_write: waiting on socket"); + self.pending_request = Some(RequestMessages::CopyIn { + receiver, + pending_message: Some(message), + }); + return Ok(false); + } + }; } } } diff --git a/tokio-postgres/src/proto/copy_in.rs b/tokio-postgres/src/proto/copy_in.rs new file mode 100644 index 00000000..12fc69ab --- /dev/null +++ b/tokio-postgres/src/proto/copy_in.rs @@ -0,0 +1,219 @@ +use futures::sink; +use futures::sync::mpsc; +use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use state_machine_future::RentToOwn; +use std::error::Error as StdError; + +use error::{self, Error}; +use proto::client::{Client, PendingRequest}; +use proto::statement::Statement; +use {bad_response, disconnected}; + +pub enum CopyMessage { + Data(Vec), + Done, +} + +pub struct CopyInReceiver { + receiver: mpsc::Receiver, + done: bool, +} + +impl CopyInReceiver { + pub fn new(receiver: mpsc::Receiver) -> CopyInReceiver { + CopyInReceiver { + receiver, + done: false, + } + } +} + +impl Stream for CopyInReceiver { + type Item = Vec; + type Error = (); + + fn poll(&mut self) -> Poll>, ()> { + if self.done { + return Ok(Async::Ready(None)); + } + + match self.receiver.poll()? { + Async::Ready(Some(CopyMessage::Data(buf))) => Ok(Async::Ready(Some(buf))), + Async::Ready(Some(CopyMessage::Done)) => { + self.done = true; + let mut buf = vec![]; + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Ok(Async::Ready(Some(buf))) + } + Async::Ready(None) => { + self.done = true; + let mut buf = vec![]; + frontend::copy_fail("", &mut buf).unwrap(); + frontend::sync(&mut buf); + Ok(Async::Ready(Some(buf))) + } + Async::NotReady => Ok(Async::NotReady), + } + } +} + +#[derive(StateMachineFuture)] +pub enum CopyIn +where + S: Stream>, + S::Error: Into>, +{ + #[state_machine_future(start, transitions(ReadCopyInResponse))] + Start { + client: Client, + request: PendingRequest, + statement: Statement, + stream: S, + sender: mpsc::Sender, + }, + #[state_machine_future(transitions(WriteCopyData))] + ReadCopyInResponse { + stream: S, + sender: mpsc::Sender, + receiver: mpsc::Receiver, + }, + #[state_machine_future(transitions(WriteCopyDone))] + WriteCopyData { + stream: S, + pending_message: Option, + sender: mpsc::Sender, + receiver: mpsc::Receiver, + }, + #[state_machine_future(transitions(ReadCommandComplete))] + WriteCopyDone { + future: sink::Send>, + receiver: mpsc::Receiver, + }, + #[state_machine_future(transitions(Finished))] + ReadCommandComplete { receiver: mpsc::Receiver }, + #[state_machine_future(ready)] + Finished(u64), + #[state_machine_future(error)] + Failed(Error), +} + +impl PollCopyIn for CopyIn +where + S: Stream>, + S::Error: Into>, +{ + fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { + let state = state.take(); + let receiver = state.client.send(state.request)?; + + // the statement can drop after this point, since its close will queue up after the copy + transition!(ReadCopyInResponse { + stream: state.stream, + sender: state.sender, + receiver + }) + } + + fn poll_read_copy_in_response<'a>( + state: &'a mut RentToOwn<'a, ReadCopyInResponse>, + ) -> Poll, Error> { + loop { + let message = try_receive!(state.receiver.poll()); + + match message { + Some(Message::BindComplete) => {} + Some(Message::CopyInResponse(_)) => { + let state = state.take(); + transition!(WriteCopyData { + stream: state.stream, + pending_message: None, + sender: state.sender, + receiver: state.receiver + }) + } + Some(Message::ErrorResponse(body)) => return Err(error::__db(body)), + Some(_) => return Err(bad_response()), + None => return Err(disconnected()), + } + } + } + + fn poll_write_copy_data<'a>( + state: &'a mut RentToOwn<'a, WriteCopyData>, + ) -> Poll { + loop { + let message = match state.pending_message.take() { + Some(message) => message, + None => match try_ready!(state.stream.poll().map_err(error::__user)) { + Some(data) => { + let mut buf = vec![]; + frontend::copy_data(&data, &mut buf).map_err(error::io)?; + CopyMessage::Data(buf) + } + None => { + let state = state.take(); + transition!(WriteCopyDone { + future: state.sender.send(CopyMessage::Done), + receiver: state.receiver + }) + } + }, + }; + + match state.sender.start_send(message) { + Ok(AsyncSink::Ready) => {} + Ok(AsyncSink::NotReady(message)) => { + state.pending_message = Some(message); + return Ok(Async::NotReady); + } + Err(_) => return Err(disconnected()), + } + } + } + + fn poll_write_copy_done<'a>( + state: &'a mut RentToOwn<'a, WriteCopyDone>, + ) -> Poll { + try_ready!(state.future.poll().map_err(|_| disconnected())); + let state = state.take(); + + transition!(ReadCommandComplete { + receiver: state.receiver + }) + } + + fn poll_read_command_complete<'a>( + state: &'a mut RentToOwn<'a, ReadCommandComplete>, + ) -> Poll { + let message = try_receive!(state.receiver.poll()); + + match message { + Some(Message::CommandComplete(body)) => { + let rows = body.tag()?.rsplit(' ').next().unwrap().parse().unwrap_or(0); + transition!(Finished(rows)) + } + Some(Message::ErrorResponse(body)) => Err(error::__db(body)), + Some(_) => Err(bad_response()), + None => Err(disconnected()), + } + } +} + +impl CopyInFuture +where + S: Stream>, + S::Error: Into>, +{ + pub fn new( + client: Client, + request: PendingRequest, + statement: Statement, + stream: S, + sender: mpsc::Sender, + ) -> CopyInFuture { + CopyIn::start(client, request, statement, stream, sender) + } +} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 7aeb25ba..f5d81872 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -13,6 +13,7 @@ mod client; mod codec; mod connect; mod connection; +mod copy_in; mod copy_out; mod execute; mod handshake; @@ -31,6 +32,7 @@ pub use proto::cancel::CancelFuture; pub use proto::client::Client; pub use proto::codec::PostgresCodec; pub use proto::connection::Connection; +pub use proto::copy_in::CopyInFuture; pub use proto::copy_out::CopyOutStream; pub use proto::execute::ExecuteFuture; pub use proto::handshake::HandshakeFuture; diff --git a/tokio-postgres/tests/test.rs b/tokio-postgres/tests/test.rs index 441d321e..f15ae4ca 100644 --- a/tokio-postgres/tests/test.rs +++ b/tokio-postgres/tests/test.rs @@ -8,6 +8,7 @@ extern crate futures; extern crate log; use futures::future; +use futures::stream; use futures::sync::mpsc; use std::error::Error; use std::time::{Duration, Instant}; @@ -238,8 +239,7 @@ fn cancel_query() { TlsMode::None, cancel_data, ) - }) - .then(|r| { + }).then(|r| { r.unwrap(); Ok::<(), ()>(()) }); @@ -267,8 +267,7 @@ fn custom_enum() { 'ok', 'happy' )", - )) - .unwrap(); + )).unwrap(); let select = client.prepare("SELECT $1::mood"); let select = runtime.block_on(select).unwrap(); @@ -301,8 +300,7 @@ fn custom_domain() { runtime .block_on(client.batch_execute( "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)", - )) - .unwrap(); + )).unwrap(); let select = client.prepare("SELECT $1::session_id"); let select = runtime.block_on(select).unwrap(); @@ -359,8 +357,7 @@ fn custom_composite() { supplier INTEGER, price NUMERIC )", - )) - .unwrap(); + )).unwrap(); let select = client.prepare("SELECT $1::inventory_item"); let select = runtime.block_on(select).unwrap(); @@ -399,8 +396,7 @@ fn custom_range() { subtype = float8, subtype_diff = float8mi )", - )) - .unwrap(); + )).unwrap(); let select = client.prepare("SELECT $1::floatrange"); let select = runtime.block_on(select).unwrap(); @@ -488,8 +484,7 @@ fn transaction_commit() { .block_on(tokio_postgres::connect( "postgres://postgres@localhost:5433".parse().unwrap(), TlsMode::None, - )) - .unwrap(); + )).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.handle().spawn(connection).unwrap(); @@ -499,8 +494,7 @@ fn transaction_commit() { id SERIAL, name TEXT )", - )) - .unwrap(); + )).unwrap(); let f = client.batch_execute("INSERT INTO foo (name) VALUES ('steven')"); runtime.block_on(client.transaction(f)).unwrap(); @@ -510,8 +504,7 @@ fn transaction_commit() { client .prepare("SELECT name FROM foo") .and_then(|s| client.query(&s, &[]).collect()), - ) - .unwrap(); + ).unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, &str>(0), "steven"); @@ -526,8 +519,7 @@ fn transaction_abort() { .block_on(tokio_postgres::connect( "postgres://postgres@localhost:5433".parse().unwrap(), TlsMode::None, - )) - .unwrap(); + )).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.handle().spawn(connection).unwrap(); @@ -537,8 +529,7 @@ fn transaction_abort() { id SERIAL, name TEXT )", - )) - .unwrap(); + )).unwrap(); let f = client .batch_execute("INSERT INTO foo (name) VALUES ('steven')") @@ -551,8 +542,91 @@ fn transaction_abort() { client .prepare("SELECT name FROM foo") .and_then(|s| client.query(&s, &[]).collect()), - ) - .unwrap(); + ).unwrap(); + + assert_eq!(rows.len(), 0); +} + +#[test] +fn copy_in() { + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let (mut client, connection) = runtime + .block_on(tokio_postgres::connect( + "postgres://postgres@localhost:5433".parse().unwrap(), + TlsMode::None, + )).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + runtime + .block_on(client.batch_execute( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + )).unwrap(); + + let stream = stream::iter_ok::<_, String>(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()]); + let rows = runtime + .block_on( + client + .prepare("COPY foo FROM STDIN") + .and_then(|s| client.copy_in(&s, &[], stream)), + ).unwrap(); + assert_eq!(rows, 2); + + let rows = runtime + .block_on( + client + .prepare("SELECT id, name FROM foo ORDER BY id") + .and_then(|s| client.query(&s, &[]).collect()), + ).unwrap(); + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[0].get::<_, &str>(1), "jim"); + assert_eq!(rows[1].get::<_, i32>(0), 2); + assert_eq!(rows[1].get::<_, &str>(1), "joe"); +} + +#[test] +fn copy_in_error() { + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let (mut client, connection) = runtime + .block_on(tokio_postgres::connect( + "postgres://postgres@localhost:5433".parse().unwrap(), + TlsMode::None, + )).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + runtime + .block_on(client.batch_execute( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + )).unwrap(); + + let stream = stream::iter_result(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]); + let error = runtime + .block_on( + client + .prepare("COPY foo FROM STDIN") + .and_then(|s| client.copy_in(&s, &[], stream)), + ).unwrap_err(); + error.to_string().contains("asdf"); + + let rows = runtime + .block_on( + client + .prepare("SELECT id, name FROM foo ORDER BY id") + .and_then(|s| client.query(&s, &[]).collect()), + ).unwrap(); assert_eq!(rows.len(), 0); } @@ -566,8 +640,7 @@ fn copy_out() { .block_on(tokio_postgres::connect( "postgres://postgres@localhost:5433".parse().unwrap(), TlsMode::None, - )) - .unwrap(); + )).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.handle().spawn(connection).unwrap(); @@ -578,15 +651,13 @@ fn copy_out() { name TEXT ); INSERT INTO foo (name) VALUES ('jim'), ('joe');", - )) - .unwrap(); + )).unwrap(); let data = runtime .block_on( client .prepare("COPY foo TO STDOUT") .and_then(|s| client.copy_out(&s, &[]).concat2()), - ) - .unwrap(); + ).unwrap(); assert_eq!(&data[..], b"1\tjim\n2\tjoe\n"); }