diff --git a/postgres/src/client.rs b/postgres/src/client.rs index e680f8c5..4047f1b1 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -1,19 +1,15 @@ +use crate::iter::Iter; +#[cfg(feature = "runtime")] +use crate::Config; +use crate::{CopyInWriter, CopyOutReader, Statement, ToStatement, Transaction}; use fallible_iterator::FallibleIterator; use futures::executor; -use std::io::{BufRead, Read}; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::types::{ToSql, Type}; #[cfg(feature = "runtime")] use tokio_postgres::Socket; use tokio_postgres::{Error, Row, SimpleQueryMessage}; -use crate::copy_in_stream::CopyInStream; -use crate::copy_out_reader::CopyOutReader; -use crate::iter::Iter; -#[cfg(feature = "runtime")] -use crate::Config; -use crate::{Statement, ToStatement, Transaction}; - /// A synchronous PostgreSQL client. /// /// This is a lightweight wrapper over the asynchronous tokio_postgres `Client`. @@ -264,29 +260,33 @@ impl Client { /// The `query` argument can either be a `Statement`, or a raw query string. The data in the provided reader is /// passed along to the server verbatim; it is the caller's responsibility to ensure it uses the proper format. /// + /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. + /// /// # Examples /// /// ```no_run /// use postgres::{Client, NoTls}; + /// use std::io::Write; /// - /// # fn main() -> Result<(), postgres::Error> { + /// # fn main() -> Result<(), Box> { /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; /// - /// client.copy_in("COPY people FROM stdin", &[], &mut "1\tjohn\n2\tjane\n".as_bytes())?; + /// let mut writer = client.copy_in("COPY people FROM stdin", &[])?; + /// writer.write_all(b"1\tjohn\n2\tjane\n")?; + /// writer.finish()?; /// # Ok(()) /// # } /// ``` - pub fn copy_in( + pub fn copy_in( &mut self, query: &T, params: &[&(dyn ToSql + Sync)], - reader: R, - ) -> Result + ) -> Result, Error> where T: ?Sized + ToStatement, - R: Read + Unpin, { - executor::block_on(self.0.copy_in(query, params, CopyInStream(reader))) + let sink = executor::block_on(self.0.copy_in(query, params))?; + Ok(CopyInWriter::new(sink)) } /// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data. @@ -312,7 +312,7 @@ impl Client { &mut self, query: &T, params: &[&(dyn ToSql + Sync)], - ) -> Result + ) -> Result, Error> where T: ?Sized + ToStatement, { diff --git a/postgres/src/copy_in_stream.rs b/postgres/src/copy_in_stream.rs deleted file mode 100644 index 6bda3e5d..00000000 --- a/postgres/src/copy_in_stream.rs +++ /dev/null @@ -1,24 +0,0 @@ -use futures::Stream; -use std::io::{self, Cursor, Read}; -use std::pin::Pin; -use std::task::{Context, Poll}; - -pub struct CopyInStream(pub R); - -impl Stream for CopyInStream -where - R: Read + Unpin, -{ - type Item = io::Result>>; - - fn poll_next( - mut self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>>>> { - let mut buf = vec![]; - match self.0.by_ref().take(4096).read_to_end(&mut buf)? { - 0 => Poll::Ready(None), - _ => Poll::Ready(Some(Ok(Cursor::new(buf)))), - } - } -} diff --git a/postgres/src/copy_in_writer.rs b/postgres/src/copy_in_writer.rs index e69de29b..b7a2a009 100644 --- a/postgres/src/copy_in_writer.rs +++ b/postgres/src/copy_in_writer.rs @@ -0,0 +1,63 @@ +use bytes::{Bytes, BytesMut}; +use futures::{executor, SinkExt}; +use std::io; +use std::io::Write; +use std::marker::PhantomData; +use std::pin::Pin; +use tokio_postgres::{CopyInSink, Error}; + +/// The writer returned by the `copy_in` method. +/// +/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. +pub struct CopyInWriter<'a> { + sink: Pin>>, + buf: BytesMut, + _p: PhantomData<&'a mut ()>, +} + +// no-op impl to extend borrow until drop +impl Drop for CopyInWriter<'_> { + fn drop(&mut self) {} +} + +impl<'a> CopyInWriter<'a> { + pub(crate) fn new(sink: CopyInSink) -> CopyInWriter<'a> { + CopyInWriter { + sink: Box::pin(sink), + buf: BytesMut::new(), + _p: PhantomData, + } + } + + /// Completes the copy, returning the number of rows written. + /// + /// If this is not called, the copy will be aborted. + pub fn finish(mut self) -> Result { + self.flush_inner()?; + executor::block_on(self.sink.as_mut().finish()) + } + + fn flush_inner(&mut self) -> Result<(), Error> { + if self.buf.is_empty() { + return Ok(()); + } + + executor::block_on(self.sink.as_mut().send(self.buf.split().freeze())) + } +} + +impl Write for CopyInWriter<'_> { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.buf.len() > 4096 { + self.flush()?; + } + + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + self.flush_inner() + .map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } +} diff --git a/postgres/src/copy_out_reader.rs b/postgres/src/copy_out_reader.rs index 680d4d31..f1309fb8 100644 --- a/postgres/src/copy_out_reader.rs +++ b/postgres/src/copy_out_reader.rs @@ -1,33 +1,24 @@ use bytes::{Buf, Bytes}; -use futures::{executor, Stream}; +use futures::executor; use std::io::{self, BufRead, Cursor, Read}; use std::marker::PhantomData; use std::pin::Pin; -use tokio_postgres::Error; +use tokio_postgres::{CopyStream, Error}; /// The reader returned by the `copy_out` method. -pub struct CopyOutReader<'a, S> -where - S: Stream, -{ - it: executor::BlockingStream>>, +pub struct CopyOutReader<'a> { + it: executor::BlockingStream>>, cur: Cursor, _p: PhantomData<&'a mut ()>, } // no-op impl to extend borrow until drop -impl<'a, S> Drop for CopyOutReader<'a, S> -where - S: Stream, -{ +impl Drop for CopyOutReader<'_> { fn drop(&mut self) {} } -impl<'a, S> CopyOutReader<'a, S> -where - S: Stream>, -{ - pub(crate) fn new(stream: S) -> Result, Error> { +impl<'a> CopyOutReader<'a> { + pub(crate) fn new(stream: CopyStream) -> Result, Error> { let mut it = executor::block_on_stream(Box::pin(stream)); let cur = match it.next() { Some(Ok(cur)) => cur, @@ -43,10 +34,7 @@ where } } -impl<'a, S> Read for CopyOutReader<'a, S> -where - S: Stream>, -{ +impl Read for CopyOutReader<'_> { fn read(&mut self, buf: &mut [u8]) -> io::Result { let b = self.fill_buf()?; let len = usize::min(buf.len(), b.len()); @@ -56,10 +44,7 @@ where } } -impl<'a, S> BufRead for CopyOutReader<'a, S> -where - S: Stream>, -{ +impl BufRead for CopyOutReader<'_> { fn fill_buf(&mut self) -> io::Result<&[u8]> { if self.cur.remaining() == 0 { match self.it.next() { diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index 1cf9e324..d63e10a7 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -69,6 +69,8 @@ pub use tokio_postgres::{ pub use crate::client::*; #[cfg(feature = "runtime")] pub use crate::config::Config; +pub use crate::copy_in_writer::CopyInWriter; +pub use crate::copy_out_reader::CopyOutReader; #[doc(no_inline)] pub use crate::error::Error; #[doc(no_inline)] @@ -80,7 +82,7 @@ pub use crate::transaction::*; mod client; #[cfg(feature = "runtime")] pub mod config; -mod copy_in_stream; +mod copy_in_writer; mod copy_out_reader; mod iter; mod transaction; diff --git a/postgres/src/test.rs b/postgres/src/test.rs index f7d84a88..d376d186 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -1,4 +1,4 @@ -use std::io::Read; +use std::io::{Read, Write}; use tokio_postgres::types::Type; use tokio_postgres::NoTls; @@ -154,13 +154,9 @@ fn copy_in() { .simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)") .unwrap(); - client - .copy_in( - "COPY foo FROM stdin", - &[], - &mut &b"1\tsteven\n2\ttimothy"[..], - ) - .unwrap(); + let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap(); + writer.write_all(b"1\tsteven\n2\ttimothy").unwrap(); + writer.finish().unwrap(); let rows = client .query("SELECT id, name FROM foo ORDER BY id", &[]) @@ -173,6 +169,25 @@ fn copy_in() { assert_eq!(rows[1].get::<_, &str>(1), "timothy"); } +#[test] +fn copy_in_abort() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)") + .unwrap(); + + let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap(); + writer.write_all(b"1\tsteven\n2\ttimothy").unwrap(); + drop(writer); + + let rows = client + .query("SELECT id, name FROM foo ORDER BY id", &[]) + .unwrap(); + + assert_eq!(rows.len(), 0); +} + #[test] fn copy_out() { let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 8b857bb0..17631c79 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -1,14 +1,10 @@ +use crate::iter::Iter; +use crate::{CopyInWriter, CopyOutReader, Portal, Statement, ToStatement}; use fallible_iterator::FallibleIterator; use futures::executor; -use std::io::{BufRead, Read}; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row, SimpleQueryMessage}; -use crate::copy_in_stream::CopyInStream; -use crate::copy_out_reader::CopyOutReader; -use crate::iter::Iter; -use crate::{Portal, Statement, ToStatement}; - /// A representation of a PostgreSQL database transaction. /// /// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made @@ -117,17 +113,16 @@ impl<'a> Transaction<'a> { } /// Like `Client::copy_in`. - pub fn copy_in( + pub fn copy_in( &mut self, query: &T, params: &[&(dyn ToSql + Sync)], - reader: R, - ) -> Result + ) -> Result, Error> where T: ?Sized + ToStatement, - R: Read + Unpin, { - executor::block_on(self.0.copy_in(query, params, CopyInStream(reader))) + let sink = executor::block_on(self.0.copy_in(query, params))?; + Ok(CopyInWriter::new(sink)) } /// Like `Client::copy_out`. @@ -135,7 +130,7 @@ impl<'a> Transaction<'a> { &mut self, query: &T, params: &[&(dyn ToSql + Sync)], - ) -> Result + ) -> Result, Error> where T: ?Sized + ToStatement, { diff --git a/tokio-postgres-binary-copy/Cargo.toml b/tokio-postgres-binary-copy/Cargo.toml index 0b64e4f2..a8d44bf5 100644 --- a/tokio-postgres-binary-copy/Cargo.toml +++ b/tokio-postgres-binary-copy/Cargo.toml @@ -8,7 +8,6 @@ edition = "2018" byteorder = "1.0" bytes = "0.5" futures = "0.3" -parking_lot = "0.10" pin-project-lite = "0.1" tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" } diff --git a/tokio-postgres-binary-copy/src/lib.rs b/tokio-postgres-binary-copy/src/lib.rs index 59acf746..b90768b2 100644 --- a/tokio-postgres-binary-copy/src/lib.rs +++ b/tokio-postgres-binary-copy/src/lib.rs @@ -1,145 +1,95 @@ use bytes::{BufMut, Bytes, BytesMut, Buf}; -use futures::{future, ready, Stream}; -use parking_lot::Mutex; +use futures::{ready, Stream, SinkExt}; use pin_project_lite::pin_project; use std::convert::TryFrom; use std::error::Error; -use std::future::Future; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio_postgres::types::{IsNull, ToSql, Type, FromSql, WrongType}; -use tokio_postgres::CopyStream; +use tokio_postgres::{CopyStream, CopyInSink}; use std::io::Cursor; use byteorder::{ByteOrder, BigEndian}; #[cfg(test)] mod test; -const BLOCK_SIZE: usize = 4096; const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0"; const HEADER_LEN: usize = MAGIC.len() + 4 + 4; pin_project! { - pub struct BinaryCopyInStream { + pub struct BinaryCopyInWriter { #[pin] - future: F, - buf: Arc>, - done: bool, + sink: CopyInSink, + types: Vec, + buf: BytesMut, } } -impl BinaryCopyInStream -where - F: Future>>, -{ - pub fn new(types: &[Type], write_values: M) -> BinaryCopyInStream - where - M: FnOnce(BinaryCopyInWriter) -> F, - { +impl BinaryCopyInWriter { + pub fn new(sink: CopyInSink, types: &[Type]) -> BinaryCopyInWriter { let mut buf = BytesMut::new(); buf.reserve(HEADER_LEN); buf.put_slice(MAGIC); // magic buf.put_i32(0); // flags buf.put_i32(0); // header extension - let buf = Arc::new(Mutex::new(buf)); - let writer = BinaryCopyInWriter { - buf: buf.clone(), + BinaryCopyInWriter { + sink, types: types.to_vec(), - }; - - BinaryCopyInStream { - future: write_values(writer), buf, - done: false, } } -} -impl Stream for BinaryCopyInStream -where - F: Future>>, -{ - type Item = Result>; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); - - if *this.done { - return Poll::Ready(None); - } - - *this.done = this.future.poll(cx)?.is_ready(); - - let mut buf = this.buf.lock(); - if *this.done { - buf.reserve(2); - buf.put_i16(-1); - Poll::Ready(Some(Ok(buf.split().freeze()))) - } else if buf.len() > BLOCK_SIZE { - Poll::Ready(Some(Ok(buf.split().freeze()))) - } else { - Poll::Pending - } - } -} - -// FIXME this should really just take a reference to the buffer, but that requires HKT :( -pub struct BinaryCopyInWriter { - buf: Arc>, - types: Vec, -} - -impl BinaryCopyInWriter { pub async fn write( - &mut self, + self: Pin<&mut Self>, values: &[&(dyn ToSql + Send)], ) -> Result<(), Box> { self.write_raw(values.iter().cloned()).await } - pub async fn write_raw<'a, I>(&mut self, values: I) -> Result<(), Box> - where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, + pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Box> + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { + let mut this = self.project(); + let values = values.into_iter(); assert!( - values.len() == self.types.len(), + values.len() == this.types.len(), "expected {} values but got {}", - self.types.len(), + this.types.len(), values.len(), ); - future::poll_fn(|_| { - if self.buf.lock().len() > BLOCK_SIZE { - Poll::Pending - } else { - Poll::Ready(()) - } - }) - .await; + this.buf.put_i16(this.types.len() as i16); - let mut buf = self.buf.lock(); - - buf.reserve(2); - buf.put_u16(self.types.len() as u16); - - for (value, type_) in values.zip(&self.types) { - let idx = buf.len(); - buf.reserve(4); - buf.put_i32(0); - let len = match value.to_sql_checked(type_, &mut buf)? { + for (value, type_) in values.zip(this.types) { + let idx = this.buf.len(); + this.buf.put_i32(0); + let len = match value.to_sql_checked(type_, this.buf)? { IsNull::Yes => -1, - IsNull::No => i32::try_from(buf.len() - idx - 4)?, + IsNull::No => i32::try_from(this.buf.len() - idx - 4)?, }; - BigEndian::write_i32(&mut buf[idx..], len); + BigEndian::write_i32(&mut this.buf[idx..], len); + } + + if this.buf.len() > 4096 { + this.sink.send(this.buf.split().freeze()).await?; } Ok(()) } + + pub async fn finish(self: Pin<&mut Self>) -> Result { + let mut this = self.project(); + + this.buf.put_i16(-1); + this.sink.send(this.buf.split().freeze()).await?; + this.sink.finish().await + } } struct Header { diff --git a/tokio-postgres-binary-copy/src/test.rs b/tokio-postgres-binary-copy/src/test.rs index f4d19351..7d8bcd30 100644 --- a/tokio-postgres-binary-copy/src/test.rs +++ b/tokio-postgres-binary-copy/src/test.rs @@ -1,7 +1,7 @@ -use crate::{BinaryCopyInStream, BinaryCopyOutStream}; +use crate::{BinaryCopyInWriter, BinaryCopyOutStream}; use tokio_postgres::types::Type; use tokio_postgres::{Client, NoTls}; -use futures::TryStreamExt; +use futures::{TryStreamExt, pin_mut}; async fn connect() -> Client { let (client, connection) = @@ -23,19 +23,12 @@ async fn write_basic() { .await .unwrap(); - let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| { - async move { - w.write(&[&1i32, &"foobar"]).await?; - w.write(&[&2i32, &None::<&str>]).await?; - - Ok(()) - } - }); - - client - .copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream) - .await - .unwrap(); + let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap(); + let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]); + pin_mut!(writer); + writer.as_mut().write(&[&1i32, &"foobar"]).await.unwrap(); + writer.as_mut().write(&[&2i32, &None::<&str>]).await.unwrap(); + writer.finish().await.unwrap(); let rows = client .query("SELECT id, bar FROM foo ORDER BY id", &[]) @@ -57,20 +50,15 @@ async fn write_many_rows() { .await .unwrap(); - let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| { - async move { - for i in 0..10_000i32 { - w.write(&[&i, &format!("the value for {}", i)]).await?; - } + let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap(); + let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]); + pin_mut!(writer); - Ok(()) - } - }); + for i in 0..10_000i32 { + writer.as_mut().write(&[&i, &format!("the value for {}", i)]).await.unwrap(); + } - client - .copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream) - .await - .unwrap(); + writer.finish().await.unwrap(); let rows = client .query("SELECT id, bar FROM foo ORDER BY id", &[]) @@ -91,20 +79,15 @@ async fn write_big_rows() { .await .unwrap(); - let stream = BinaryCopyInStream::new(&[Type::INT4, Type::BYTEA], |mut w| { - async move { - for i in 0..2i32 { - w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?; - } + let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap(); + let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::BYTEA]); + pin_mut!(writer); - Ok(()) - } - }); + for i in 0..2i32 { + writer.as_mut().write(&[&i, &vec![i as u8; 128 * 1024]]).await.unwrap(); + } - client - .copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream) - .await - .unwrap(); + writer.finish().await.unwrap(); let rows = client .query("SELECT id, bar FROM foo ORDER BY id", &[]) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 4d7c2053..984d401b 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -14,18 +14,17 @@ use crate::to_statement::ToStatement; use crate::types::{Oid, ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; -use crate::{cancel_query_raw, copy_in, copy_out, query, Transaction}; +use crate::{cancel_query_raw, copy_in, copy_out, query, CopyInSink, Transaction}; use crate::{prepare, SimpleQueryMessage}; use crate::{simple_query, Row}; use crate::{Error, Statement}; use bytes::{Buf, BytesMut}; use fallible_iterator::FallibleIterator; use futures::channel::mpsc; -use futures::{future, pin_mut, ready, StreamExt, TryStream, TryStreamExt}; +use futures::{future, pin_mut, ready, StreamExt, TryStreamExt}; use parking_lot::Mutex; use postgres_protocol::message::backend::Message; use std::collections::HashMap; -use std::error; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -340,29 +339,26 @@ impl Client { query::execute(self.inner(), statement, params).await } - /// Executes a `COPY FROM STDIN` statement, returning the number of rows created. + /// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data. /// - /// The data in the provided stream is passed along to the server verbatim; it is the caller's responsibility to - /// ensure it uses the proper format. + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. /// /// # Panics /// /// Panics if the number of parameters provided does not match the number expected. - pub async fn copy_in( + pub async fn copy_in( &self, statement: &T, params: &[&(dyn ToSql + Sync)], - stream: S, - ) -> Result + ) -> Result, Error> where T: ?Sized + ToStatement, - S: TryStream, - S::Ok: Buf + 'static + Send, - S::Error: Into>, + U: Buf + 'static + Send, { let statement = statement.__convert().into_statement(self).await?; let params = slice_iter(params); - copy_in::copy_in(self.inner(), statement, params, stream).await + copy_in::copy_in(self.inner(), statement, params).await } /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index ebacb6cf..b1cdae59 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -1,4 +1,4 @@ -use crate::client::InnerClient; +use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::ToSql; @@ -6,11 +6,13 @@ use crate::{query, Error, Statement}; use bytes::buf::BufExt; use bytes::{Buf, BufMut, BytesMut}; use futures::channel::mpsc; -use futures::{pin_mut, ready, SinkExt, Stream, StreamExt, TryStream, TryStreamExt}; +use futures::{ready, Sink, SinkExt, Stream, StreamExt}; +use futures::future; +use pin_project_lite::pin_project; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use postgres_protocol::message::frontend::CopyData; -use std::error; +use std::marker::{PhantomPinned, PhantomData}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -61,18 +63,148 @@ impl Stream for CopyInReceiver { } } -pub async fn copy_in<'a, I, S>( +enum SinkState { + Active, + Closing, + Reading, +} + +pin_project! { + /// A sink for `COPY ... FROM STDIN` query data. + /// + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. + pub struct CopyInSink { + #[pin] + sender: mpsc::Sender, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl CopyInSink +where + T: Buf + 'static + Send, +{ + /// A poll-based version of `finish`. + pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + SinkState::Active => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + this.sender + .start_send(CopyInMessage::Done) + .map_err(|_| Error::closed())?; + *this.state = SinkState::Closing; + } + SinkState::Closing => { + let this = self.as_mut().project(); + ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; + *this.state = SinkState::Reading; + } + SinkState::Reading => { + let this = self.as_mut().project(); + match ready!(this.responses.poll_next(cx))? { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Ok(rows)); + } + _ => return Poll::Ready(Err(Error::unexpected_message())), + } + } + } + } + } + + /// Completes the copy, returning the number of rows inserted. + /// + /// The `Sink::close` method is equivalent to `finish`, except that it does not return the + /// number of rows. + pub async fn finish(mut self: Pin<&mut Self>) -> Result { + future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await + } +} + +impl Sink for CopyInSink +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_finish(cx).map_ok(|_| ()) + } +} + +pub async fn copy_in<'a, I, T>( client: &InnerClient, statement: Statement, params: I, - stream: S, -) -> Result +) -> Result, Error> where I: IntoIterator, I::IntoIter: ExactSizeIterator, - S: TryStream, - S::Ok: Buf + 'static + Send, - S::Error: Into>, + T: Buf + 'static + Send, { let buf = query::encode(client, &statement, params)?; @@ -95,60 +227,12 @@ where _ => return Err(Error::unexpected_message()), } - let mut bytes = BytesMut::new(); - let stream = stream.into_stream(); - pin_mut!(stream); - - while let Some(buf) = stream.try_next().await.map_err(Error::copy_in_stream)? { - let data: Box = if buf.remaining() > 4096 { - if bytes.is_empty() { - Box::new(buf) - } else { - Box::new(bytes.split().freeze().chain(buf)) - } - } else { - bytes.reserve(buf.remaining()); - bytes.put(buf); - if bytes.len() > 4096 { - Box::new(bytes.split().freeze()) - } else { - continue; - } - }; - - let data = CopyData::new(data).map_err(Error::encode)?; - sender - .send(CopyInMessage::Message(FrontendMessage::CopyData(data))) - .await - .map_err(|_| Error::closed())?; - } - - if !bytes.is_empty() { - let data: Box = Box::new(bytes.freeze()); - let data = CopyData::new(data).map_err(Error::encode)?; - sender - .send(CopyInMessage::Message(FrontendMessage::CopyData(data))) - .await - .map_err(|_| Error::closed())?; - } - - sender - .send(CopyInMessage::Done) - .await - .map_err(|_| Error::closed())?; - - match responses.next().await? { - Message::CommandComplete(body) => { - let rows = body - .tag() - .map_err(Error::parse)? - .rsplit(' ') - .next() - .unwrap() - .parse() - .unwrap_or(0); - Ok(rows) - } - _ => Err(Error::unexpected_message()), - } + Ok(CopyInSink { + sender, + responses, + buf: BytesMut::new(), + state: SinkState::Active, + _p: PhantomPinned, + _p2: PhantomData, + }) } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index 4dde62f7..788e70cf 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -337,7 +337,6 @@ enum Kind { ToSql(usize), FromSql(usize), Column(String), - CopyInStream, Closed, Db, Parse, @@ -376,7 +375,6 @@ impl fmt::Display for Error { Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?, - Kind::CopyInStream => fmt.write_str("error from a copy_in stream")?, Kind::Closed => fmt.write_str("connection closed")?, Kind::Db => fmt.write_str("db error")?, Kind::Parse => fmt.write_str("error parsing response from server")?, @@ -458,13 +456,6 @@ impl Error { Error::new(Kind::Column(column), None) } - pub(crate) fn copy_in_stream(e: E) -> Error - where - E: Into>, - { - Error::new(Kind::CopyInStream, Some(e.into())) - } - pub(crate) fn tls(e: Box) -> Error { Error::new(Kind::Tls, Some(e)) } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index dc88389f..61367f29 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -105,6 +105,7 @@ pub use crate::client::Client; pub use crate::config::Config; pub use crate::connection::Connection; +pub use crate::copy_in::CopyInSink; pub use crate::copy_out::CopyStream; use crate::error::DbError; pub use crate::error::Error; diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 81c5d460..ac44a841 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -9,12 +9,12 @@ use crate::types::{ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ - bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement, + bind, query, slice_iter, Client, CopyInSink, Error, Portal, Row, SimpleQueryMessage, Statement, + ToStatement, }; use bytes::Buf; -use futures::{TryStream, TryStreamExt}; +use futures::{TryStreamExt}; use postgres_protocol::message::frontend; -use std::error; use tokio::io::{AsyncRead, AsyncWrite}; /// A representation of a PostgreSQL database transaction. @@ -209,19 +209,16 @@ impl<'a> Transaction<'a> { } /// Like `Client::copy_in`. - pub async fn copy_in( + pub async fn copy_in( &self, statement: &T, params: &[&(dyn ToSql + Sync)], - stream: S, - ) -> Result + ) -> Result, Error> where T: ?Sized + ToStatement, - S: TryStream, - S::Ok: Buf + 'static + Send, - S::Error: Into>, + U: Buf + 'static + Send, { - self.client.copy_in(statement, params, stream).await + self.client.copy_in(statement, params).await } /// Like `Client::copy_out`. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index ce9cc9d8..4e6086f4 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -2,8 +2,7 @@ use bytes::{Bytes, BytesMut}; use futures::channel::mpsc; -use futures::{future, stream, StreamExt}; -use futures::{join, try_join, FutureExt, TryStreamExt}; +use futures::{future, stream, StreamExt, SinkExt, pin_mut, join, try_join, FutureExt, TryStreamExt}; use std::fmt::Write; use std::time::Duration; use tokio::net::TcpStream; @@ -409,23 +408,21 @@ async fn copy_in() { .await .unwrap(); - let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap(); - let stream = stream::iter( + let mut stream = stream::iter( vec![ Bytes::from_static(b"1\tjim\n"), Bytes::from_static(b"2\tjoe\n"), ] .into_iter() - .map(Ok::<_, String>), + .map(Ok::<_, Error>), ); - let rows = client.copy_in(&stmt, &[], stream).await.unwrap(); + let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap(); + pin_mut!(sink); + sink.send_all(&mut stream).await.unwrap(); + let rows = sink.finish().await.unwrap(); assert_eq!(rows, 2); - let stmt = client - .prepare("SELECT id, name FROM foo ORDER BY id") - .await - .unwrap(); - let rows = client.query(&stmt, &[]).await.unwrap(); + let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap(); assert_eq!(rows.len(), 2); assert_eq!(rows[0].get::<_, i32>(0), 1); @@ -448,8 +445,6 @@ async fn copy_in_large() { .await .unwrap(); - let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap(); - let a = Bytes::from_static(b"0\tname0\n"); let mut b = BytesMut::new(); for i in 1..5_000 { @@ -459,13 +454,16 @@ async fn copy_in_large() { for i in 5_000..10_000 { writeln!(c, "{0}\tname{0}", i).unwrap(); } - let stream = stream::iter( + let mut stream = stream::iter( vec![a, b.freeze(), c.freeze()] .into_iter() - .map(Ok::<_, String>), + .map(Ok::<_, Error>), ); - let rows = client.copy_in(&stmt, &[], stream).await.unwrap(); + let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap(); + pin_mut!(sink); + sink.send_all(&mut stream).await.unwrap(); + let rows = sink.finish().await.unwrap(); assert_eq!(rows, 10_000); } @@ -483,16 +481,13 @@ async fn copy_in_error() { .await .unwrap(); - let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap(); - let stream = stream::iter(vec![Ok(Bytes::from_static(b"1\tjim\n")), Err("asdf")]); - let error = client.copy_in(&stmt, &[], stream).await.unwrap_err(); - assert!(error.to_string().contains("asdf")); + { + let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap(); + pin_mut!(sink); + sink.send(Bytes::from_static(b"1\tsteven")).await.unwrap(); + } - let stmt = client - .prepare("SELECT id, name FROM foo ORDER BY id") - .await - .unwrap(); - let rows = client.query(&stmt, &[]).await.unwrap(); + let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap(); assert_eq!(rows.len(), 0); }