Merge pull request #451 from sfackler/less-copy-copies
Avoid copies in copy_in
This commit is contained in:
commit
eaef62c340
@ -2,6 +2,8 @@
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
|
||||
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
use std::marker;
|
||||
@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CopyData<T> {
|
||||
buf: T,
|
||||
len: i32,
|
||||
}
|
||||
|
||||
impl<T> CopyData<T>
|
||||
where
|
||||
T: Buf,
|
||||
{
|
||||
pub fn new<U>(buf: U) -> io::Result<CopyData<T>>
|
||||
where
|
||||
U: IntoBuf<Buf = T>,
|
||||
{
|
||||
let buf = buf.into_buf();
|
||||
|
||||
let len = buf
|
||||
.remaining()
|
||||
.checked_add(4)
|
||||
.and_then(|l| i32::try_from(l).ok())
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
|
||||
})?;
|
||||
|
||||
Ok(CopyData { buf, len })
|
||||
}
|
||||
|
||||
pub fn write(self, out: &mut BytesMut) {
|
||||
out.reserve(self.len as usize + 1);
|
||||
out.put_u8(b'd');
|
||||
out.put_i32_be(self.len);
|
||||
out.put(self.buf);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn copy_done(buf: &mut Vec<u8>) {
|
||||
buf.push(b'c');
|
||||
|
@ -170,7 +170,7 @@ pub struct CopyIn<S>(pub(crate) proto::CopyInFuture<S>)
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn error::Error + Sync + Send>>;
|
||||
|
||||
impl<S> Future for CopyIn<S>
|
||||
|
@ -242,7 +242,7 @@ impl Client {
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
// FIXME error type?
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
|
@ -11,6 +11,7 @@ use std::sync::{Arc, Weak};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::bind::BindFuture;
|
||||
use crate::proto::codec::FrontendMessage;
|
||||
use crate::proto::connection::{Request, RequestMessages};
|
||||
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
|
||||
use crate::proto::copy_out::CopyOutStream;
|
||||
@ -185,8 +186,12 @@ impl Client {
|
||||
if let Ok(ref mut buf) = buf {
|
||||
frontend::sync(buf);
|
||||
}
|
||||
let pending =
|
||||
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
|
||||
let pending = PendingRequest(buf.map(|m| {
|
||||
(
|
||||
RequestMessages::Single(FrontendMessage::Raw(m)),
|
||||
self.0.idle.guard(),
|
||||
)
|
||||
}));
|
||||
BindFuture::new(self.clone(), pending, name, statement.clone())
|
||||
}
|
||||
|
||||
@ -208,12 +213,12 @@ impl Client {
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
let (mut sender, receiver) = mpsc::channel(1);
|
||||
let pending = PendingRequest(self.excecute_message(statement, params).map(|data| {
|
||||
match sender.start_send(CopyMessage { data, done: false }) {
|
||||
match sender.start_send(CopyMessage::Message(data)) {
|
||||
Ok(AsyncSink::Ready) => {}
|
||||
_ => unreachable!("channel should have capacity"),
|
||||
}
|
||||
@ -278,7 +283,7 @@ impl Client {
|
||||
frontend::sync(&mut buf);
|
||||
let (sender, _) = mpsc::channel(0);
|
||||
let _ = self.0.sender.unbounded_send(Request {
|
||||
messages: RequestMessages::Single(buf),
|
||||
messages: RequestMessages::Single(FrontendMessage::Raw(buf)),
|
||||
sender,
|
||||
idle: None,
|
||||
});
|
||||
@ -326,11 +331,11 @@ impl Client {
|
||||
&self,
|
||||
statement: &Statement,
|
||||
params: &[&dyn ToSql],
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
) -> Result<FrontendMessage, Error> {
|
||||
let mut buf = self.bind_message(statement, "", params)?;
|
||||
frontend::execute("", 0, &mut buf).map_err(Error::parse)?;
|
||||
frontend::sync(&mut buf);
|
||||
Ok(buf)
|
||||
Ok(FrontendMessage::Raw(buf))
|
||||
}
|
||||
|
||||
fn pending<F>(&self, messages: F) -> PendingRequest
|
||||
@ -338,8 +343,11 @@ impl Client {
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
|
||||
{
|
||||
let mut buf = vec![];
|
||||
PendingRequest(
|
||||
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
|
||||
)
|
||||
PendingRequest(messages(&mut buf).map(|()| {
|
||||
(
|
||||
RequestMessages::Single(FrontendMessage::Raw(buf)),
|
||||
self.0.idle.guard(),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +1,26 @@
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use postgres_protocol::message::backend;
|
||||
use postgres_protocol::message::frontend::CopyData;
|
||||
use std::io;
|
||||
use tokio_codec::{Decoder, Encoder};
|
||||
|
||||
pub enum FrontendMessage {
|
||||
Raw(Vec<u8>),
|
||||
CopyData(CopyData<Box<dyn Buf + Send>>),
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder for PostgresCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Item = FrontendMessage;
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
|
||||
match item {
|
||||
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
|
||||
FrontendMessage::CopyData(data) => data.write(dst),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ use std::collections::HashMap;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
|
||||
use crate::proto::{Client, Connection, FrontendMessage, MaybeTlsStream, PostgresCodec, TlsFuture};
|
||||
use crate::tls::ChannelBinding;
|
||||
use crate::{Config, Error, TlsConnect};
|
||||
|
||||
@ -111,7 +111,7 @@ where
|
||||
let stream = Framed::new(stream, PostgresCodec);
|
||||
|
||||
transition!(SendingStartup {
|
||||
future: stream.send(buf),
|
||||
future: stream.send(FrontendMessage::Raw(buf)),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
channel_binding,
|
||||
@ -156,7 +156,7 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
@ -178,7 +178,7 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
@ -235,7 +235,7 @@ where
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
@ -293,7 +293,7 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
scram: state.scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
|
@ -8,17 +8,17 @@ use std::io;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::codec::PostgresCodec;
|
||||
use crate::proto::codec::{FrontendMessage, PostgresCodec};
|
||||
use crate::proto::copy_in::CopyInReceiver;
|
||||
use crate::proto::idle::IdleGuard;
|
||||
use crate::{AsyncMessage, Notification};
|
||||
use crate::{DbError, Error};
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(Vec<u8>),
|
||||
Single(FrontendMessage),
|
||||
CopyIn {
|
||||
receiver: CopyInReceiver,
|
||||
pending_message: Option<Vec<u8>>,
|
||||
pending_message: Option<FrontendMessage>,
|
||||
},
|
||||
}
|
||||
|
||||
@ -188,7 +188,7 @@ where
|
||||
self.state = State::Terminating;
|
||||
let mut request = vec![];
|
||||
frontend::terminate(&mut request);
|
||||
RequestMessages::Single(request)
|
||||
RequestMessages::Single(FrontendMessage::Raw(request))
|
||||
}
|
||||
Async::Ready(None) => {
|
||||
trace!(
|
||||
|
@ -1,20 +1,21 @@
|
||||
use bytes::{Buf, IntoBuf};
|
||||
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
|
||||
use futures::sink;
|
||||
use futures::stream;
|
||||
use futures::sync::mpsc;
|
||||
use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::message::frontend::{self, CopyData};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::error::Error as StdError;
|
||||
use std::mem;
|
||||
|
||||
use crate::proto::client::{Client, PendingRequest};
|
||||
use crate::proto::codec::FrontendMessage;
|
||||
use crate::proto::statement::Statement;
|
||||
use crate::Error;
|
||||
|
||||
pub struct CopyMessage {
|
||||
pub data: Vec<u8>,
|
||||
pub done: bool,
|
||||
pub enum CopyMessage {
|
||||
Message(FrontendMessage),
|
||||
Done,
|
||||
}
|
||||
|
||||
pub struct CopyInReceiver {
|
||||
@ -32,30 +33,29 @@ impl CopyInReceiver {
|
||||
}
|
||||
|
||||
impl Stream for CopyInReceiver {
|
||||
type Item = Vec<u8>;
|
||||
type Item = FrontendMessage;
|
||||
type Error = ();
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Vec<u8>>, ()> {
|
||||
fn poll(&mut self) -> Poll<Option<FrontendMessage>, ()> {
|
||||
if self.done {
|
||||
return Ok(Async::Ready(None));
|
||||
}
|
||||
|
||||
match self.receiver.poll()? {
|
||||
Async::Ready(Some(mut data)) => {
|
||||
if data.done {
|
||||
self.done = true;
|
||||
frontend::copy_done(&mut data.data);
|
||||
frontend::sync(&mut data.data);
|
||||
}
|
||||
|
||||
Ok(Async::Ready(Some(data.data)))
|
||||
Async::Ready(Some(CopyMessage::Message(message))) => Ok(Async::Ready(Some(message))),
|
||||
Async::Ready(Some(CopyMessage::Done)) => {
|
||||
self.done = true;
|
||||
let mut buf = vec![];
|
||||
frontend::copy_done(&mut buf);
|
||||
frontend::sync(&mut buf);
|
||||
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
|
||||
}
|
||||
Async::Ready(None) => {
|
||||
self.done = true;
|
||||
let mut buf = vec![];
|
||||
frontend::copy_fail("", &mut buf).unwrap();
|
||||
frontend::sync(&mut buf);
|
||||
Ok(Async::Ready(Some(buf)))
|
||||
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
|
||||
}
|
||||
Async::NotReady => Ok(Async::NotReady),
|
||||
}
|
||||
@ -67,7 +67,7 @@ pub enum CopyIn<S>
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(ReadCopyInResponse))]
|
||||
@ -86,8 +86,8 @@ where
|
||||
},
|
||||
#[state_machine_future(transitions(WriteCopyDone))]
|
||||
WriteCopyData {
|
||||
stream: S,
|
||||
buf: Vec<u8>,
|
||||
stream: stream::Fuse<S>,
|
||||
buf: BytesMut,
|
||||
pending_message: Option<CopyMessage>,
|
||||
sender: mpsc::Sender<CopyMessage>,
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
@ -109,7 +109,7 @@ impl<S> PollCopyIn<S> for CopyIn<S>
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S>>) -> Poll<AfterStart<S>, Error> {
|
||||
@ -135,8 +135,8 @@ where
|
||||
Some(Message::CopyInResponse(_)) => {
|
||||
let state = state.take();
|
||||
transition!(WriteCopyData {
|
||||
stream: state.stream,
|
||||
buf: vec![],
|
||||
stream: state.stream.fuse(),
|
||||
buf: BytesMut::new(),
|
||||
pending_message: None,
|
||||
sender: state.sender,
|
||||
receiver: state.receiver
|
||||
@ -167,44 +167,51 @@ where
|
||||
}
|
||||
|
||||
loop {
|
||||
let done = loop {
|
||||
let buf: Box<dyn Buf + Send> = loop {
|
||||
match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) {
|
||||
Some(data) => {
|
||||
// FIXME avoid collect
|
||||
frontend::copy_data(&data.into_buf().collect::<Vec<_>>(), &mut state.buf)
|
||||
.map_err(Error::encode)?;
|
||||
Some(buf) => {
|
||||
let buf = buf.into_buf();
|
||||
if buf.remaining() > 4096 {
|
||||
if state.buf.is_empty() {
|
||||
break Box::new(buf);
|
||||
} else {
|
||||
let cur_buf = state.buf.take().freeze().into_buf();
|
||||
break Box::new(cur_buf.chain(buf));
|
||||
}
|
||||
}
|
||||
|
||||
state.buf.reserve(buf.remaining());
|
||||
state.buf.put(buf);
|
||||
if state.buf.len() > 4096 {
|
||||
break false;
|
||||
break Box::new(state.buf.take().freeze().into_buf());
|
||||
}
|
||||
}
|
||||
None => break true,
|
||||
None => break Box::new(state.buf.take().freeze().into_buf()),
|
||||
}
|
||||
};
|
||||
|
||||
let message = CopyMessage {
|
||||
data: mem::replace(&mut state.buf, vec![]),
|
||||
done,
|
||||
};
|
||||
if buf.has_remaining() {
|
||||
let data = CopyData::new(buf).map_err(Error::encode)?;
|
||||
let message = CopyMessage::Message(FrontendMessage::CopyData(data));
|
||||
|
||||
if done {
|
||||
match state
|
||||
.sender
|
||||
.start_send(message)
|
||||
.map_err(|_| Error::closed())?
|
||||
{
|
||||
AsyncSink::Ready => {}
|
||||
AsyncSink::NotReady(message) => {
|
||||
state.pending_message = Some(message);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let state = state.take();
|
||||
transition!(WriteCopyDone {
|
||||
future: state.sender.send(message),
|
||||
future: state.sender.send(CopyMessage::Done),
|
||||
receiver: state.receiver,
|
||||
});
|
||||
}
|
||||
|
||||
match state
|
||||
.sender
|
||||
.start_send(message)
|
||||
.map_err(|_| Error::closed())?
|
||||
{
|
||||
AsyncSink::Ready => {}
|
||||
AsyncSink::NotReady(message) => {
|
||||
state.pending_message = Some(message);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ pub use crate::proto::bind::BindFuture;
|
||||
pub use crate::proto::cancel_query::CancelQueryFuture;
|
||||
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
|
||||
pub use crate::proto::client::Client;
|
||||
pub use crate::proto::codec::PostgresCodec;
|
||||
pub use crate::proto::codec::{FrontendMessage, PostgresCodec};
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::proto::connect::ConnectFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
|
@ -4,6 +4,7 @@ use futures::sync::mpsc;
|
||||
use futures::{future, stream, try_ready};
|
||||
use log::debug;
|
||||
use std::error::Error;
|
||||
use std::fmt::Write;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::net::TcpStream;
|
||||
@ -616,6 +617,49 @@ fn copy_in() {
|
||||
assert_eq!(rows[1].get::<_, &str>(1), "joe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_in_large() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
runtime
|
||||
.block_on(
|
||||
client
|
||||
.simple_query(
|
||||
"CREATE TEMPORARY TABLE foo (
|
||||
id INTEGER,
|
||||
name TEXT
|
||||
)",
|
||||
)
|
||||
.for_each(|_| Ok(())),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let a = "0\tname0\n".to_string();
|
||||
let mut b = String::new();
|
||||
for i in 1..5_000 {
|
||||
writeln!(b, "{0}\tname{0}", i).unwrap();
|
||||
}
|
||||
let mut c = String::new();
|
||||
for i in 5_000..10_000 {
|
||||
writeln!(c, "{0}\tname{0}", i).unwrap();
|
||||
}
|
||||
|
||||
let stream = stream::iter_ok::<_, String>(vec![a, b, c]);
|
||||
let rows = runtime
|
||||
.block_on(
|
||||
client
|
||||
.prepare("COPY foo FROM STDIN")
|
||||
.and_then(|s| client.copy_in(&s, &[], stream)),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(rows, 10_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_in_error() {
|
||||
let _ = env_logger::try_init();
|
||||
|
Loading…
Reference in New Issue
Block a user