diff --git a/postgres-protocol/src/message/frontend.rs b/postgres-protocol/src/message/frontend.rs index 0ff0ddb6..6daed7d9 100644 --- a/postgres-protocol/src/message/frontend.rs +++ b/postgres-protocol/src/message/frontend.rs @@ -2,6 +2,8 @@ #![allow(missing_docs)] use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; +use bytes::{Buf, BufMut, BytesMut, IntoBuf}; +use std::convert::TryFrom; use std::error::Error; use std::io; use std::marker; @@ -263,6 +265,40 @@ pub fn copy_data(data: &[u8], buf: &mut Vec) -> io::Result<()> { }) } +pub struct CopyData { + buf: T, + len: i32, +} + +impl CopyData +where + T: Buf, +{ + pub fn new(buf: U) -> io::Result> + where + U: IntoBuf, + { + let buf = buf.into_buf(); + + let len = buf + .remaining() + .checked_add(4) + .and_then(|l| i32::try_from(l).ok()) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "message length overflow") + })?; + + Ok(CopyData { buf, len }) + } + + pub fn write(self, out: &mut BytesMut) { + out.reserve(self.len as usize + 1); + out.put_u8(b'd'); + out.put_i32_be(self.len); + out.put(self.buf); + } +} + #[inline] pub fn copy_done(buf: &mut Vec) { buf.push(b'c'); diff --git a/tokio-postgres/src/impls.rs b/tokio-postgres/src/impls.rs index 546c3d7d..520a4aff 100644 --- a/tokio-postgres/src/impls.rs +++ b/tokio-postgres/src/impls.rs @@ -170,7 +170,7 @@ pub struct CopyIn(pub(crate) proto::CopyInFuture) where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>; impl Future for CopyIn diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 060e0c5d..78c65331 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -242,7 +242,7 @@ impl Client { where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, // FIXME error type? S::Error: Into>, { diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index a328e90d..abb2f534 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -11,6 +11,7 @@ use std::sync::{Arc, Weak}; use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::bind::BindFuture; +use crate::proto::codec::FrontendMessage; use crate::proto::connection::{Request, RequestMessages}; use crate::proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage}; use crate::proto::copy_out::CopyOutStream; @@ -185,8 +186,12 @@ impl Client { if let Ok(ref mut buf) = buf { frontend::sync(buf); } - let pending = - PendingRequest(buf.map(|m| (RequestMessages::Single(m), self.0.idle.guard()))); + let pending = PendingRequest(buf.map(|m| { + ( + RequestMessages::Single(FrontendMessage::Raw(m)), + self.0.idle.guard(), + ) + })); BindFuture::new(self.clone(), pending, name, statement.clone()) } @@ -208,12 +213,12 @@ impl Client { where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>, { let (mut sender, receiver) = mpsc::channel(1); let pending = PendingRequest(self.excecute_message(statement, params).map(|data| { - match sender.start_send(CopyMessage { data, done: false }) { + match sender.start_send(CopyMessage::Message(data)) { Ok(AsyncSink::Ready) => {} _ => unreachable!("channel should have capacity"), } @@ -278,7 +283,7 @@ impl Client { frontend::sync(&mut buf); let (sender, _) = mpsc::channel(0); let _ = self.0.sender.unbounded_send(Request { - messages: RequestMessages::Single(buf), + messages: RequestMessages::Single(FrontendMessage::Raw(buf)), sender, idle: None, }); @@ -326,11 +331,11 @@ impl Client { &self, statement: &Statement, params: &[&dyn ToSql], - ) -> Result, Error> { + ) -> Result { let mut buf = self.bind_message(statement, "", params)?; frontend::execute("", 0, &mut buf).map_err(Error::parse)?; frontend::sync(&mut buf); - Ok(buf) + Ok(FrontendMessage::Raw(buf)) } fn pending(&self, messages: F) -> PendingRequest @@ -338,8 +343,11 @@ impl Client { F: FnOnce(&mut Vec) -> Result<(), Error>, { let mut buf = vec![]; - PendingRequest( - messages(&mut buf).map(|()| (RequestMessages::Single(buf), self.0.idle.guard())), - ) + PendingRequest(messages(&mut buf).map(|()| { + ( + RequestMessages::Single(FrontendMessage::Raw(buf)), + self.0.idle.guard(), + ) + })) } } diff --git a/tokio-postgres/src/proto/codec.rs b/tokio-postgres/src/proto/codec.rs index 4e37ab60..c7c6d904 100644 --- a/tokio-postgres/src/proto/codec.rs +++ b/tokio-postgres/src/proto/codec.rs @@ -1,16 +1,26 @@ -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use postgres_protocol::message::backend; +use postgres_protocol::message::frontend::CopyData; use std::io; use tokio_codec::{Decoder, Encoder}; +pub enum FrontendMessage { + Raw(Vec), + CopyData(CopyData>), +} + pub struct PostgresCodec; impl Encoder for PostgresCodec { - type Item = Vec; + type Item = FrontendMessage; type Error = io::Error; - fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), io::Error> { - dst.extend_from_slice(&item); + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> Result<(), io::Error> { + match item { + FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), + FrontendMessage::CopyData(data) => data.write(dst), + } + Ok(()) } } diff --git a/tokio-postgres/src/proto/connect_raw.rs b/tokio-postgres/src/proto/connect_raw.rs index 45b2c3e5..85368cd9 100644 --- a/tokio-postgres/src/proto/connect_raw.rs +++ b/tokio-postgres/src/proto/connect_raw.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture}; +use crate::proto::{Client, Connection, MaybeTlsStream, PostgresCodec, TlsFuture, FrontendMessage}; use crate::tls::ChannelBinding; use crate::{Config, Error, TlsConnect}; @@ -111,7 +111,7 @@ where let stream = Framed::new(stream, PostgresCodec); transition!(SendingStartup { - future: stream.send(buf), + future: stream.send(FrontendMessage::Raw(buf)), config: state.config, idx: state.idx, channel_binding, @@ -156,7 +156,7 @@ where let mut buf = vec![]; frontend::password_message(pass, &mut buf).map_err(Error::encode)?; transition!(SendingPassword { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), config: state.config, idx: state.idx, }) @@ -178,7 +178,7 @@ where let mut buf = vec![]; frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?; transition!(SendingPassword { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), config: state.config, idx: state.idx, }) @@ -235,7 +235,7 @@ where .map_err(Error::encode)?; transition!(SendingSasl { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), scram, config: state.config, idx: state.idx, @@ -293,7 +293,7 @@ where let mut buf = vec![]; frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?; transition!(SendingSasl { - future: state.stream.send(buf), + future: state.stream.send(FrontendMessage::Raw(buf)), scram: state.scram, config: state.config, idx: state.idx, diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index dd9f30fe..14559fd0 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -8,17 +8,17 @@ use std::io; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::codec::PostgresCodec; +use crate::proto::codec::{FrontendMessage, PostgresCodec}; use crate::proto::copy_in::CopyInReceiver; use crate::proto::idle::IdleGuard; use crate::{AsyncMessage, Notification}; use crate::{DbError, Error}; pub enum RequestMessages { - Single(Vec), + Single(FrontendMessage), CopyIn { receiver: CopyInReceiver, - pending_message: Option>, + pending_message: Option, }, } @@ -188,7 +188,7 @@ where self.state = State::Terminating; let mut request = vec![]; frontend::terminate(&mut request); - RequestMessages::Single(request) + RequestMessages::Single(FrontendMessage::Raw(request)) } Async::Ready(None) => { trace!( diff --git a/tokio-postgres/src/proto/copy_in.rs b/tokio-postgres/src/proto/copy_in.rs index 33018aa1..80f09c52 100644 --- a/tokio-postgres/src/proto/copy_in.rs +++ b/tokio-postgres/src/proto/copy_in.rs @@ -1,20 +1,21 @@ -use bytes::{Buf, IntoBuf}; +use bytes::{Buf, BufMut, BytesMut, IntoBuf}; use futures::sink; +use futures::stream; use futures::sync::mpsc; use futures::{try_ready, Async, AsyncSink, Future, Poll, Sink, Stream}; use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::{self, CopyData}; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::error::Error as StdError; -use std::mem; use crate::proto::client::{Client, PendingRequest}; +use crate::proto::codec::FrontendMessage; use crate::proto::statement::Statement; use crate::Error; -pub struct CopyMessage { - pub data: Vec, - pub done: bool, +pub enum CopyMessage { + Message(FrontendMessage), + Done, } pub struct CopyInReceiver { @@ -32,30 +33,29 @@ impl CopyInReceiver { } impl Stream for CopyInReceiver { - type Item = Vec; + type Item = FrontendMessage; type Error = (); - fn poll(&mut self) -> Poll>, ()> { + fn poll(&mut self) -> Poll, ()> { if self.done { return Ok(Async::Ready(None)); } match self.receiver.poll()? { - Async::Ready(Some(mut data)) => { - if data.done { - self.done = true; - frontend::copy_done(&mut data.data); - frontend::sync(&mut data.data); - } - - Ok(Async::Ready(Some(data.data))) + Async::Ready(Some(CopyMessage::Message(message))) => Ok(Async::Ready(Some(message))), + Async::Ready(Some(CopyMessage::Done)) => { + self.done = true; + let mut buf = vec![]; + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Ok(Async::Ready(Some(FrontendMessage::Raw(buf)))) } Async::Ready(None) => { self.done = true; let mut buf = vec![]; frontend::copy_fail("", &mut buf).unwrap(); frontend::sync(&mut buf); - Ok(Async::Ready(Some(buf))) + Ok(Async::Ready(Some(FrontendMessage::Raw(buf)))) } Async::NotReady => Ok(Async::NotReady), } @@ -67,7 +67,7 @@ pub enum CopyIn where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>, { #[state_machine_future(start, transitions(ReadCopyInResponse))] @@ -86,8 +86,8 @@ where }, #[state_machine_future(transitions(WriteCopyDone))] WriteCopyData { - stream: S, - buf: Vec, + stream: stream::Fuse, + buf: BytesMut, pending_message: Option, sender: mpsc::Sender, receiver: mpsc::Receiver, @@ -109,7 +109,7 @@ impl PollCopyIn for CopyIn where S: Stream, S::Item: IntoBuf, - ::Buf: Send, + ::Buf: 'static + Send, S::Error: Into>, { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { @@ -135,8 +135,8 @@ where Some(Message::CopyInResponse(_)) => { let state = state.take(); transition!(WriteCopyData { - stream: state.stream, - buf: vec![], + stream: state.stream.fuse(), + buf: BytesMut::new(), pending_message: None, sender: state.sender, receiver: state.receiver @@ -167,44 +167,51 @@ where } loop { - let done = loop { + let buf: Box = loop { match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) { - Some(data) => { - // FIXME avoid collect - frontend::copy_data(&data.into_buf().collect::>(), &mut state.buf) - .map_err(Error::encode)?; + Some(buf) => { + let buf = buf.into_buf(); + if buf.remaining() > 4096 { + if state.buf.is_empty() { + break Box::new(buf); + } else { + let cur_buf = state.buf.take().freeze().into_buf(); + break Box::new(cur_buf.chain(buf)); + } + } + + state.buf.reserve(buf.remaining()); + state.buf.put(buf); if state.buf.len() > 4096 { - break false; + break Box::new(state.buf.take().freeze().into_buf()); } } - None => break true, + None => break Box::new(state.buf.take().freeze().into_buf()), } }; - let message = CopyMessage { - data: mem::replace(&mut state.buf, vec![]), - done, - }; + if buf.has_remaining() { + let data = CopyData::new(buf).map_err(Error::encode)?; + let message = CopyMessage::Message(FrontendMessage::CopyData(data)); - if done { + match state + .sender + .start_send(message) + .map_err(|_| Error::closed())? + { + AsyncSink::Ready => {} + AsyncSink::NotReady(message) => { + state.pending_message = Some(message); + return Ok(Async::NotReady); + } + } + } else { let state = state.take(); transition!(WriteCopyDone { - future: state.sender.send(message), + future: state.sender.send(CopyMessage::Done), receiver: state.receiver, }); } - - match state - .sender - .start_send(message) - .map_err(|_| Error::closed())? - { - AsyncSink::Ready => {} - AsyncSink::NotReady(message) => { - state.pending_message = Some(message); - return Ok(Async::NotReady); - } - } } } diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index 7667901c..7c30cda8 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -53,7 +53,7 @@ pub use crate::proto::bind::BindFuture; pub use crate::proto::cancel_query::CancelQueryFuture; pub use crate::proto::cancel_query_raw::CancelQueryRawFuture; pub use crate::proto::client::Client; -pub use crate::proto::codec::PostgresCodec; +pub use crate::proto::codec::{FrontendMessage, PostgresCodec}; #[cfg(feature = "runtime")] pub use crate::proto::connect::ConnectFuture; #[cfg(feature = "runtime")] diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 7930b1b9..06a6d738 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -4,6 +4,7 @@ use futures::sync::mpsc; use futures::{future, stream, try_ready}; use log::debug; use std::error::Error; +use std::fmt::Write; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, Instant}; use tokio::net::TcpStream; @@ -616,6 +617,49 @@ fn copy_in() { assert_eq!(rows[1].get::<_, &str>(1), "joe"); } +#[test] +fn copy_in_large() { + let _ = env_logger::try_init(); + let mut runtime = Runtime::new().unwrap(); + + let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.handle().spawn(connection).unwrap(); + + runtime + .block_on( + client + .simple_query( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + ) + .for_each(|_| Ok(())), + ) + .unwrap(); + + let a = "0\tname0\n".to_string(); + let mut b = String::new(); + for i in 1..5_000 { + writeln!(b, "{0}\tname{0}", i).unwrap(); + } + let mut c = String::new(); + for i in 5_000..10_000 { + writeln!(c, "{0}\tname{0}", i).unwrap(); + } + + let stream = stream::iter_ok::<_, String>(vec![a, b, c]); + let rows = runtime + .block_on( + client + .prepare("COPY foo FROM STDIN") + .and_then(|s| client.copy_in(&s, &[], stream)), + ) + .unwrap(); + assert_eq!(rows, 10_000); +} + #[test] fn copy_in_error() { let _ = env_logger::try_init();