Avoid copies in copy_in
copy_in data was previously copied ~3 times - once into the copy_in buffer, once more to frame it into a CopyData frame, and once to write that into the stream. Our Codec is now a bit more interesting. Rather than just writing out pre-encoded data, we can also send along unencoded CopyData so they can be framed directly into the stream output buffer. In the future we can extend this to e.g. avoid allocating for simple commands like Sync. This also allows us to directly pass large copy_in input directly through without rebuffering it.
This commit is contained in:
parent
bcb4ca0713
commit
db462eb018
@ -2,6 +2,8 @@
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
|
||||
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
use std::marker;
|
||||
@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CopyData<T> {
|
||||
buf: T,
|
||||
len: i32,
|
||||
}
|
||||
|
||||
impl<T> CopyData<T>
|
||||
where
|
||||
T: Buf,
|
||||
{
|
||||
pub fn new<U>(buf: U) -> io::Result<CopyData<T>>
|
||||
where
|
||||
U: IntoBuf<Buf = T>,
|
||||
{
|
||||
let buf = buf.into_buf();
|
||||
|
||||
let len = buf
|
||||
.remaining()
|
||||
.checked_add(4)
|
||||
.and_then(|l| i32::try_from(l).ok())
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
|
||||
})?;
|
||||
|
||||
Ok(CopyData { buf, len })
|
||||
}
|
||||
|
||||
pub fn write(self, out: &mut BytesMut) {
|
||||
out.reserve(self.len as usize + 1);
|
||||
out.put_u8(b'd');
|
||||
out.put_i32_be(self.len);
|
||||
out.put(self.buf);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn copy_done(buf: &mut Vec<u8>) {
|
||||
buf.push(b'c');
|
||||
|
@ -170,7 +170,7 @@ pub struct CopyIn<S>(pub(crate) proto::CopyInFuture<S>)
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn error::Error + Sync + Send>>;
|
||||
|
||||
impl<S> Future for CopyIn<S>
|
||||
|
@ -242,7 +242,7 @@ impl Client {
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
// FIXME error type?
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
|
@ -11,6 +11,7 @@ use std::sync::{Arc, Weak};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::bind::BindFuture;
|
||||
use crate::proto::codec::FrontendMessage;
|
||||
use crate::proto::connection::{Request, RequestMessages};
|
||||
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
|
||||
use crate::proto::copy_out::CopyOutStream;
|
||||
@ -185,8 +186,12 @@ impl Client {
|
||||
if let Ok(ref mut buf) = buf {
|
||||
frontend::sync(buf);
|
||||
}
|
||||
let pending =
|
||||
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
|
||||
let pending = PendingRequest(buf.map(|m| {
|
||||
(
|
||||
RequestMessages::Single(FrontendMessage::Raw(m)),
|
||||
self.0.idle.guard(),
|
||||
)
|
||||
}));
|
||||
BindFuture::new(self.clone(), pending, name, statement.clone())
|
||||
}
|
||||
|
||||
@ -208,12 +213,12 @@ impl Client {
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
let (mut sender, receiver) = mpsc::channel(1);
|
||||
let pending = PendingRequest(self.excecute_message(statement, params).map(|data| {
|
||||
match sender.start_send(CopyMessage { data, done: false }) {
|
||||
match sender.start_send(CopyMessage::Message(data)) {
|
||||
Ok(AsyncSink::Ready) => {}
|
||||
_ => unreachable!("channel should have capacity"),
|
||||
}
|
||||
@ -278,7 +283,7 @@ impl Client {
|
||||
frontend::sync(&mut buf);
|
||||
let (sender, _) = mpsc::channel(0);
|
||||
let _ = self.0.sender.unbounded_send(Request {
|
||||
messages: RequestMessages::Single(buf),
|
||||
messages: RequestMessages::Single(FrontendMessage::Raw(buf)),
|
||||
sender,
|
||||
idle: None,
|
||||
});
|
||||
@ -326,11 +331,11 @@ impl Client {
|
||||
&self,
|
||||
statement: &Statement,
|
||||
params: &[&dyn ToSql],
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
) -> Result<FrontendMessage, Error> {
|
||||
let mut buf = self.bind_message(statement, "", params)?;
|
||||
frontend::execute("", 0, &mut buf).map_err(Error::parse)?;
|
||||
frontend::sync(&mut buf);
|
||||
Ok(buf)
|
||||
Ok(FrontendMessage::Raw(buf))
|
||||
}
|
||||
|
||||
fn pending<F>(&self, messages: F) -> PendingRequest
|
||||
@ -338,8 +343,11 @@ impl Client {
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
|
||||
{
|
||||
let mut buf = vec![];
|
||||
PendingRequest(
|
||||
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
|
||||
)
|
||||
PendingRequest(messages(&mut buf).map(|()| {
|
||||
(
|
||||
RequestMessages::Single(FrontendMessage::Raw(buf)),
|
||||
self.0.idle.guard(),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +1,26 @@
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use postgres_protocol::message::backend;
|
||||
use postgres_protocol::message::frontend::CopyData;
|
||||
use std::io;
|
||||
use tokio_codec::{Decoder, Encoder};
|
||||
|
||||
pub enum FrontendMessage {
|
||||
Raw(Vec<u8>),
|
||||
CopyData(CopyData<Box<dyn Buf + Send>>),
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder for PostgresCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Item = FrontendMessage;
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
|
||||
match item {
|
||||
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
|
||||
FrontendMessage::CopyData(data) => data.write(dst),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -11,7 +11,7 @@ use std::collections::HashMap;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
|
||||
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture, FrontendMessage};
|
||||
use crate::tls::ChannelBinding;
|
||||
use crate::{Config, Error, TlsConnect};
|
||||
|
||||
@ -111,7 +111,7 @@ where
|
||||
let stream = Framed::new(stream, PostgresCodec);
|
||||
|
||||
transition!(SendingStartup {
|
||||
future: stream.send(buf),
|
||||
future: stream.send(FrontendMessage::Raw(buf)),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
channel_binding,
|
||||
@ -156,7 +156,7 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
@ -178,7 +178,7 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
@ -235,7 +235,7 @@ where
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
@ -293,7 +293,7 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
future: state.stream.send(FrontendMessage::Raw(buf)),
|
||||
scram: state.scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
|
@ -8,17 +8,17 @@ use std::io;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::codec::PostgresCodec;
|
||||
use crate::proto::codec::{FrontendMessage, PostgresCodec};
|
||||
use crate::proto::copy_in::CopyInReceiver;
|
||||
use crate::proto::idle::IdleGuard;
|
||||
use crate::{AsyncMessage, Notification};
|
||||
use crate::{DbError, Error};
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(Vec<u8>),
|
||||
Single(FrontendMessage),
|
||||
CopyIn {
|
||||
receiver: CopyInReceiver,
|
||||
pending_message: Option<Vec<u8>>,
|
||||
pending_message: Option<FrontendMessage>,
|
||||
},
|
||||
}
|
||||
|
||||
@ -188,7 +188,7 @@ where
|
||||
self.state = State::Terminating;
|
||||
let mut request = vec![];
|
||||
frontend::terminate(&mut request);
|
||||
RequestMessages::Single(request)
|
||||
RequestMessages::Single(FrontendMessage::Raw(request))
|
||||
}
|
||||
Async::Ready(None) => {
|
||||
trace!(
|
||||
|
@ -1,20 +1,21 @@
|
||||
use bytes::{Buf, IntoBuf};
|
||||
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
|
||||
use futures::sink;
|
||||
use futures::stream;
|
||||
use futures::sync::mpsc;
|
||||
use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::message::frontend::{self, CopyData};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::error::Error as StdError;
|
||||
use std::mem;
|
||||
|
||||
use crate::proto::client::{Client, PendingRequest};
|
||||
use crate::proto::codec::FrontendMessage;
|
||||
use crate::proto::statement::Statement;
|
||||
use crate::Error;
|
||||
|
||||
pub struct CopyMessage {
|
||||
pub data: Vec<u8>,
|
||||
pub done: bool,
|
||||
pub enum CopyMessage {
|
||||
Message(FrontendMessage),
|
||||
Done,
|
||||
}
|
||||
|
||||
pub struct CopyInReceiver {
|
||||
@ -32,30 +33,29 @@ impl CopyInReceiver {
|
||||
}
|
||||
|
||||
impl Stream for CopyInReceiver {
|
||||
type Item = Vec<u8>;
|
||||
type Item = FrontendMessage;
|
||||
type Error = ();
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Vec<u8>>, ()> {
|
||||
fn poll(&mut self) -> Poll<Option<FrontendMessage>, ()> {
|
||||
if self.done {
|
||||
return Ok(Async::Ready(None));
|
||||
}
|
||||
|
||||
match self.receiver.poll()? {
|
||||
Async::Ready(Some(mut data)) => {
|
||||
if data.done {
|
||||
self.done = true;
|
||||
frontend::copy_done(&mut data.data);
|
||||
frontend::sync(&mut data.data);
|
||||
}
|
||||
|
||||
Ok(Async::Ready(Some(data.data)))
|
||||
Async::Ready(Some(CopyMessage::Message(message))) => Ok(Async::Ready(Some(message))),
|
||||
Async::Ready(Some(CopyMessage::Done)) => {
|
||||
self.done = true;
|
||||
let mut buf = vec![];
|
||||
frontend::copy_done(&mut buf);
|
||||
frontend::sync(&mut buf);
|
||||
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
|
||||
}
|
||||
Async::Ready(None) => {
|
||||
self.done = true;
|
||||
let mut buf = vec![];
|
||||
frontend::copy_fail("", &mut buf).unwrap();
|
||||
frontend::sync(&mut buf);
|
||||
Ok(Async::Ready(Some(buf)))
|
||||
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
|
||||
}
|
||||
Async::NotReady => Ok(Async::NotReady),
|
||||
}
|
||||
@ -67,7 +67,7 @@ pub enum CopyIn<S>
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(ReadCopyInResponse))]
|
||||
@ -86,8 +86,8 @@ where
|
||||
},
|
||||
#[state_machine_future(transitions(WriteCopyDone))]
|
||||
WriteCopyData {
|
||||
stream: S,
|
||||
buf: Vec<u8>,
|
||||
stream: stream::Fuse<S>,
|
||||
buf: BytesMut,
|
||||
pending_message: Option<CopyMessage>,
|
||||
sender: mpsc::Sender<CopyMessage>,
|
||||
receiver: mpsc::Receiver<Message>,
|
||||
@ -109,7 +109,7 @@ impl<S> PollCopyIn<S> for CopyIn<S>
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: IntoBuf,
|
||||
<S::Item as IntoBuf>::Buf: Send,
|
||||
<S::Item as IntoBuf>::Buf: 'static + Send,
|
||||
S::Error: Into<Box<dyn StdError + Sync + Send>>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S>>) -> Poll<AfterStart<S>, Error> {
|
||||
@ -135,8 +135,8 @@ where
|
||||
Some(Message::CopyInResponse(_)) => {
|
||||
let state = state.take();
|
||||
transition!(WriteCopyData {
|
||||
stream: state.stream,
|
||||
buf: vec![],
|
||||
stream: state.stream.fuse(),
|
||||
buf: BytesMut::new(),
|
||||
pending_message: None,
|
||||
sender: state.sender,
|
||||
receiver: state.receiver
|
||||
@ -167,44 +167,51 @@ where
|
||||
}
|
||||
|
||||
loop {
|
||||
let done = loop {
|
||||
let buf: Box<dyn Buf + Send> = loop {
|
||||
match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) {
|
||||
Some(data) => {
|
||||
// FIXME avoid collect
|
||||
frontend::copy_data(&data.into_buf().collect::<Vec<_>>(), &mut state.buf)
|
||||
.map_err(Error::encode)?;
|
||||
Some(buf) => {
|
||||
let buf = buf.into_buf();
|
||||
if buf.remaining() > 4096 {
|
||||
if state.buf.is_empty() {
|
||||
break Box::new(buf);
|
||||
} else {
|
||||
let cur_buf = state.buf.take().freeze().into_buf();
|
||||
break Box::new(cur_buf.chain(buf));
|
||||
}
|
||||
}
|
||||
|
||||
state.buf.reserve(buf.remaining());
|
||||
state.buf.put(buf);
|
||||
if state.buf.len() > 4096 {
|
||||
break false;
|
||||
break Box::new(state.buf.take().freeze().into_buf());
|
||||
}
|
||||
}
|
||||
None => break true,
|
||||
None => break Box::new(state.buf.take().freeze().into_buf()),
|
||||
}
|
||||
};
|
||||
|
||||
let message = CopyMessage {
|
||||
data: mem::replace(&mut state.buf, vec![]),
|
||||
done,
|
||||
};
|
||||
if buf.has_remaining() {
|
||||
let data = CopyData::new(buf).map_err(Error::encode)?;
|
||||
let message = CopyMessage::Message(FrontendMessage::CopyData(data));
|
||||
|
||||
if done {
|
||||
match state
|
||||
.sender
|
||||
.start_send(message)
|
||||
.map_err(|_| Error::closed())?
|
||||
{
|
||||
AsyncSink::Ready => {}
|
||||
AsyncSink::NotReady(message) => {
|
||||
state.pending_message = Some(message);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let state = state.take();
|
||||
transition!(WriteCopyDone {
|
||||
future: state.sender.send(message),
|
||||
future: state.sender.send(CopyMessage::Done),
|
||||
receiver: state.receiver,
|
||||
});
|
||||
}
|
||||
|
||||
match state
|
||||
.sender
|
||||
.start_send(message)
|
||||
.map_err(|_| Error::closed())?
|
||||
{
|
||||
AsyncSink::Ready => {}
|
||||
AsyncSink::NotReady(message) => {
|
||||
state.pending_message = Some(message);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,7 +53,7 @@ pub use crate::proto::bind::BindFuture;
|
||||
pub use crate::proto::cancel_query::CancelQueryFuture;
|
||||
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
|
||||
pub use crate::proto::client::Client;
|
||||
pub use crate::proto::codec::PostgresCodec;
|
||||
pub use crate::proto::codec::{FrontendMessage, PostgresCodec};
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::proto::connect::ConnectFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
|
@ -4,6 +4,7 @@ use futures::sync::mpsc;
|
||||
use futures::{future, stream, try_ready};
|
||||
use log::debug;
|
||||
use std::error::Error;
|
||||
use std::fmt::Write;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::net::TcpStream;
|
||||
@ -616,6 +617,49 @@ fn copy_in() {
|
||||
assert_eq!(rows[1].get::<_, &str>(1), "joe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_in_large() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
runtime
|
||||
.block_on(
|
||||
client
|
||||
.simple_query(
|
||||
"CREATE TEMPORARY TABLE foo (
|
||||
id INTEGER,
|
||||
name TEXT
|
||||
)",
|
||||
)
|
||||
.for_each(|_| Ok(())),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let a = "0\tname0\n".to_string();
|
||||
let mut b = String::new();
|
||||
for i in 1..5_000 {
|
||||
writeln!(b, "{0}\tname{0}", i).unwrap();
|
||||
}
|
||||
let mut c = String::new();
|
||||
for i in 5_000..10_000 {
|
||||
writeln!(c, "{0}\tname{0}", i).unwrap();
|
||||
}
|
||||
|
||||
let stream = stream::iter_ok::<_, String>(vec![a, b, c]);
|
||||
let rows = runtime
|
||||
.block_on(
|
||||
client
|
||||
.prepare("COPY foo FROM STDIN")
|
||||
.and_then(|s| client.copy_in(&s, &[], stream)),
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(rows, 10_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_in_error() {
|
||||
let _ = env_logger::try_init();
|
||||
|
Loading…
Reference in New Issue
Block a user