Copy out support
This commit is contained in:
parent
9e399aa93f
commit
7056e3ec24
@ -1,6 +1,6 @@
|
|||||||
#![allow(missing_docs)]
|
#![allow(missing_docs)]
|
||||||
|
|
||||||
use byteorder::{ReadBytesExt, BigEndian};
|
use byteorder::{BigEndian, ReadBytesExt};
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use fallible_iterator::FallibleIterator;
|
use fallible_iterator::FallibleIterator;
|
||||||
use memchr::memchr;
|
use memchr::memchr;
|
||||||
@ -148,17 +148,14 @@ impl Message {
|
|||||||
let storage = buf.read_all();
|
let storage = buf.read_all();
|
||||||
Message::NoticeResponse(NoticeResponseBody { storage: storage })
|
Message::NoticeResponse(NoticeResponseBody { storage: storage })
|
||||||
}
|
}
|
||||||
b'R' => {
|
b'R' => match buf.read_i32::<BigEndian>()? {
|
||||||
match buf.read_i32::<BigEndian>()? {
|
|
||||||
0 => Message::AuthenticationOk,
|
0 => Message::AuthenticationOk,
|
||||||
2 => Message::AuthenticationKerberosV5,
|
2 => Message::AuthenticationKerberosV5,
|
||||||
3 => Message::AuthenticationCleartextPassword,
|
3 => Message::AuthenticationCleartextPassword,
|
||||||
5 => {
|
5 => {
|
||||||
let mut salt = [0; 4];
|
let mut salt = [0; 4];
|
||||||
buf.read_exact(&mut salt)?;
|
buf.read_exact(&mut salt)?;
|
||||||
Message::AuthenticationMd5Password(
|
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt: salt })
|
||||||
AuthenticationMd5PasswordBody { salt: salt },
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
6 => Message::AuthenticationScmCredential,
|
6 => Message::AuthenticationScmCredential,
|
||||||
7 => Message::AuthenticationGss,
|
7 => Message::AuthenticationGss,
|
||||||
@ -185,8 +182,7 @@ impl Message {
|
|||||||
format!("unknown authentication tag `{}`", tag),
|
format!("unknown authentication tag `{}`", tag),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
b's' => Message::PortalSuspended,
|
b's' => Message::PortalSuspended,
|
||||||
b'S' => {
|
b'S' => {
|
||||||
let name = buf.read_cstr()?;
|
let name = buf.read_cstr()?;
|
||||||
@ -394,6 +390,11 @@ impl CopyDataBody {
|
|||||||
pub fn data(&self) -> &[u8] {
|
pub fn data(&self) -> &[u8] {
|
||||||
&self.storage
|
&self.storage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn into_bytes(self) -> Bytes {
|
||||||
|
self.storage
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CopyInResponseBody {
|
pub struct CopyInResponseBody {
|
||||||
|
@ -21,6 +21,7 @@ extern crate state_machine_future;
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
extern crate tokio_uds;
|
extern crate tokio_uds;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
use futures::{Async, Future, Poll, Stream};
|
use futures::{Async, Future, Poll, Stream};
|
||||||
use postgres_shared::rows::RowIndex;
|
use postgres_shared::rows::RowIndex;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
@ -95,6 +96,10 @@ impl Client {
|
|||||||
Query(self.0.query(&statement.0, params))
|
Query(self.0.query(&statement.0, params))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut {
|
||||||
|
CopyOut(self.0.copy_out(&statement.0, params))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn transaction<T>(&mut self, future: T) -> Transaction<T>
|
pub fn transaction<T>(&mut self, future: T) -> Transaction<T>
|
||||||
where
|
where
|
||||||
T: Future,
|
T: Future,
|
||||||
@ -222,6 +227,18 @@ impl Stream for Query {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use = "streams do nothing unless polled"]
|
||||||
|
pub struct CopyOut(proto::CopyOutStream);
|
||||||
|
|
||||||
|
impl Stream for CopyOut {
|
||||||
|
type Item = Bytes;
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
|
||||||
|
self.0.poll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct Row(proto::Row);
|
pub struct Row(proto::Row);
|
||||||
|
|
||||||
impl Row {
|
impl Row {
|
||||||
|
@ -9,6 +9,7 @@ use std::sync::{Arc, Weak};
|
|||||||
use disconnected;
|
use disconnected;
|
||||||
use error::{self, Error};
|
use error::{self, Error};
|
||||||
use proto::connection::Request;
|
use proto::connection::Request;
|
||||||
|
use proto::copy_out::CopyOutStream;
|
||||||
use proto::execute::ExecuteFuture;
|
use proto::execute::ExecuteFuture;
|
||||||
use proto::prepare::PrepareFuture;
|
use proto::prepare::PrepareFuture;
|
||||||
use proto::query::QueryStream;
|
use proto::query::QueryStream;
|
||||||
@ -130,6 +131,11 @@ impl Client {
|
|||||||
QueryStream::new(self.clone(), pending, statement.clone())
|
QueryStream::new(self.clone(), pending, statement.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream {
|
||||||
|
let pending = self.pending_execute(statement, params);
|
||||||
|
CopyOutStream::new(self.clone(), pending, statement.clone())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn close_statement(&self, name: &str) {
|
pub fn close_statement(&self, name: &str) {
|
||||||
let mut buf = vec![];
|
let mut buf = vec![];
|
||||||
frontend::close(b'S', name, &mut buf).expect("statement name not valid");
|
frontend::close(b'S', name, &mut buf).expect("statement name not valid");
|
||||||
|
106
tokio-postgres/src/proto/copy_out.rs
Normal file
106
tokio-postgres/src/proto/copy_out.rs
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
use bytes::Bytes;
|
||||||
|
use futures::sync::mpsc;
|
||||||
|
use futures::{Async, Poll, Stream};
|
||||||
|
use postgres_protocol::message::backend::Message;
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
|
use error::{self, Error};
|
||||||
|
use proto::client::{Client, PendingRequest};
|
||||||
|
use proto::statement::Statement;
|
||||||
|
use {bad_response, disconnected};
|
||||||
|
|
||||||
|
enum State {
|
||||||
|
Start {
|
||||||
|
client: Client,
|
||||||
|
request: PendingRequest,
|
||||||
|
statement: Statement,
|
||||||
|
},
|
||||||
|
ReadingCopyOutResponse {
|
||||||
|
receiver: mpsc::Receiver<Message>,
|
||||||
|
},
|
||||||
|
ReadingCopyData {
|
||||||
|
receiver: mpsc::Receiver<Message>,
|
||||||
|
},
|
||||||
|
Done,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CopyOutStream(State);
|
||||||
|
|
||||||
|
impl Stream for CopyOutStream {
|
||||||
|
type Item = Bytes;
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
fn poll(&mut self) -> Poll<Option<Bytes>, Error> {
|
||||||
|
loop {
|
||||||
|
match mem::replace(&mut self.0, State::Done) {
|
||||||
|
State::Start {
|
||||||
|
client,
|
||||||
|
request,
|
||||||
|
statement,
|
||||||
|
} => {
|
||||||
|
let receiver = client.send(request)?;
|
||||||
|
// it's ok for the statement to close now that we've queued the query
|
||||||
|
drop(statement);
|
||||||
|
self.0 = State::ReadingCopyOutResponse { receiver };
|
||||||
|
}
|
||||||
|
State::ReadingCopyOutResponse { mut receiver } => {
|
||||||
|
let message = match receiver.poll() {
|
||||||
|
Ok(Async::Ready(message)) => message,
|
||||||
|
Ok(Async::NotReady) => {
|
||||||
|
self.0 = State::ReadingCopyOutResponse { receiver };
|
||||||
|
break Ok(Async::NotReady);
|
||||||
|
}
|
||||||
|
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
|
||||||
|
};
|
||||||
|
|
||||||
|
match message {
|
||||||
|
Some(Message::BindComplete) => {
|
||||||
|
self.0 = State::ReadingCopyOutResponse { receiver };
|
||||||
|
}
|
||||||
|
Some(Message::CopyOutResponse(_)) => {
|
||||||
|
self.0 = State::ReadingCopyData { receiver };
|
||||||
|
}
|
||||||
|
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
|
||||||
|
Some(_) => break Err(bad_response()),
|
||||||
|
None => break Err(disconnected()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
State::ReadingCopyData { mut receiver } => {
|
||||||
|
let message = match receiver.poll() {
|
||||||
|
Ok(Async::Ready(message)) => message,
|
||||||
|
Ok(Async::NotReady) => {
|
||||||
|
self.0 = State::ReadingCopyData { receiver };
|
||||||
|
break Ok(Async::NotReady);
|
||||||
|
}
|
||||||
|
Err(()) => unreachable!("mpsc::Reciever doesn't return errors"),
|
||||||
|
};
|
||||||
|
|
||||||
|
match message {
|
||||||
|
Some(Message::CopyData(body)) => {
|
||||||
|
self.0 = State::ReadingCopyData { receiver };
|
||||||
|
break Ok(Async::Ready(Some(body.into_bytes())));
|
||||||
|
}
|
||||||
|
Some(Message::CopyDone) | Some(Message::CommandComplete(_)) => {
|
||||||
|
self.0 = State::ReadingCopyData { receiver };
|
||||||
|
}
|
||||||
|
Some(Message::ReadyForQuery(_)) => break Ok(Async::Ready(None)),
|
||||||
|
Some(Message::ErrorResponse(body)) => break Err(error::__db(body)),
|
||||||
|
Some(_) => break Err(bad_response()),
|
||||||
|
None => break Err(disconnected()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
State::Done => break Ok(Async::Ready(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CopyOutStream {
|
||||||
|
pub fn new(client: Client, request: PendingRequest, statement: Statement) -> CopyOutStream {
|
||||||
|
CopyOutStream(State::Start {
|
||||||
|
client,
|
||||||
|
request,
|
||||||
|
statement,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -13,6 +13,7 @@ mod client;
|
|||||||
mod codec;
|
mod codec;
|
||||||
mod connect;
|
mod connect;
|
||||||
mod connection;
|
mod connection;
|
||||||
|
mod copy_out;
|
||||||
mod execute;
|
mod execute;
|
||||||
mod handshake;
|
mod handshake;
|
||||||
mod prepare;
|
mod prepare;
|
||||||
@ -30,6 +31,7 @@ pub use proto::cancel::CancelFuture;
|
|||||||
pub use proto::client::Client;
|
pub use proto::client::Client;
|
||||||
pub use proto::codec::PostgresCodec;
|
pub use proto::codec::PostgresCodec;
|
||||||
pub use proto::connection::Connection;
|
pub use proto::connection::Connection;
|
||||||
|
pub use proto::copy_out::CopyOutStream;
|
||||||
pub use proto::execute::ExecuteFuture;
|
pub use proto::execute::ExecuteFuture;
|
||||||
pub use proto::handshake::HandshakeFuture;
|
pub use proto::handshake::HandshakeFuture;
|
||||||
pub use proto::prepare::PrepareFuture;
|
pub use proto::prepare::PrepareFuture;
|
||||||
|
@ -480,7 +480,7 @@ fn notifications() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_transaction_commit() {
|
fn transaction_commit() {
|
||||||
let _ = env_logger::try_init();
|
let _ = env_logger::try_init();
|
||||||
let mut runtime = Runtime::new().unwrap();
|
let mut runtime = Runtime::new().unwrap();
|
||||||
|
|
||||||
@ -518,7 +518,7 @@ fn test_transaction_commit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_transaction_abort() {
|
fn transaction_abort() {
|
||||||
let _ = env_logger::try_init();
|
let _ = env_logger::try_init();
|
||||||
let mut runtime = Runtime::new().unwrap();
|
let mut runtime = Runtime::new().unwrap();
|
||||||
|
|
||||||
@ -556,3 +556,37 @@ fn test_transaction_abort() {
|
|||||||
|
|
||||||
assert_eq!(rows.len(), 0);
|
assert_eq!(rows.len(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn copy_out() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
|
let mut runtime = Runtime::new().unwrap();
|
||||||
|
|
||||||
|
let (mut client, connection) = runtime
|
||||||
|
.block_on(tokio_postgres::connect(
|
||||||
|
"postgres://postgres@localhost:5433".parse().unwrap(),
|
||||||
|
TlsMode::None,
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
let connection = connection.map_err(|e| panic!("{}", e));
|
||||||
|
runtime.handle().spawn(connection).unwrap();
|
||||||
|
|
||||||
|
runtime
|
||||||
|
.block_on(client.batch_execute(
|
||||||
|
"CREATE TEMPORARY TABLE foo (
|
||||||
|
id SERIAL,
|
||||||
|
name TEXT
|
||||||
|
);
|
||||||
|
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let data = runtime
|
||||||
|
.block_on(
|
||||||
|
client
|
||||||
|
.prepare("COPY foo TO STDOUT")
|
||||||
|
.and_then(|s| client.copy_out(&s, &[]).concat2()),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user