Copy out support
This commit is contained in:
parent
9e399aa93f
commit
7056e3ec24
@ -1,6 +1,6 @@
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use byteorder::{ReadBytesExt, BigEndian};
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use memchr::memchr;
|
||||
@ -148,17 +148,14 @@ impl Message {
|
||||
let storage = buf.read_all();
|
||||
Message::NoticeResponse(NoticeResponseBody { storage: storage })
|
||||
}
|
||||
b'R' => {
|
||||
match buf.read_i32::<BigEndian>()? {
|
||||
b'R' => match buf.read_i32::<BigEndian>()? {
|
||||
0 => Message::AuthenticationOk,
|
||||
2 => Message::AuthenticationKerberosV5,
|
||||
3 => Message::AuthenticationCleartextPassword,
|
||||
5 => {
|
||||
let mut salt = [0; 4];
|
||||
buf.read_exact(&mut salt)?;
|
||||
Message::AuthenticationMd5Password(
|
||||
AuthenticationMd5PasswordBody { salt: salt },
|
||||
)
|
||||
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt: salt })
|
||||
}
|
||||
6 => Message::AuthenticationScmCredential,
|
||||
7 => Message::AuthenticationGss,
|
||||
@ -185,8 +182,7 @@ impl Message {
|
||||
format!("unknown authentication tag `{}`", tag),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
b's' => Message::PortalSuspended,
|
||||
b'S' => {
|
||||
let name = buf.read_cstr()?;
|
||||
@ -394,6 +390,11 @@ impl CopyDataBody {
|
||||
pub fn data(&self) -> &[u8] {
|
||||
&self.storage
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_bytes(self) -> Bytes {
|
||||
self.storage
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CopyInResponseBody {
|
||||
|
@ -21,6 +21,7 @@ extern crate state_machine_future;
|
||||
#[cfg(unix)]
|
||||
extern crate tokio_uds;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{Async, Future, Poll, Stream};
|
||||
use postgres_shared::rows::RowIndex;
|
||||
use std::fmt;
|
||||
@ -95,6 +96,10 @@ impl Client {
|
||||
Query(self.0.query(&statement.0, params))
|
||||
}
|
||||
|
||||
pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut {
|
||||
CopyOut(self.0.copy_out(&statement.0, params))
|
||||
}
|
||||
|
||||
pub fn transaction<T>(&mut self, future: T) -> Transaction<T>
|
||||
where
|
||||
T: Future,
|
||||
@ -222,6 +227,18 @@ impl Stream for Query {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use = "streams do nothing unless polled"]
|
||||
pub struct CopyOut(proto::CopyOutStream);
|
||||
|
||||
impl Stream for CopyOut {
|
||||
type Item = Bytes;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
|
||||
self.0.poll()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Row(proto::Row);
|
||||
|
||||
impl Row {
|
||||
|
@ -9,6 +9,7 @@ use std::sync::{Arc, Weak};
|
||||
use disconnected;
|
||||
use error::{self, Error};
|
||||
use proto::connection::Request;
|
||||
use proto::copy_out::CopyOutStream;
|
||||
use proto::execute::ExecuteFuture;
|
||||
use proto::prepare::PrepareFuture;
|
||||
use proto::query::QueryStream;
|
||||
@ -130,6 +131,11 @@ impl Client {
|
||||
QueryStream::new(self.clone(), pending, statement.clone())
|
||||
}
|
||||
|
||||
pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream {
|
||||
let pending = self.pending_execute(statement, params);
|
||||
CopyOutStream::new(self.clone(), pending, statement.clone())
|
||||
}
|
||||
|
||||
pub fn close_statement(&self, name: &str) {
|
||||
let mut buf = vec![];
|
||||
frontend::close(b'S', name, &mut buf).expect("statement name not valid");
|
||||
|
106
tokio-postgres/src/proto/copy_out.rs
Normal file
106
tokio-postgres/src/proto/copy_out.rs
Normal file
@ -0,0 +1,106 @@
|
||||
use bytes::Bytes;
|
||||
use futures::sync::mpsc;
|
||||
use futures::{Async, Poll, Stream};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use std::mem;
|
||||
|
||||
use error::{self, Error};
|
||||
use proto::client::{Client, PendingRequest};
|
||||
use proto::statement::Statement;
|
||||
use {bad_response, disconnected};
|
||||
|
||||
enum State {
|
||||
Start {
|
||||
client: Client,
|
||||
request: PendingRequest,
|
||||
statement: Statement,
|
||||
},
|
||||
ReadingCopyOutResponse {
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
},
|
||||
ReadingCopyData {
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
},
|
||||
Done,
|
||||
}
|
||||
|
||||
pub struct CopyOutStream(State);
|
||||
|
||||
impl Stream for CopyOutStream {
|
||||
type Item = Bytes;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
|
||||
loop {
|
||||
match mem::replace(&mut self.0, State::Done) {
|
||||
State::Start {
|
||||
client,
|
||||
request,
|
||||
statement,
|
||||
} => {
|
||||
let receiver = client.send(request)?;
|
||||
// it's ok for the statement to close now that we've queued the query
|
||||
drop(statement);
|
||||
self.0 = State::ReadingCopyOutResponse { receiver };
|
||||
}
|
||||
State::ReadingCopyOutResponse { mut receiver } => {
|
||||
let message = match receiver.poll() {
|
||||
Ok(Async::Ready(message)) => message,
|
||||
Ok(Async::NotReady) => {
|
||||
self.0 = State::ReadingCopyOutResponse { receiver };
|
||||
break Ok(Async::NotReady);
|
||||
}
|
||||
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
|
||||
};
|
||||
|
||||
match message {
|
||||
Some(Message::BindComplete) => {
|
||||
self.0 = State::ReadingCopyOutResponse { receiver };
|
||||
}
|
||||
Some(Message::CopyOutResponse(_)) => {
|
||||
self.0 = State::ReadingCopyData { receiver };
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
|
||||
Some(_) => break Err(bad_response()),
|
||||
None => break Err(disconnected()),
|
||||
}
|
||||
}
|
||||
State::ReadingCopyData { mut receiver } => {
|
||||
let message = match receiver.poll() {
|
||||
Ok(Async::Ready(message)) => message,
|
||||
Ok(Async::NotReady) => {
|
||||
self.0 = State::ReadingCopyData { receiver };
|
||||
break Ok(Async::NotReady);
|
||||
}
|
||||
Err(()) => unreachable!("mpsc::Reciever doesn't return errors"),
|
||||
};
|
||||
|
||||
match message {
|
||||
Some(Message::CopyData(body)) => {
|
||||
self.0 = State::ReadingCopyData { receiver };
|
||||
break Ok(Async::Ready(Some(body.into_bytes())));
|
||||
}
|
||||
Some(Message::CopyDone) | Some(Message::CommandComplete(_)) => {
|
||||
self.0 = State::ReadingCopyData { receiver };
|
||||
}
|
||||
Some(Message::ReadyForQuery(_)) => break Ok(Async::Ready(None)),
|
||||
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
|
||||
Some(_) => break Err(bad_response()),
|
||||
None => break Err(disconnected()),
|
||||
}
|
||||
}
|
||||
State::Done => break Ok(Async::Ready(None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CopyOutStream {
|
||||
pub fn new(client: Client, request: PendingRequest, statement: Statement) -> CopyOutStream {
|
||||
CopyOutStream(State::Start {
|
||||
client,
|
||||
request,
|
||||
statement,
|
||||
})
|
||||
}
|
||||
}
|
@ -13,6 +13,7 @@ mod client;
|
||||
mod codec;
|
||||
mod connect;
|
||||
mod connection;
|
||||
mod copy_out;
|
||||
mod execute;
|
||||
mod handshake;
|
||||
mod prepare;
|
||||
@ -30,6 +31,7 @@ pub use proto::cancel::CancelFuture;
|
||||
pub use proto::client::Client;
|
||||
pub use proto::codec::PostgresCodec;
|
||||
pub use proto::connection::Connection;
|
||||
pub use proto::copy_out::CopyOutStream;
|
||||
pub use proto::execute::ExecuteFuture;
|
||||
pub use proto::handshake::HandshakeFuture;
|
||||
pub use proto::prepare::PrepareFuture;
|
||||
|
@ -480,7 +480,7 @@ fn notifications() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_commit() {
|
||||
fn transaction_commit() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
@ -518,7 +518,7 @@ fn test_transaction_commit() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transaction_abort() {
|
||||
fn transaction_abort() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
@ -556,3 +556,37 @@ fn test_transaction_abort() {
|
||||
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_out() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime
|
||||
.block_on(tokio_postgres::connect(
|
||||
"postgres://postgres@localhost:5433".parse().unwrap(),
|
||||
TlsMode::None,
|
||||
))
|
||||
.unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
runtime
|
||||
.block_on(client.batch_execute(
|
||||
"CREATE TEMPORARY TABLE foo (
|
||||
id SERIAL,
|
||||
name TEXT
|
||||
);
|
||||
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let data = runtime
|
||||
.block_on(
|
||||
client
|
||||
.prepare("COPY foo TO STDOUT")
|
||||
.and_then(|s| client.copy_out(&s, &[]).concat2()),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user