Merge pull request #451 from sfackler/less-copy-copies

Avoid copies in copy_in
This commit is contained in:
Steven Fackler 2019-06-25 19:27:21 -07:00 committed by GitHub
commit eaef62c340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 180 additions and 75 deletions

View File

@ -2,6 +2,8 @@
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;
@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec<u8>) -> io::Result<()> {
})
}
pub struct CopyData<T> {
buf: T,
len: i32,
}
impl<T> CopyData<T>
where
T: Buf,
{
pub fn new<U>(buf: U) -> io::Result<CopyData<T>>
where
U: IntoBuf<Buf = T>,
{
let buf = buf.into_buf();
let len = buf
.remaining()
.checked_add(4)
.and_then(|l| i32::try_from(l).ok())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
})?;
Ok(CopyData { buf, len })
}
pub fn write(self, out: &mut BytesMut) {
out.reserve(self.len as usize + 1);
out.put_u8(b'd');
out.put_i32_be(self.len);
out.put(self.buf);
}
}
#[inline]
pub fn copy_done(buf: &mut Vec<u8>) {
buf.push(b'c');

View File

@ -170,7 +170,7 @@ pub struct CopyIn<S>(pub(crate) proto::CopyInFuture<S>)
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>;
impl<S> Future for CopyIn<S>

View File

@ -242,7 +242,7 @@ impl Client {
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
// FIXME error type?
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{

View File

@ -11,6 +11,7 @@ use std::sync::{Arc, Weak};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::bind::BindFuture;
use crate::proto::codec::FrontendMessage;
use crate::proto::connection::{Request, RequestMessages};
use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage};
use crate::proto::copy_out::CopyOutStream;
@ -185,8 +186,12 @@ impl Client {
if let Ok(ref mut buf) = buf {
frontend::sync(buf);
}
let pending =
PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard())));
let pending = PendingRequest(buf.map(|m| {
(
RequestMessages::Single(FrontendMessage::Raw(m)),
self.0.idle.guard(),
)
}));
BindFuture::new(self.clone(), pending, name, statement.clone())
}
@ -208,12 +213,12 @@ impl Client {
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
let (mut sender, receiver) = mpsc::channel(1);
let pending = PendingRequest(self.excecute_message(statement, params).map(|data| {
match sender.start_send(CopyMessage { data, done: false }) {
match sender.start_send(CopyMessage::Message(data)) {
Ok(AsyncSink::Ready) => {}
_ => unreachable!("channel should have capacity"),
}
@ -278,7 +283,7 @@ impl Client {
frontend::sync(&mut buf);
let (sender, _) = mpsc::channel(0);
let _ = self.0.sender.unbounded_send(Request {
messages: RequestMessages::Single(buf),
messages: RequestMessages::Single(FrontendMessage::Raw(buf)),
sender,
idle: None,
});
@ -326,11 +331,11 @@ impl Client {
&self,
statement: &Statement,
params: &[&dyn ToSql],
) -> Result<Vec<u8>, Error> {
) -> Result<FrontendMessage, Error> {
let mut buf = self.bind_message(statement, "", params)?;
frontend::execute("", 0, &mut buf).map_err(Error::parse)?;
frontend::sync(&mut buf);
Ok(buf)
Ok(FrontendMessage::Raw(buf))
}
fn pending<F>(&self, messages: F) -> PendingRequest
@ -338,8 +343,11 @@ impl Client {
F: FnOnce(&mut Vec<u8>) -> Result<(), Error>,
{
let mut buf = vec![];
PendingRequest(
messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())),
)
PendingRequest(messages(&mut buf).map(|()| {
(
RequestMessages::Single(FrontendMessage::Raw(buf)),
self.0.idle.guard(),
)
}))
}
}

View File

@ -1,16 +1,26 @@
use bytes::BytesMut;
use bytes::{Buf, BytesMut};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend::CopyData;
use std::io;
use tokio_codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Vec<u8>),
CopyData(CopyData<Box<dyn Buf + Send>>),
}
pub struct PostgresCodec;
impl Encoder for PostgresCodec {
type Item = Vec<u8>;
type Item = FrontendMessage;
type Error = io::Error;
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
dst.extend_from_slice(&item);
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
FrontendMessage::CopyData(data) => data.write(dst),
}
Ok(())
}
}

