Avoid copies in copy_in

copy_in data was previously copied ~3 times - once into the copy_in
buffer, once more to frame it into a CopyData frame, and once to write
that into the stream.

Our Codec is now a bit more interesting. Rather than just writing out
pre-encoded data, we can also send along unencoded CopyData so they can
be framed directly into the stream output buffer. In the future we can
extend this to e.g. avoid allocating for simple commands like Sync.

This also allows us to directly pass large copy_in input directly
through without rebuffering it.
This commit is contained in:
Steven Fackler 2019-06-25 18:54:17 -07:00
parent bcb4ca0713
commit db462eb018
10 changed files with 180 additions and 75 deletions

View File

@ -2,6 +2,8 @@
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;
@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
})
}
pub struct CopyData<T> {
buf: T,
len: i32,
}
impl<T> CopyData<T>
where
T: Buf,
{
pub fn new<U>(buf: U) -> io::Result<CopyData<T>>
where
U: IntoBuf<Buf = T>,
{
let buf = buf.into_buf();
let len = buf
.remaining()
.checked_add(4)
.and_then(|l| i32::try_from(l).ok())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
})?;
Ok(CopyData { buf, len })
}
pub fn write(self, out: &mut BytesMut) {
out.reserve(self.len as usize + 1);
out.put_u8(b'd');
out.put_i32_be(self.len);
out.put(self.buf);
}
}
#[inline]
pub fn copy_done(buf: &mut Vec<u8>) {
buf.push(b'c');

View File

@ -170,7 +170,7 @@ pub struct CopyIn<S>(pub(crate) proto::CopyInFuture<S>)
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>;
impl<S> Future for CopyIn<S>

View File

@ -242,7 +242,7 @@ impl Client {
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
// FIXME error type?
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{

View File

@ -11,6 +11,7 @@ use std::sync::{Arc, Weak};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::bind::BindFuture;
use crate::proto::codec::FrontendMessage;
use crate::proto::connection::{Request, RequestMessages};
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
use crate::proto::copy_out::CopyOutStream;
@ -185,8 +186,12 @@ impl Client {
if let Ok(ref mut buf) = buf {
frontend::sync(buf);
}
let pending =
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
let pending = PendingRequest(buf.map(|m| {
(
RequestMessages::Single(FrontendMessage::Raw(m)),
self.0.idle.guard(),
)
}));
BindFuture::new(self.clone(), pending, name, statement.clone())
}
@ -208,12 +213,12 @@ impl Client {
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
let (mut sender, receiver) = mpsc::channel(1);
let pending = PendingRequest(self.excecute_message(statement, params).map(|data| {
match sender.start_send(CopyMessage { data, done: false }) {
match sender.start_send(CopyMessage::Message(data)) {
Ok(AsyncSink::Ready) => {}
_ => unreachable!("channel should have capacity"),
}
@ -278,7 +283,7 @@ impl Client {
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0);
let _ = self.0.sender.unbounded_send(Request {
messages: RequestMessages::Single(buf),
messages: RequestMessages::Single(FrontendMessage::Raw(buf)),
sender,
idle: None,
});
@ -326,11 +331,11 @@ impl Client {
&self,
statement: &Statement,
params: &[&dyn ToSql],
) -> Result<Vec<u8>, Error> {
) -> Result<FrontendMessage, Error> {
let mut buf = self.bind_message(statement, "", params)?;
frontend::execute("", 0, &mut buf).map_err(Error::parse)?;
frontend::sync(&mut buf);
Ok(buf)
Ok(FrontendMessage::Raw(buf))
}
fn pending<F>(&self, messages: F) -> PendingRequest
@ -338,8 +343,11 @@ impl Client {
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
{
let mut buf = vec![];
PendingRequest(
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
)
PendingRequest(messages(&mut buf).map(|()| {
(
RequestMessages::Single(FrontendMessage::Raw(buf)),
self.0.idle.guard(),
)
}))
}
}

View File

@ -1,16 +1,26 @@
use bytes::BytesMut;
use bytes::{Buf, BytesMut};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend::CopyData;
use std::io;
use tokio_codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Vec<u8>),
CopyData(CopyData<Box<dyn Buf + Send>>),
}
pub struct PostgresCodec;
impl Encoder for PostgresCodec {
type Item = Vec<u8>;
type Item = FrontendMessage;
type Error = io::Error;
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
dst.extend_from_slice(&item);
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
FrontendMessage::CopyData(data) => data.write(dst),
}
Ok(())
}
}

View File

@ -11,7 +11,7 @@ use std::collections::HashMap;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture, FrontendMessage};
use crate::tls::ChannelBinding;
use crate::{Config, Error, TlsConnect};
@ -111,7 +111,7 @@ where
let stream = Framed::new(stream, PostgresCodec);
transition!(SendingStartup {
future: stream.send(buf),
future: stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
channel_binding,
@ -156,7 +156,7 @@ where
let mut buf = vec![];
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
})
@ -178,7 +178,7 @@ where
let mut buf = vec![];
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
})
@ -235,7 +235,7 @@ where
.map_err(Error::encode)?;
transition!(SendingSasl {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
scram,
config: state.config,
idx: state.idx,
@ -293,7 +293,7 @@ where
let mut buf = vec![];
frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?;
transition!(SendingSasl {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
scram: state.scram,
config: state.config,
idx: state.idx,

View File

@ -8,17 +8,17 @@ use std::io;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::codec::PostgresCodec;
use crate::proto::codec::{FrontendMessage, PostgresCodec};
use crate::proto::copy_in::CopyInReceiver;
use crate::proto::idle::IdleGuard;
use crate::{AsyncMessage, Notification};
use crate::{DbError, Error};
pub enum RequestMessages {
Single(Vec<u8>),
Single(FrontendMessage),
CopyIn {
receiver: CopyInReceiver,
pending_message: Option<Vec<u8>>,
pending_message: Option<FrontendMessage>,
},
}
@ -188,7 +188,7 @@ where
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
RequestMessages::Single(request)
RequestMessages::Single(FrontendMessage::Raw(request))
}
Async::Ready(None) => {
trace!(

View File

@ -1,20 +1,21 @@
use bytes::{Buf, IntoBuf};
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use futures::sink;
use futures::stream;
use futures::sync::mpsc;
use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::{self, CopyData};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::error::Error as StdError;
use std::mem;
use crate::proto::client::{Client, PendingRequest};
use crate::proto::codec::FrontendMessage;
use crate::proto::statement::Statement;
use crate::Error;
pub struct CopyMessage {
pub data: Vec<u8>,
pub done: bool,
pub enum CopyMessage {
Message(FrontendMessage),
Done,
}
pub struct CopyInReceiver {
@ -32,30 +33,29 @@ impl CopyInReceiver {
}
impl Stream for CopyInReceiver {
type Item = Vec<u8>;
type Item = FrontendMessage;
type Error = ();
fn poll(&mut self) -> Poll<Option<Vec<u8>>, ()> {
fn poll(&mut self) -> Poll<Option<FrontendMessage>, ()> {
if self.done {
return Ok(Async::Ready(None));
}
match self.receiver.poll()? {
Async::Ready(Some(mut data)) => {
if data.done {
self.done = true;
frontend::copy_done(&mut data.data);
frontend::sync(&mut data.data);
}
Ok(Async::Ready(Some(data.data)))
Async::Ready(Some(CopyMessage::Message(message))) => Ok(Async::Ready(Some(message))),
Async::Ready(Some(CopyMessage::Done)) => {
self.done = true;
let mut buf = vec![];
frontend::copy_done(&mut buf);
frontend::sync(&mut buf);
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
}
Async::Ready(None) => {
self.done = true;
let mut buf = vec![];
frontend::copy_fail("", &mut buf).unwrap();
frontend::sync(&mut buf);
Ok(Async::Ready(Some(buf)))
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
}
Async::NotReady => Ok(Async::NotReady),
}
@ -67,7 +67,7 @@ pub enum CopyIn<S>
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
#[state_machine_future(start, transitions(ReadCopyInResponse))]
@ -86,8 +86,8 @@ where
},
#[state_machine_future(transitions(WriteCopyDone))]
WriteCopyData {
stream: S,
buf: Vec<u8>,
stream: stream::Fuse<S>,
buf: BytesMut,
pending_message: Option<CopyMessage>,
sender: mpsc::Sender<CopyMessage>,
receiver: mpsc::Receiver<Message>,
@ -109,7 +109,7 @@ impl<S> PollCopyIn<S> for CopyIn<S>
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S>>) -> Poll<AfterStart<S>, Error> {
@ -135,8 +135,8 @@ where
Some(Message::CopyInResponse(_)) => {
let state = state.take();
transition!(WriteCopyData {
stream: state.stream,
buf: vec![],
stream: state.stream.fuse(),
buf: BytesMut::new(),
pending_message: None,
sender: state.sender,
receiver: state.receiver
@ -167,44 +167,51 @@ where
}
loop {
let done = loop {
let buf: Box<dyn Buf + Send> = loop {
match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) {
Some(data) => {
// FIXME avoid collect
frontend::copy_data(&data.into_buf().collect::<Vec<_>>(), &mut state.buf)
.map_err(Error::encode)?;
Some(buf) => {
let buf = buf.into_buf();
if buf.remaining() > 4096 {
if state.buf.is_empty() {
break Box::new(buf);
} else {
let cur_buf = state.buf.take().freeze().into_buf();
break Box::new(cur_buf.chain(buf));
}
}
state.buf.reserve(buf.remaining());
state.buf.put(buf);
if state.buf.len() > 4096 {
break false;
break Box::new(state.buf.take().freeze().into_buf());
}
}
None => break true,
None => break Box::new(state.buf.take().freeze().into_buf()),
}
};
let message = CopyMessage {
data: mem::replace(&mut state.buf, vec![]),
done,
};
if buf.has_remaining() {
let data = CopyData::new(buf).map_err(Error::encode)?;
let message = CopyMessage::Message(FrontendMessage::CopyData(data));
if done {
match state
.sender
.start_send(message)
.map_err(|_| Error::closed())?
{
AsyncSink::Ready => {}
AsyncSink::NotReady(message) => {
state.pending_message = Some(message);
return Ok(Async::NotReady);
}
}
} else {
let state = state.take();
transition!(WriteCopyDone {
future: state.sender.send(message),
future: state.sender.send(CopyMessage::Done),
receiver: state.receiver,
});
}
match state
.sender
.start_send(message)
.map_err(|_| Error::closed())?
{
AsyncSink::Ready => {}
AsyncSink::NotReady(message) => {
state.pending_message = Some(message);
return Ok(Async::NotReady);
}
}
}
}

View File

@ -53,7 +53,7 @@ pub use crate::proto::bind::BindFuture;
pub use crate::proto::cancel_query::CancelQueryFuture;
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
pub use crate::proto::client::Client;
pub use crate::proto::codec::PostgresCodec;
pub use crate::proto::codec::{FrontendMessage, PostgresCodec};
#[cfg(feature = "runtime")]
pub use crate::proto::connect::ConnectFuture;
#[cfg(feature = "runtime")]

View File

@ -4,6 +4,7 @@ use futures::sync::mpsc;
use futures::{future, stream, try_ready};
use log::debug;
use std::error::Error;
use std::fmt::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
@ -616,6 +617,49 @@ fn copy_in() {
assert_eq!(rows[1].get::<_, &str>(1), "joe");
}
#[test]
fn copy_in_large() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let a = "0\tname0\n".to_string();
let mut b = String::new();
for i in 1..5_000 {
writeln!(b, "{0}\tname{0}", i).unwrap();
}
let mut c = String::new();
for i in 5_000..10_000 {
writeln!(c, "{0}\tname{0}", i).unwrap();
}
let stream = stream::iter_ok::<_, String>(vec![a, b, c]);
let rows = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap();
assert_eq!(rows, 10_000);
}
#[test]
fn copy_in_error() {
let _ = env_logger::try_init();