From db462eb0180df8561a2a8e62c7d4e2d8b1efde28 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Tue, 25 Jun 2019 18:54:17 -0700 Subject: [PATCH] Avoid copies in copy_in copy_in data was previously copied ~3 times - once into the copy_in buffer, once more to frame it into a CopyData frame, and once to write that into the stream. Our Codec is now a bit more interesting. Rather than just writing out pre-encoded data, we can also send along unencoded CopyData so they can be framed directly into the stream output buffer. In the future we can extend this to e.g. avoid allocating for simple commands like Sync. This also allows us to directly pass large copy_in input directly through without rebuffering it. --- postgres-protocol/src/message/frontend.rs | 36 ++++++++ tokio-postgres/src/impls.rs | 2 +- tokio-postgres/src/lib.rs | 2 +- tokio-postgres/src/proto/client.rs | 28 +++--- tokio-postgres/src/proto/codec.rs | 18 +++- tokio-postgres/src/proto/connect_raw.rs | 12 +-- tokio-postgres/src/proto/connection.rs | 8 +- tokio-postgres/src/proto/copy_in.rs | 103 ++++++++++++---------- tokio-postgres/src/proto/mod.rs | 2 +- tokio-postgres/tests/test/main.rs | 44 +++++++++ 10 files changed, 180 insertions(+), 75 deletions(-) diff --git a/postgres-protocol/src/message/frontend.rs b/postgres-protocol/src/message/frontend.rs index 0ff0ddb6..6daed7d9 100644 --- a/postgres-protocol/src/message/frontend.rs +++ b/postgres-protocol/src/message/frontend.rs @@ -2,6 +2,8 @@ #![allow(missing_docs)] use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; +use bytes::{Buf, BufMut, BytesMut, IntoBuf}; +use std::convert::TryFrom; use std::error::Error; use std::io; use std::marker; @@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec) -> io::Result<()> { }) } +pub struct CopyData { + buf: T, + len: i32, +} + +impl CopyData +where + T: Buf, +{ + pub fn new(buf: U) -> io::Result> + where + U: IntoBuf, + { + let buf = buf.into_buf(); + + let len = buf + .remaining() + .checked_add(4) + .and_then(|l| i32::try_from(l).ok()) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "message length overflow") + })?; + + Ok(CopyData { buf, len }) + } + + pub fn write(self, out: &mut BytesMut) { + out.reserve(self.len as usize + 1); + out.put_u8(b'd'); + out.put_i32_be(self.len); + out.put(self.buf); + } +} + #[inline] pub fn copy_done(buf: &mut Vec) { buf.push(b'c'); diff --git a/tokio-postgres/src/impls.rs b/tokio-postgres/src/impls.rs index 546c3d7d..520a4aff 100644 --- a/tokio-postgres/src/impls.rs +++ b/tokio-postgres/src/impls.rs @@ -170,7 +170,7 @@ pub struct CopyIn(pub(crate) proto::CopyInFuture) where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>; impl Future for CopyIn diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 060e0c5d..78c65331 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -242,7 +242,7 @@ impl Client { where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, // FIXME error type? S::Error: Into>, { diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index a328e90d..abb2f534 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -11,6 +11,7 @@ use std::sync::{Arc, Weak}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::bind::BindFuture; +use crate::proto::codec::FrontendMessage; use crate::proto::connection::{Request, RequestMessages}; use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage}; use crate::proto::copy_out::CopyOutStream; @@ -185,8 +186,12 @@ impl Client { if let Ok(ref mut buf) = buf { frontend::sync(buf); } - let pending = - PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard()))); + let pending = PendingRequest(buf.map(|m| { + ( + RequestMessages::Single(FrontendMessage::Raw(m)), + self.0.idle.guard(), + ) + })); BindFuture::new(self.clone(), pending, name, statement.clone()) } @@ -208,12 +213,12 @@ impl Client { where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>, { let (mut sender, receiver) = mpsc::channel(1); let pending = PendingRequest(self.excecute_message(statement, params).map(|data| { - match sender.start_send(CopyMessage { data, done: false }) { + match sender.start_send(CopyMessage::Message(data)) { Ok(AsyncSink::Ready) => {} _ => unreachable!("channel should have capacity"), } @@ -278,7 +283,7 @@ impl Client { frontend::sync(&mut buf); let (sender, _) = mpsc::channel(0); let _ = self.0.sender.unbounded_send(Request { - messages: RequestMessages::Single(buf), + messages: RequestMessages::Single(FrontendMessage::Raw(buf)), sender, idle: None, }); @@ -326,11 +331,11 @@ impl Client { &self, statement: &Statement, params: &[&dyn ToSql], - ) -> Result, Error> { + ) -> Result { let mut buf = self.bind_message(statement, "", params)?; frontend::execute("", 0, &mut buf).map_err(Error::parse)?; frontend::sync(&mut buf); - Ok(buf) + Ok(FrontendMessage::Raw(buf)) } fn pending(&self, messages: F) -> PendingRequest @@ -338,8 +343,11 @@ impl Client { F: FnOnce(&mut Vec) -> Result<(), Error>, { let mut buf = vec![]; - PendingRequest( - messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())), - ) + PendingRequest(messages(&mut buf).map(|()| { + ( + RequestMessages::Single(FrontendMessage::Raw(buf)), + self.0.idle.guard(), + ) + })) } } diff --git a/tokio-postgres/src/proto/codec.rs b/tokio-postgres/src/proto/codec.rs index 4e37ab60..c7c6d904 100644 --- a/tokio-postgres/src/proto/codec.rs +++ b/tokio-postgres/src/proto/codec.rs @@ -1,16 +1,26 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use postgres_protocol::message::backend; +use postgres_protocol::message::frontend::CopyData; use std::io; use tokio_codec::{Decoder, Encoder}; +pub enum FrontendMessage { + Raw(Vec), + CopyData(CopyData>), +} + pub struct PostgresCodec; impl Encoder for PostgresCodec { - type Item = Vec; + type Item = FrontendMessage; type Error = io::Error; - fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), io::Error> { - dst.extend_from_slice(&item); + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> { + match item { + FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), + FrontendMessage::CopyData(data) => data.write(dst), + } + Ok(()) } } diff --git a/tokio-postgres/src/proto/connect_raw.rs b/tokio-postgres/src/proto/connect_raw.rs index 45b2c3e5..85368cd9 100644 --- a/tokio-postgres/src/proto/connect_raw.rs +++ b/tokio-postgres/src/proto/connect_raw.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture}; +use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture, FrontendMessage}; use crate::tls::ChannelBinding; use crate::{Config, Error, TlsConnect}; @@ -111,7 +111,7 @@ where let stream = Framed::new(stream, PostgresCodec); transition!(SendingStartup { - future: stream.send(buf), + future: stream.send(FrontendMessage::Raw(buf)), config: state.config, idx: state.idx, channel_binding, @@ -156,7 +156,7 @@ where let mut buf = vec![]; frontend::password_message(pass, &mut buf).map_err(Error::encode)?; transition!(SendingPassword { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), config: state.config, idx: state.idx, }) @@ -178,7 +178,7 @@ where let mut buf = vec![]; frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?; transition!(SendingPassword { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), config: state.config, idx: state.idx, }) @@ -235,7 +235,7 @@ where .map_err(Error::encode)?; transition!(SendingSasl { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), scram, config: state.config, idx: state.idx, @@ -293,7 +293,7 @@ where let mut buf = vec![]; frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?; transition!(SendingSasl { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), scram: state.scram, config: state.config, idx: state.idx, diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index dd9f30fe..14559fd0 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -8,17 +8,17 @@ use std::io; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::codec::PostgresCodec; +use crate::proto::codec::{FrontendMessage, PostgresCodec}; use crate::proto::copy_in::CopyInReceiver; use crate::proto::idle::IdleGuard; use crate::{AsyncMessage, Notification}; use crate::{DbError, Error}; pub enum RequestMessages { - Single(Vec), + Single(FrontendMessage), CopyIn { receiver: CopyInReceiver, - pending_message: Option>, + pending_message: Option, }, } @@ -188,7 +188,7 @@ where self.state = State::Terminating; let mut request = vec![]; frontend::terminate(&mut request); - RequestMessages::Single(request) + RequestMessages::Single(FrontendMessage::Raw(request)) } Async::Ready(None) => { trace!( diff --git a/tokio-postgres/src/proto/copy_in.rs b/tokio-postgres/src/proto/copy_in.rs index 33018aa1..80f09c52 100644 --- a/tokio-postgres/src/proto/copy_in.rs +++ b/tokio-postgres/src/proto/copy_in.rs @@ -1,20 +1,21 @@ -use bytes::{Buf, IntoBuf}; +use bytes::{Buf, BufMut, BytesMut, IntoBuf}; use futures::sink; +use futures::stream; use futures::sync::mpsc; use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream}; use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::{self, CopyData}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::error::Error as StdError; -use std::mem; use crate::proto::client::{Client, PendingRequest}; +use crate::proto::codec::FrontendMessage; use crate::proto::statement::Statement; use crate::Error; -pub struct CopyMessage { - pub data: Vec, - pub done: bool, +pub enum CopyMessage { + Message(FrontendMessage), + Done, } pub struct CopyInReceiver { @@ -32,30 +33,29 @@ impl CopyInReceiver { } impl Stream for CopyInReceiver { - type Item = Vec; + type Item = FrontendMessage; type Error = (); - fn poll(&mut self) -> Poll>, ()> { + fn poll(&mut self) -> Poll, ()> { if self.done { return Ok(Async::Ready(None)); } match self.receiver.poll()? { - Async::Ready(Some(mut data)) => { - if data.done { - self.done = true; - frontend::copy_done(&mut data.data); - frontend::sync(&mut data.data); - } - - Ok(Async::Ready(Some(data.data))) + Async::Ready(Some(CopyMessage::Message(message))) => Ok(Async::Ready(Some(message))), + Async::Ready(Some(CopyMessage::Done)) => { + self.done = true; + let mut buf = vec![]; + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Ok(Async::Ready(Some(FrontendMessage::Raw(buf)))) } Async::Ready(None) => { self.done = true; let mut buf = vec![]; frontend::copy_fail("", &mut buf).unwrap(); frontend::sync(&mut buf); - Ok(Async::Ready(Some(buf))) + Ok(Async::Ready(Some(FrontendMessage::Raw(buf)))) } Async::NotReady => Ok(Async::NotReady), } @@ -67,7 +67,7 @@ pub enum CopyIn where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>, { #[state_machine_future(start, transitions(ReadCopyInResponse))] @@ -86,8 +86,8 @@ where }, #[state_machine_future(transitions(WriteCopyDone))] WriteCopyData { - stream: S, - buf: Vec, + stream: stream::Fuse, + buf: BytesMut, pending_message: Option, sender: mpsc::Sender, receiver: mpsc::Receiver, @@ -109,7 +109,7 @@ impl PollCopyIn for CopyIn where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { @@ -135,8 +135,8 @@ where Some(Message::CopyInResponse(_)) => { let state = state.take(); transition!(WriteCopyData { - stream: state.stream, - buf: vec![], + stream: state.stream.fuse(), + buf: BytesMut::new(), pending_message: None, sender: state.sender, receiver: state.receiver @@ -167,44 +167,51 @@ where } loop { - let done = loop { + let buf: Box = loop { match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) { - Some(data) => { - // FIXME avoid collect - frontend::copy_data(&data.into_buf().collect::>(), &mut state.buf) - .map_err(Error::encode)?; + Some(buf) => { + let buf = buf.into_buf(); + if buf.remaining() > 4096 { + if state.buf.is_empty() { + break Box::new(buf); + } else { + let cur_buf = state.buf.take().freeze().into_buf(); + break Box::new(cur_buf.chain(buf)); + } + } + + state.buf.reserve(buf.remaining()); + state.buf.put(buf); if state.buf.len() > 4096 { - break false; + break Box::new(state.buf.take().freeze().into_buf()); } } - None => break true, + None => break Box::new(state.buf.take().freeze().into_buf()), } }; - let message = CopyMessage { - data: mem::replace(&mut state.buf, vec![]), - done, - }; + if buf.has_remaining() { + let data = CopyData::new(buf).map_err(Error::encode)?; + let message = CopyMessage::Message(FrontendMessage::CopyData(data)); - if done { + match state + .sender + .start_send(message) + .map_err(|_| Error::closed())? + { + AsyncSink::Ready => {} + AsyncSink::NotReady(message) => { + state.pending_message = Some(message); + return Ok(Async::NotReady); + } + } + } else { let state = state.take(); transition!(WriteCopyDone { - future: state.sender.send(message), + future: state.sender.send(CopyMessage::Done), receiver: state.receiver, }); } - - match state - .sender - .start_send(message) - .map_err(|_| Error::closed())? - { - AsyncSink::Ready => {} - AsyncSink::NotReady(message) => { - state.pending_message = Some(message); - return Ok(Async::NotReady); - } - } } } diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 7667901c..7c30cda8 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -53,7 +53,7 @@ pub use crate::proto::bind::BindFuture; pub use crate::proto::cancel_query::CancelQueryFuture; pub use crate::proto::cancel_query_raw::CancelQueryRawFuture; pub use crate::proto::client::Client; -pub use crate::proto::codec::PostgresCodec; +pub use crate::proto::codec::{FrontendMessage, PostgresCodec}; #[cfg(feature = "runtime")] pub use crate::proto::connect::ConnectFuture; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 7930b1b9..06a6d738 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -4,6 +4,7 @@ use futures::sync::mpsc; use futures::{future, stream, try_ready}; use log::debug; use std::error::Error; +use std::fmt::Write; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, Instant}; use tokio::net::TcpStream; @@ -616,6 +617,49 @@ fn copy_in() { assert_eq!(rows[1].get::<_, &str>(1), "joe"); } +#[test] +fn copy_in_large() { + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + runtime + .block_on( + client + .simple_query( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + ) + .for_each(|_| Ok(())), + ) + .unwrap(); + + let a = "0\tname0\n".to_string(); + let mut b = String::new(); + for i in 1..5_000 { + writeln!(b, "{0}\tname{0}", i).unwrap(); + } + let mut c = String::new(); + for i in 5_000..10_000 { + writeln!(c, "{0}\tname{0}", i).unwrap(); + } + + let stream = stream::iter_ok::<_, String>(vec![a, b, c]); + let rows = runtime + .block_on( + client + .prepare("COPY foo FROM STDIN") + .and_then(|s| client.copy_in(&s, &[], stream)), + ) + .unwrap(); + assert_eq!(rows, 10_000); +} + #[test] fn copy_in_error() { let _ = env_logger::try_init();