From 7056e3ec24c54331a6ddd2b70010cf5b648623b6 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 15 Jul 2018 19:40:15 -0700 Subject: [PATCH] Copy out support --- postgres-protocol/src/message/backend.rs | 79 ++++++++--------- tokio-postgres/src/lib.rs | 17 ++++ tokio-postgres/src/proto/client.rs | 6 ++ tokio-postgres/src/proto/copy_out.rs | 106 +++++++++++++++++++++++ tokio-postgres/src/proto/mod.rs | 2 + tokio-postgres/tests/test.rs | 38 +++++++- 6 files changed, 207 insertions(+), 41 deletions(-) create mode 100644 tokio-postgres/src/proto/copy_out.rs diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 3f136c31..eacb5da4 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -1,6 +1,6 @@ #![allow(missing_docs)] -use byteorder::{ReadBytesExt, BigEndian}; +use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use memchr::memchr; @@ -148,45 +148,41 @@ impl Message { let storage = buf.read_all(); Message::NoticeResponse(NoticeResponseBody { storage: storage }) } - b'R' => { - match buf.read_i32::()? { - 0 => Message::AuthenticationOk, - 2 => Message::AuthenticationKerberosV5, - 3 => Message::AuthenticationCleartextPassword, - 5 => { - let mut salt = [0; 4]; - buf.read_exact(&mut salt)?; - Message::AuthenticationMd5Password( - AuthenticationMd5PasswordBody { salt: salt }, - ) - } - 6 => Message::AuthenticationScmCredential, - 7 => Message::AuthenticationGss, - 8 => { - let storage = buf.read_all(); - Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage)) - } - 9 => Message::AuthenticationSspi, - 10 => { - let storage = buf.read_all(); - Message::AuthenticationSasl(AuthenticationSaslBody(storage)) - } - 11 => { - let storage = buf.read_all(); - Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage)) - } - 12 => { - let storage = buf.read_all(); - Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) - } - tag => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - format!("unknown authentication tag `{}`", tag), - )); - } + b'R' => match buf.read_i32::()? { + 0 => Message::AuthenticationOk, + 2 => Message::AuthenticationKerberosV5, + 3 => Message::AuthenticationCleartextPassword, + 5 => { + let mut salt = [0; 4]; + buf.read_exact(&mut salt)?; + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt: salt }) } - } + 6 => Message::AuthenticationScmCredential, + 7 => Message::AuthenticationGss, + 8 => { + let storage = buf.read_all(); + Message::AuthenticationGssContinue(AuthenticationGssContinueBody(storage)) + } + 9 => Message::AuthenticationSspi, + 10 => { + let storage = buf.read_all(); + Message::AuthenticationSasl(AuthenticationSaslBody(storage)) + } + 11 => { + let storage = buf.read_all(); + Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage)) + } + 12 => { + let storage = buf.read_all(); + Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown authentication tag `{}`", tag), + )); + } + }, b's' => Message::PortalSuspended, b'S' => { let name = buf.read_cstr()?; @@ -394,6 +390,11 @@ impl CopyDataBody { pub fn data(&self) -> &[u8] { &self.storage } + + #[inline] + pub fn into_bytes(self) -> Bytes { + self.storage + } } pub struct CopyInResponseBody { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a5089d3d..309c6140 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -21,6 +21,7 @@ extern crate state_machine_future; #[cfg(unix)] extern crate tokio_uds; +use bytes::Bytes; use futures::{Async, Future, Poll, Stream}; use postgres_shared::rows::RowIndex; use std::fmt; @@ -95,6 +96,10 @@ impl Client { Query(self.0.query(&statement.0, params)) } + pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut { + CopyOut(self.0.copy_out(&statement.0, params)) + } + pub fn transaction(&mut self, future: T) -> Transaction where T: Future, @@ -222,6 +227,18 @@ impl Stream for Query { } } +#[must_use = "streams do nothing unless polled"] +pub struct CopyOut(proto::CopyOutStream); + +impl Stream for CopyOut { + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + self.0.poll() + } +} + pub struct Row(proto::Row); impl Row { diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index e34b64fe..1e37edec 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -9,6 +9,7 @@ use std::sync::{Arc, Weak}; use disconnected; use error::{self, Error}; use proto::connection::Request; +use proto::copy_out::CopyOutStream; use proto::execute::ExecuteFuture; use proto::prepare::PrepareFuture; use proto::query::QueryStream; @@ -130,6 +131,11 @@ impl Client { QueryStream::new(self.clone(), pending, statement.clone()) } + pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream { + let pending = self.pending_execute(statement, params); + CopyOutStream::new(self.clone(), pending, statement.clone()) + } + pub fn close_statement(&self, name: &str) { let mut buf = vec![]; frontend::close(b'S', name, &mut buf).expect("statement name not valid"); diff --git a/tokio-postgres/src/proto/copy_out.rs b/tokio-postgres/src/proto/copy_out.rs new file mode 100644 index 00000000..b8909d5d --- /dev/null +++ b/tokio-postgres/src/proto/copy_out.rs @@ -0,0 +1,106 @@ +use bytes::Bytes; +use futures::sync::mpsc; +use futures::{Async, Poll, Stream}; +use postgres_protocol::message::backend::Message; +use std::mem; + +use error::{self, Error}; +use proto::client::{Client, PendingRequest}; +use proto::statement::Statement; +use {bad_response, disconnected}; + +enum State { + Start { + client: Client, + request: PendingRequest, + statement: Statement, + }, + ReadingCopyOutResponse { + receiver: mpsc::Receiver, + }, + ReadingCopyData { + receiver: mpsc::Receiver, + }, + Done, +} + +pub struct CopyOutStream(State); + +impl Stream for CopyOutStream { + type Item = Bytes; + type Error = Error; + + fn poll(&mut self) -> Poll, Error> { + loop { + match mem::replace(&mut self.0, State::Done) { + State::Start { + client, + request, + statement, + } => { + let receiver = client.send(request)?; + // it's ok for the statement to close now that we've queued the query + drop(statement); + self.0 = State::ReadingCopyOutResponse { receiver }; + } + State::ReadingCopyOutResponse { mut receiver } => { + let message = match receiver.poll() { + Ok(Async::Ready(message)) => message, + Ok(Async::NotReady) => { + self.0 = State::ReadingCopyOutResponse { receiver }; + break Ok(Async::NotReady); + } + Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), + }; + + match message { + Some(Message::BindComplete) => { + self.0 = State::ReadingCopyOutResponse { receiver }; + } + Some(Message::CopyOutResponse(_)) => { + self.0 = State::ReadingCopyData { receiver }; + } + Some(Message::ErrorResponse(body)) => break Err(error::__db(body)), + Some(_) => break Err(bad_response()), + None => break Err(disconnected()), + } + } + State::ReadingCopyData { mut receiver } => { + let message = match receiver.poll() { + Ok(Async::Ready(message)) => message, + Ok(Async::NotReady) => { + self.0 = State::ReadingCopyData { receiver }; + break Ok(Async::NotReady); + } + Err(()) => unreachable!("mpsc::Reciever doesn't return errors"), + }; + + match message { + Some(Message::CopyData(body)) => { + self.0 = State::ReadingCopyData { receiver }; + break Ok(Async::Ready(Some(body.into_bytes()))); + } + Some(Message::CopyDone) | Some(Message::CommandComplete(_)) => { + self.0 = State::ReadingCopyData { receiver }; + } + Some(Message::ReadyForQuery(_)) => break Ok(Async::Ready(None)), + Some(Message::ErrorResponse(body)) => break Err(error::__db(body)), + Some(_) => break Err(bad_response()), + None => break Err(disconnected()), + } + } + State::Done => break Ok(Async::Ready(None)), + } + } + } +} + +impl CopyOutStream { + pub fn new(client: Client, request: PendingRequest, statement: Statement) -> CopyOutStream { + CopyOutStream(State::Start { + client, + request, + statement, + }) + } +} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 4e9d6006..7aeb25ba 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -13,6 +13,7 @@ mod client; mod codec; mod connect; mod connection; +mod copy_out; mod execute; mod handshake; mod prepare; @@ -30,6 +31,7 @@ pub use proto::cancel::CancelFuture; pub use proto::client::Client; pub use proto::codec::PostgresCodec; pub use proto::connection::Connection; +pub use proto::copy_out::CopyOutStream; pub use proto::execute::ExecuteFuture; pub use proto::handshake::HandshakeFuture; pub use proto::prepare::PrepareFuture; diff --git a/tokio-postgres/tests/test.rs b/tokio-postgres/tests/test.rs index bd946c6c..441d321e 100644 --- a/tokio-postgres/tests/test.rs +++ b/tokio-postgres/tests/test.rs @@ -480,7 +480,7 @@ fn notifications() { } #[test] -fn test_transaction_commit() { +fn transaction_commit() { let _ = env_logger::try_init(); let mut runtime = Runtime::new().unwrap(); @@ -518,7 +518,7 @@ fn test_transaction_commit() { } #[test] -fn test_transaction_abort() { +fn transaction_abort() { let _ = env_logger::try_init(); let mut runtime = Runtime::new().unwrap(); @@ -556,3 +556,37 @@ fn test_transaction_abort() { assert_eq!(rows.len(), 0); } + +#[test] +fn copy_out() { + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let (mut client, connection) = runtime + .block_on(tokio_postgres::connect( + "postgres://postgres@localhost:5433".parse().unwrap(), + TlsMode::None, + )) + .unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + runtime + .block_on(client.batch_execute( + "CREATE TEMPORARY TABLE foo ( + id SERIAL, + name TEXT + ); + INSERT INTO foo (name) VALUES ('jim'), ('joe');", + )) + .unwrap(); + + let data = runtime + .block_on( + client + .prepare("COPY foo TO STDOUT") + .and_then(|s| client.copy_out(&s, &[]).concat2()), + ) + .unwrap(); + assert_eq!(&data[..], b"1\tjim\n2\tjoe\n"); +}