copy in support

This commit is contained in:
Steven Fackler 2018-08-11 15:32:17 -06:00
parent daeb5389ed
commit b74f5c80d0
7 changed files with 483 additions and 73 deletions

View File

@ -325,6 +325,14 @@ pub fn __db(e: ErrorResponseBody) -> Error {
}
}
#[doc(hidden)]
pub fn __user<T>(e: T) -> Error
where
T: Into<Box<error::Error + Sync + Send>>,
{
Error(Box::new(ErrorKind::Conversion(e.into())))
}
#[doc(hidden)]
pub fn io(e: io::Error) -> Error {
Error(Box::new(ErrorKind::Io(e)))

View File

@ -24,6 +24,7 @@ extern crate tokio_uds;
use bytes::Bytes;
use futures::{Async, Future, Poll, Stream};
use postgres_shared::rows::RowIndex;
use std::error::Error as StdError;
use std::fmt;
use std::io;
use std::sync::atomic::{AtomicUsize, Ordering};
@ -96,6 +97,14 @@ impl Client {
Query(self.0.query(&statement.0, params))
}
pub fn copy_in<S>(&mut self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyIn<S>
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>,
{
CopyIn(self.0.copy_in(&statement.0, params, stream))
}
pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut {
CopyOut(self.0.copy_out(&statement.0, params))
}
@ -227,6 +236,25 @@ impl Stream for Query {
}
}
#[must_use = "futures do nothing unless polled"]
pub struct CopyIn<S>(proto::CopyInFuture<S>)
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>;
impl<S> Future for CopyIn<S>
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>,
{
type Item = u64;
type Error = Error;
fn poll(&mut self) -> Poll<u64, Error> {
self.0.poll()
}
}
#[must_use = "streams do nothing unless polled"]
pub struct CopyOut(proto::CopyOutStream);

View File

@ -1,14 +1,17 @@
use antidote::Mutex;
use futures::sync::mpsc;
use futures::{AsyncSink, Sink, Stream};
use postgres_protocol;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::collections::HashMap;
use std::error::Error as StdError;
use std::sync::{Arc, Weak};
use disconnected;
use error::{self, Error};
use proto::connection::Request;
use proto::connection::{Request, RequestMessages};
use proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
use proto::copy_out::CopyOutStream;
use proto::execute::ExecuteFuture;
use proto::prepare::PrepareFuture;
@ -17,7 +20,7 @@ use proto::simple_query::SimpleQueryFuture;
use proto::statement::Statement;
use types::{IsNull, Oid, ToSql, Type};
pub struct PendingRequest(Result<Vec<u8>, Error>);
pub struct PendingRequest(Result<RequestMessages, Error>);
pub struct WeakClient(Weak<Inner>);
@ -122,17 +125,45 @@ impl Client {
}
pub fn execute(&self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture {
let pending = self.pending_execute(statement, params);
let pending = PendingRequest(
self.excecute_message(statement, params)
.map(RequestMessages::Single),
);
ExecuteFuture::new(self.clone(), pending, statement.clone())
}
pub fn query(&self, statement: &Statement, params: &[&ToSql]) -> QueryStream {
let pending = self.pending_execute(statement, params);
let pending = PendingRequest(
self.excecute_message(statement, params)
.map(RequestMessages::Single),
);
QueryStream::new(self.clone(), pending, statement.clone())
}
pub fn copy_in<S>(&self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyInFuture<S>
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>,
{
let (mut sender, receiver) = mpsc::channel(0);
let pending = PendingRequest(self.excecute_message(statement, params).map(|buf| {
match sender.start_send(CopyMessage::Data(buf)) {
Ok(AsyncSink::Ready) => {}
_ => unreachable!("channel should have capacity"),
}
RequestMessages::CopyIn {
receiver: CopyInReceiver::new(receiver),
pending_message: None,
}
}));
CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender)
}
pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream {
let pending = self.pending_execute(statement, params);
let pending = PendingRequest(
self.excecute_message(statement, params)
.map(RequestMessages::Single),
);
CopyOutStream::new(self.clone(), pending, statement.clone())
}
@ -142,35 +173,34 @@ impl Client {
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0);
let _ = self.0.sender.unbounded_send(Request {
messages: buf,
messages: RequestMessages::Single(buf),
sender,
});
}
fn pending_execute(&self, statement: &Statement, params: &[&ToSql]) -> PendingRequest {
self.pending(|buf| {
let r = frontend::bind(
"",
statement.name(),
Some(1),
params.iter().zip(statement.params()),
|(param, ty), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
Err(e) => Err(e),
},
Some(1),
buf,
);
match r {
Ok(()) => {}
Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)),
Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)),
}
frontend::execute("", 0, buf)?;
frontend::sync(buf);
Ok(())
})
fn excecute_message(&self, statement: &Statement, params: &[&ToSql]) -> Result<Vec<u8>, Error> {
let mut buf = vec![];
let r = frontend::bind(
"",
statement.name(),
Some(1),
params.iter().zip(statement.params()),
|(param, ty), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
Err(e) => Err(e),
},
Some(1),
&mut buf,
);
match r {
Ok(()) => {}
Err(frontend::BindError::Conversion(e)) => return Err(error::conversion(e)),
Err(frontend::BindError::Serialization(e)) => return Err(Error::from(e)),
}
frontend::execute("", 0, &mut buf)?;
frontend::sync(&mut buf);
Ok(buf)
}
fn pending<F>(&self, messages: F) -> PendingRequest
@ -178,6 +208,6 @@ impl Client {
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
{
let mut buf = vec![];
PendingRequest(messages(&mut buf).map(|()| buf))
PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf)))
}
}