View File

@ -11,7 +11,7 @@ use std::collections::HashMap;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture};
use crate::proto::{Client, Connection, FrontendMessage, MaybeTlsStream, PostgresCodec, TlsFuture};
use crate::tls::ChannelBinding;
use crate::{Config, Error, TlsConnect};
@ -111,7 +111,7 @@ where
let stream = Framed::new(stream, PostgresCodec);
transition!(SendingStartup {
future: stream.send(buf),
future: stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
channel_binding,
@ -156,7 +156,7 @@ where
let mut buf = vec![];
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
})
@ -178,7 +178,7 @@ where
let mut buf = vec![];
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
config: state.config,
idx: state.idx,
})
@ -235,7 +235,7 @@ where
.map_err(Error::encode)?;
transition!(SendingSasl {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
scram,
config: state.config,
idx: state.idx,
@ -293,7 +293,7 @@ where
let mut buf = vec![];
frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?;
transition!(SendingSasl {
future: state.stream.send(buf),
future: state.stream.send(FrontendMessage::Raw(buf)),
scram: state.scram,
config: state.config,
idx: state.idx,

View File

@ -8,17 +8,17 @@ use std::io;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::codec::PostgresCodec;
use crate::proto::codec::{FrontendMessage, PostgresCodec};
use crate::proto::copy_in::CopyInReceiver;
use crate::proto::idle::IdleGuard;
use crate::{AsyncMessage, Notification};
use crate::{DbError, Error};
pub enum RequestMessages {
Single(Vec<u8>),
Single(FrontendMessage),
CopyIn {
receiver: CopyInReceiver,
pending_message: Option<Vec<u8>>,
pending_message: Option<FrontendMessage>,
},
}
@ -188,7 +188,7 @@ where
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
RequestMessages::Single(request)
RequestMessages::Single(FrontendMessage::Raw(request))
}
Async::Ready(None) => {
trace!(

View File

@ -1,20 +1,21 @@
use bytes::{Buf, IntoBuf};
use bytes::{Buf, BufMut, BytesMut, IntoBuf};
use futures::sink;
use futures::stream;
use futures::sync::mpsc;
use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream};
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::{self, CopyData};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::error::Error as StdError;
use std::mem;
use crate::proto::client::{Client, PendingRequest};
use crate::proto::codec::FrontendMessage;
use crate::proto::statement::Statement;
use crate::Error;
pub struct CopyMessage {
pub data: Vec<u8>,
pub done: bool,
pub enum CopyMessage {
Message(FrontendMessage),
Done,
}
pub struct CopyInReceiver {
@ -32,30 +33,29 @@ impl CopyInReceiver {
}
impl Stream for CopyInReceiver {
type Item = Vec<u8>;
type Item = FrontendMessage;
type Error = ();
fn poll(&mut self) -> Poll<Option<Vec<u8>>, ()> {
fn poll(&mut self) -> Poll<Option<FrontendMessage>, ()> {
if self.done {
return Ok(Async::Ready(None));
}
match self.receiver.poll()? {
Async::Ready(Some(mut data)) => {
if data.done {
self.done = true;
frontend::copy_done(&mut data.data);
frontend::sync(&mut data.data);
}
Ok(Async::Ready(Some(data.data)))
Async::Ready(Some(CopyMessage::Message(message))) => Ok(Async::Ready(Some(message))),
Async::Ready(Some(CopyMessage::Done)) => {
self.done = true;
let mut buf = vec![];
frontend::copy_done(&mut buf);
frontend::sync(&mut buf);
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
}
Async::Ready(None) => {
self.done = true;
let mut buf = vec![];
frontend::copy_fail("", &mut buf).unwrap();
frontend::sync(&mut buf);
Ok(Async::Ready(Some(buf)))
Ok(Async::Ready(Some(FrontendMessage::Raw(buf))))
}
Async::NotReady => Ok(Async::NotReady),
}
@ -67,7 +67,7 @@ pub enum CopyIn<S>
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
#[state_machine_future(start, transitions(ReadCopyInResponse))]
@ -86,8 +86,8 @@ where
},
#[state_machine_future(transitions(WriteCopyDone))]
WriteCopyData {
stream: S,
buf: Vec<u8>,
stream: stream::Fuse<S>,
buf: BytesMut,
pending_message: Option<CopyMessage>,
sender: mpsc::Sender<CopyMessage>,
receiver: mpsc::Receiver<Message>,
@ -109,7 +109,7 @@ impl<S> PollCopyIn<S> for CopyIn<S>
where
S: Stream,
S::Item: IntoBuf,
<S::Item as IntoBuf>::Buf: Send,
<S::Item as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn StdError + Sync + Send>>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<S>>) -> Poll<AfterStart<S>, Error> {
@ -135,8 +135,8 @@ where
Some(Message::CopyInResponse(_)) => {
let state = state.take();
transition!(WriteCopyData {
stream: state.stream,
buf: vec![],
stream: state.stream.fuse(),
buf: BytesMut::new(),
pending_message: None,
sender: state.sender,
receiver: state.receiver
@ -167,44 +167,51 @@ where
}
loop {
let done = loop {
let buf: Box<dyn Buf + Send> = loop {
match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) {
Some(data) => {
// FIXME avoid collect
frontend::copy_data(&data.into_buf().collect::<Vec<_>>(), &mut state.buf)
.map_err(Error::encode)?;
Some(buf) => {
let buf = buf.into_buf();
if buf.remaining() > 4096 {
if state.buf.is_empty() {
break Box::new(buf);
} else {
let cur_buf = state.buf.take().freeze().into_buf();
break Box::new(cur_buf.chain(buf));
}
}
state.buf.reserve(buf.remaining());
state.buf.put(buf);
if state.buf.len() > 4096 {
break false;
break Box::new(state.buf.take().freeze().into_buf());
}
}
None => break true,
None => break Box::new(state.buf.take().freeze().into_buf()),
}
};
let message = CopyMessage {
data: mem::replace(&mut state.buf, vec![]),
done,
};
if buf.has_remaining() {
let data = CopyData::new(buf).map_err(Error::encode)?;
let message = CopyMessage::Message(FrontendMessage::CopyData(data));
if done {
match state
.sender
.start_send(message)
.map_err(|_| Error::closed())?
{
AsyncSink::Ready => {}
AsyncSink::NotReady(message) => {
state.pending_message = Some(message);
return Ok(Async::NotReady);
}
}
} else {
let state = state.take();
transition!(WriteCopyDone {
future: state.sender.send(message),
future: state.sender.send(CopyMessage::Done),
receiver: state.receiver,
});
}
match state
.sender
.start_send(message)
.map_err(|_| Error::closed())?
{
AsyncSink::Ready => {}
AsyncSink::NotReady(message) => {
state.pending_message = Some(message);
return Ok(Async::NotReady);
}
}
}
}

View File

@ -53,7 +53,7 @@ pub use crate::proto::bind::BindFuture;
pub use crate::proto::cancel_query::CancelQueryFuture;
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
pub use crate::proto::client::Client;
pub use crate::proto::codec::PostgresCodec;
pub use crate::proto::codec::{FrontendMessage, PostgresCodec};
#[cfg(feature = "runtime")]
pub use crate::proto::connect::ConnectFuture;
#[cfg(feature = "runtime")]

View File

@ -4,6 +4,7 @@ use futures::sync::mpsc;
use futures::{future, stream, try_ready};
use log::debug;
use std::error::Error;
use std::fmt::Write;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
@ -616,6 +617,49 @@ fn copy_in() {
assert_eq!(rows[1].get::<_, &str>(1), "joe");
}
#[test]
fn copy_in_large() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
runtime
.block_on(
client
.simple_query(
"CREATE TEMPORARY TABLE foo (
id INTEGER,
name TEXT
)",
)
.for_each(|_| Ok(())),
)
.unwrap();
let a = "0\tname0\n".to_string();
let mut b = String::new();
for i in 1..5_000 {
writeln!(b, "{0}\tname{0}", i).unwrap();
}
let mut c = String::new();
for i in 5_000..10_000 {
writeln!(c, "{0}\tname{0}", i).unwrap();
}
let stream = stream::iter_ok::<_, String>(vec![a, b, c]);
let rows = runtime
.block_on(
client
.prepare("COPY foo FROM STDIN")
.and_then(|s| client.copy_in(&s, &[], stream)),
)
.unwrap();
assert_eq!(rows, 10_000);
}
#[test]
fn copy_in_error() {
let _ = env_logger::try_init();