View File

@ -8,11 +8,20 @@ use tokio_codec::Framed;
use error::{self, DbError, Error};
use proto::codec::PostgresCodec;
use proto::copy_in::CopyInReceiver;
use tls::TlsStream;
use {bad_response, disconnected, AsyncMessage, CancelData, Notification};
pub enum RequestMessages {
Single(Vec<u8>),
CopyIn {
receiver: CopyInReceiver,
pending_message: Option<Vec<u8>>,
},
}
pub struct Request {
pub messages: Vec<u8>,
pub messages: RequestMessages,
pub sender: mpsc::Sender<Message>,
}
@ -28,7 +37,7 @@ pub struct Connection {
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<Vec<u8>>,
pending_request: Option<RequestMessages>,
pending_response: Option<Message>,
responses: VecDeque<mpsc::Sender<Message>>,
state: State,
@ -140,7 +149,7 @@ impl Connection {
}
}
fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> {
fn poll_request(&mut self) -> Poll<Option<RequestMessages>, Error> {
if let Some(message) = self.pending_request.take() {
trace!("retrying pending request");
return Ok(Async::Ready(Some(message)));
@ -170,7 +179,7 @@ impl Connection {
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
request
RequestMessages::Single(request)
}
Async::Ready(None) => {
trace!(
@ -185,17 +194,60 @@ impl Connection {
}
};
match self.stream.start_send(request)? {
AsyncSink::Ready => {
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
match request {
RequestMessages::Single(request) => match self.stream.start_send(request)? {
AsyncSink::Ready => {
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
AsyncSink::NotReady(request) => {
trace!("poll_write: waiting on socket");
self.pending_request = Some(request);
return Ok(false);
AsyncSink::NotReady(request) => {
trace!("poll_write: waiting on socket");
self.pending_request = Some(RequestMessages::Single(request));
return Ok(false);
}
},
RequestMessages::CopyIn {
mut receiver,
mut pending_message,
} => {
let message = match pending_message.take() {
Some(message) => message,
None => match receiver.poll() {
Ok(Async::Ready(Some(message))) => message,
Ok(Async::Ready(None)) => {
trace!("poll_write: finished copy_in request");
continue;
}
Ok(Async::NotReady) => {
trace!("poll_write: waiting on copy_in stream");
self.pending_request = Some(RequestMessages::CopyIn {
receiver,
pending_message,
});
return Ok(true);
}
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
},
};
match self.stream.start_send(message)? {
AsyncSink::Ready => {
self.pending_request = Some(RequestMessages::CopyIn {
receiver,
pending_message: None,
});
}
AsyncSink::NotReady(message) => {
trace!("poll_write: waiting on socket");
self.pending_request = Some(RequestMessages::CopyIn {
receiver,
pending_message: Some(message),
});
return Ok(false);
}
};
}
}
}

View File

@ -0,0 +1,219 @@
use futures::sink;
use futures::sync::mpsc;
use futures::{Async, AsyncSink, Future, Poll, Sink, Stream};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use state_machine_future::RentToOwn;
use std::error::Error as StdError;
use error::{self, Error};
use proto::client::{Client, PendingRequest};
use proto::statement::Statement;
use {bad_response, disconnected};
pub enum CopyMessage {
Data(Vec<u8>),
Done,
}
pub struct CopyInReceiver {
receiver: mpsc::Receiver<CopyMessage>,
done: bool,
}
impl CopyInReceiver {
pub fn new(receiver: mpsc::Receiver<CopyMessage>) -> CopyInReceiver {
CopyInReceiver {
receiver,
done: false,
}
}
}
impl Stream for CopyInReceiver {
type Item = Vec<u8>;
type Error = ();
fn poll(&mut self) -> Poll<Option<Vec<u8>>, ()> {
if self.done {
return Ok(Async::Ready(None));
}
match self.receiver.poll()? {
Async::Ready(Some(CopyMessage::Data(buf))) => Ok(Async::Ready(Some(buf))),
Async::Ready(Some(CopyMessage::Done)) => {
self.done = true;
let mut buf = vec![];
frontend::copy_done(&mut buf);
frontend::sync(&mut buf);
Ok(Async::Ready(Some(buf)))
}
Async::Ready(None) => {
self.done = true;
let mut buf = vec![];
frontend::copy_fail("", &mut buf).unwrap();
frontend::sync(&mut buf);
Ok(Async::Ready(Some(buf)))
}
Async::NotReady => Ok(Async::NotReady),
}
}
}
#[derive(StateMachineFuture)]
pub enum CopyIn<S>
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>,
{
#[state_machine_future(start, transitions(ReadCopyInResponse))]
Start {
client: Client,
request: PendingRequest,
statement: Statement,
stream: S,
sender: mpsc::Sender<CopyMessage>,
},
#[state_machine_future(transitions(WriteCopyData))]
ReadCopyInResponse {
stream: S,
sender: mpsc::Sender<CopyMessage>,
receiver: mpsc::Receiver<Message>,
},
#[state_machine_future(transitions(WriteCopyDone))]
WriteCopyData {
stream: S,
pending_message: Option<CopyMessage>,
sender: mpsc::Sender<CopyMessage>,
receiver: mpsc::Receiver<Message>,
},
#[state_machine_future(transitions(ReadCommandComplete))]
WriteCopyDone {
future: sink::Send<mpsc::Sender<CopyMessage>>,
receiver: mpsc::Receiver<Message>,
},
#[state_machine_future(transitions(Finished))]
ReadCommandComplete { receiver: mpsc::Receiver<Message> },
#[state_machine_future(ready)]
Finished(u64),
#[state_machine_future(error)]
Failed(Error),
}
impl<S> PollCopyIn<S> for CopyIn<S>
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S>>) -> Poll<AfterStart<S>, Error> {
let state = state.take();
let receiver = state.client.send(state.request)?;
// the statement can drop after this point, since its close will queue up after the copy
transition!(ReadCopyInResponse {
stream: state.stream,
sender: state.sender,
receiver
})
}
fn poll_read_copy_in_response<'a>(
state: &'a mut RentToOwn<'a, ReadCopyInResponse<S>>,
) -> Poll<AfterReadCopyInResponse<S>, Error> {
loop {
let message = try_receive!(state.receiver.poll());
match message {
Some(Message::BindComplete) => {}
Some(Message::CopyInResponse(_)) => {
let state = state.take();
transition!(WriteCopyData {
stream: state.stream,
pending_message: None,
sender: state.sender,
receiver: state.receiver
})
}
Some(Message::ErrorResponse(body)) => return Err(error::__db(body)),
Some(_) => return Err(bad_response()),
None => return Err(disconnected()),
}
}
}
fn poll_write_copy_data<'a>(
state: &'a mut RentToOwn<'a, WriteCopyData<S>>,
) -> Poll<AfterWriteCopyData, Error> {
loop {
let message = match state.pending_message.take() {
Some(message) => message,
None => match try_ready!(state.stream.poll().map_err(error::__user)) {
Some(data) => {
let mut buf = vec![];
frontend::copy_data(&data, &mut buf).map_err(error::io)?;
CopyMessage::Data(buf)
}
None => {
let state = state.take();
transition!(WriteCopyDone {
future: state.sender.send(CopyMessage::Done),
receiver: state.receiver
})
}
},
};
match state.sender.start_send(message) {
Ok(AsyncSink::Ready) => {}
Ok(AsyncSink::NotReady(message)) => {
state.pending_message = Some(message);
return Ok(Async::NotReady);
}
Err(_) => return Err(disconnected()),
}
}
}
fn poll_write_copy_done<'a>(
state: &'a mut RentToOwn<'a, WriteCopyDone>,
) -> Poll<AfterWriteCopyDone, Error> {
try_ready!(state.future.poll().map_err(|_| disconnected()));
let state = state.take();
transition!(ReadCommandComplete {
receiver: state.receiver
})
}
fn poll_read_command_complete<'a>(
state: &'a mut RentToOwn<'a, ReadCommandComplete>,
) -> Poll<AfterReadCommandComplete, Error> {
let message = try_receive!(state.receiver.poll());
match message {
Some(Message::CommandComplete(body)) => {
let rows = body.tag()?.rsplit(' ').next().unwrap().parse().unwrap_or(0);
transition!(Finished(rows))
}
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
Some(_) => Err(bad_response()),
None => Err(disconnected()),
}
}
}
impl<S> CopyInFuture<S>
where
S: Stream<Item = Vec<u8>>,
S::Error: Into<Box<StdError + Sync + Send>>,
{
pub fn new(
client: Client,
request: PendingRequest,
statement: Statement,
stream: S,
sender: mpsc::Sender<CopyMessage>,
) -> CopyInFuture<S> {
CopyIn::start(client, request, statement, stream, sender)
}
}

View File

@ -13,6 +13,7 @@ mod client;
mod codec;
mod connect;
mod connection;
mod copy_in;
mod copy_out;
mod execute;
mod handshake;
@ -31,6 +32,7 @@ pub use proto::cancel::CancelFuture;
pub use proto::client::Client;
pub use proto::codec::PostgresCodec;
pub use proto::connection::Connection;
pub use proto::copy_in::CopyInFuture;
pub use proto::copy_out::CopyOutStream;
pub use proto::execute::ExecuteFuture;
pub use proto::handshake::HandshakeFuture;

View File

@ -8,6 +8,7 @@ extern crate futures;
extern crate log;
use futures::future;
use futures::stream;
use futures::sync::mpsc;
use std::error::Error;
use std::time::{Duration, Instant};
@ -238,8 +239,7 @@ fn cancel_query() {
TlsMode::None,
cancel_data,
)
})
.then(|r| {
}).then(|r| {
r.unwrap();
Ok::<(), ()>(())
});
@ -267,8 +267,7 @@ fn custom_enum() {
'ok',
'happy'
)",
))
.unwrap();
)).unwrap();
let select = client.prepare("SELECT $1::mood");
let select = runtime.block_on(select).unwrap();
@ -301,8 +300,7 @@ fn custom_domain() {
runtime
.block_on(client.batch_execute(
"CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)",
))
.unwrap();
)).unwrap();
let select = client.prepare("SELECT $1::session_id");
let select = runtime.block_on(select).unwrap();
@ -359,8 +357,7 @@ fn custom_composite() {
supplier INTEGER,
price NUMERIC
)",
))
.unwrap();
)).unwrap();
let select = client.prepare("SELECT $1::inventory_item");
let select = runtime.block_on(select).unwrap();
@ -399,8 +396,7 @@ fn custom_range() {
subtype = float8,
subtype_diff = float8mi
)",
))
.unwrap();
)).unwrap();
let select = client.prepare("SELECT $1::floatrange");
let select = runtime.block_on(select).unwrap();
@ -488,8 +484,7 @@ fn transaction_commit() {
.block_on(tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
))
.unwrap();
)).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
@ -499,8 +494,7 @@ fn transaction_commit() {
id SERIAL,
name TEXT
)",
))
.unwrap();
)).unwrap();
let f = client.batch_execute("INSERT INTO foo (name) VALUES ('steven')");
runtime.block_on(client.transaction(f)).unwrap();
@ -510,8 +504,7 @@ fn transaction_commit() {
client
.prepare("SELECT name FROM foo")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, &str>(0), "steven");
@ -526,8 +519,7 @@ fn transaction_abort() {
.block_on(tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
))
.unwrap();
)).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
@ -537,8 +529,7 @@ fn transaction_abort() {
id SERIAL,
name TEXT
)",
))
.unwrap();
)).unwrap();
let f = client
.batch_execute("INSERT INTO foo (name) VALUES ('steven')")
@ -551,8 +542,91 @@ fn transaction_abort() {
client
.prepare("SELECT name FROM foo")
.and_then(|s| client.query(&s, &[]).collect()),
)
.unwrap();
).unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn copy_in() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime
.block_on(tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
)).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)).unwrap();
let stream = stream::iter_ok::<_, String>(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()]);
let rows = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
).unwrap();
assert_eq!(rows, 2);
let rows = runtime
.block_on(
client
.prepare("SELECT id, name FROM foo ORDER BY id")
.and_then(|s| client.query(&s, &[]).collect()),
).unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, &str>(1), "jim");
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, &str>(1), "joe");
}
#[test]
fn copy_in_error() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime
.block_on(tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
)).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(client.batch_execute(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)).unwrap();
let stream = stream::iter_result(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]);
let error = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
).unwrap_err();
error.to_string().contains("asdf");
let rows = runtime
.block_on(
client
.prepare("SELECT id, name FROM foo ORDER BY id")
.and_then(|s| client.query(&s, &[]).collect()),
).unwrap();
assert_eq!(rows.len(), 0);
}
@ -566,8 +640,7 @@ fn copy_out() {
.block_on(tokio_postgres::connect(
"postgres://postgres@localhost:5433".parse().unwrap(),
TlsMode::None,
))
.unwrap();
)).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
@ -578,15 +651,13 @@ fn copy_out() {
name TEXT
);
INSERT INTO foo (name) VALUES ('jim'), ('joe');",
))
.unwrap();
)).unwrap();
let data = runtime
.block_on(
client
.prepare("COPY foo TO STDOUT")
.and_then(|s| client.copy_out(&s, &[]).concat2()),
)
.unwrap();
).unwrap();
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
}