Change the copy_in interface

Rather than taking in a Stream and advancing it internally, return a
Sink that can be advanced by the calling code. This significantly
simplifies encoding logic for things like tokio-postgres-binary-copy.

Similarly, the blocking interface returns a Writer.

Closes #489
This commit is contained in:
Steven Fackler 2019-11-30 11:04:59 -05:00
parent a5428e6a03
commit e5e03b0064
16 changed files with 367 additions and 335 deletions

View File

@ -1,19 +1,15 @@
use crate::iter::Iter;
#[cfg(feature = "runtime")]
use crate::Config;
use crate::{CopyInWriter, CopyOutReader, Statement, ToStatement, Transaction};
use fallible_iterator::FallibleIterator;
use futures::executor;
use std::io::{BufRead, Read};
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::types::{ToSql, Type};
#[cfg(feature = "runtime")]
use tokio_postgres::Socket;
use tokio_postgres::{Error, Row, SimpleQueryMessage};
use crate::copy_in_stream::CopyInStream;
use crate::copy_out_reader::CopyOutReader;
use crate::iter::Iter;
#[cfg(feature = "runtime")]
use crate::Config;
use crate::{Statement, ToStatement, Transaction};
/// A synchronous PostgreSQL client.
///
/// This is a lightweight wrapper over the asynchronous tokio_postgres `Client`.
@ -264,29 +260,33 @@ impl Client {
/// The `query` argument can either be a `Statement`, or a raw query string. The data in the provided reader is
/// passed along to the server verbatim; it is the caller's responsibility to ensure it uses the proper format.
///
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
///
/// # Examples
///
/// ```no_run
/// use postgres::{Client, NoTls};
/// use std::io::Write;
///
/// # fn main() -> Result<(), postgres::Error> {
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
/// let mut client = Client::connect("host=localhost user=postgres", NoTls)?;
///
/// client.copy_in("COPY people FROM stdin", &[], &mut "1\tjohn\n2\tjane\n".as_bytes())?;
/// let mut writer = client.copy_in("COPY people FROM stdin", &[])?;
/// writer.write_all(b"1\tjohn\n2\tjane\n")?;
/// writer.finish()?;
/// # Ok(())
/// # }
/// ```
pub fn copy_in<T, R>(
pub fn copy_in<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
reader: R,
) -> Result<u64, Error>
) -> Result<CopyInWriter<'_>, Error>
where
T: ?Sized + ToStatement,
R: Read + Unpin,
{
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
let sink = executor::block_on(self.0.copy_in(query, params))?;
Ok(CopyInWriter::new(sink))
}
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
@ -312,7 +312,7 @@ impl Client {
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<impl BufRead, Error>
) -> Result<CopyOutReader<'_>, Error>
where
T: ?Sized + ToStatement,
{

View File

@ -1,24 +0,0 @@
use futures::Stream;
use std::io::{self, Cursor, Read};
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct CopyInStream<R>(pub R);
impl<R> Stream for CopyInStream<R>
where
R: Read + Unpin,
{
type Item = io::Result<Cursor<Vec<u8>>>;
fn poll_next(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<io::Result<Cursor<Vec<u8>>>>> {
let mut buf = vec![];
match self.0.by_ref().take(4096).read_to_end(&mut buf)? {
0 => Poll::Ready(None),
_ => Poll::Ready(Some(Ok(Cursor::new(buf)))),
}
}
}

View File

@ -0,0 +1,63 @@
use bytes::{Bytes, BytesMut};
use futures::{executor, SinkExt};
use std::io;
use std::io::Write;
use std::marker::PhantomData;
use std::pin::Pin;
use tokio_postgres::{CopyInSink, Error};
/// The writer returned by the `copy_in` method.
///
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
pub struct CopyInWriter<'a> {
sink: Pin<Box<CopyInSink<Bytes>>>,
buf: BytesMut,
_p: PhantomData<&'a mut ()>,
}
// no-op impl to extend borrow until drop
impl Drop for CopyInWriter<'_> {
fn drop(&mut self) {}
}
impl<'a> CopyInWriter<'a> {
pub(crate) fn new(sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
CopyInWriter {
sink: Box::pin(sink),
buf: BytesMut::new(),
_p: PhantomData,
}
}
/// Completes the copy, returning the number of rows written.
///
/// If this is not called, the copy will be aborted.
pub fn finish(mut self) -> Result<u64, Error> {
self.flush_inner()?;
executor::block_on(self.sink.as_mut().finish())
}
fn flush_inner(&mut self) -> Result<(), Error> {
if self.buf.is_empty() {
return Ok(());
}
executor::block_on(self.sink.as_mut().send(self.buf.split().freeze()))
}
}
impl Write for CopyInWriter<'_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.buf.len() > 4096 {
self.flush()?;
}
self.buf.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.flush_inner()
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}

View File

@ -1,33 +1,24 @@
use bytes::{Buf, Bytes};
use futures::{executor, Stream};
use futures::executor;
use std::io::{self, BufRead, Cursor, Read};
use std::marker::PhantomData;
use std::pin::Pin;
use tokio_postgres::Error;
use tokio_postgres::{CopyStream, Error};
/// The reader returned by the `copy_out` method.
pub struct CopyOutReader<'a, S>
where
S: Stream,
{
it: executor::BlockingStream<Pin<Box<S>>>,
pub struct CopyOutReader<'a> {
it: executor::BlockingStream<Pin<Box<CopyStream>>>,
cur: Cursor<Bytes>,
_p: PhantomData<&'a mut ()>,
}
// no-op impl to extend borrow until drop
impl<'a, S> Drop for CopyOutReader<'a, S>
where
S: Stream,
{
impl Drop for CopyOutReader<'_> {
fn drop(&mut self) {}
}
impl<'a, S> CopyOutReader<'a, S>
where
S: Stream<Item = Result<Bytes, Error>>,
{
pub(crate) fn new(stream: S) -> Result<CopyOutReader<'a, S>, Error> {
impl<'a> CopyOutReader<'a> {
pub(crate) fn new(stream: CopyStream) -> Result<CopyOutReader<'a>, Error> {
let mut it = executor::block_on_stream(Box::pin(stream));
let cur = match it.next() {
Some(Ok(cur)) => cur,
@ -43,10 +34,7 @@ where
}
}
impl<'a, S> Read for CopyOutReader<'a, S>
where
S: Stream<Item = Result<Bytes, Error>>,
{
impl Read for CopyOutReader<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let b = self.fill_buf()?;
let len = usize::min(buf.len(), b.len());
@ -56,10 +44,7 @@ where
}
}
impl<'a, S> BufRead for CopyOutReader<'a, S>
where
S: Stream<Item = Result<Bytes, Error>>,
{
impl BufRead for CopyOutReader<'_> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if self.cur.remaining() == 0 {
match self.it.next() {

View File

@ -69,6 +69,8 @@ pub use tokio_postgres::{
pub use crate::client::*;
#[cfg(feature = "runtime")]
pub use crate::config::Config;
pub use crate::copy_in_writer::CopyInWriter;
pub use crate::copy_out_reader::CopyOutReader;
#[doc(no_inline)]
pub use crate::error::Error;
#[doc(no_inline)]
@ -80,7 +82,7 @@ pub use crate::transaction::*;
mod client;
#[cfg(feature = "runtime")]
pub mod config;
mod copy_in_stream;
mod copy_in_writer;
mod copy_out_reader;
mod iter;
mod transaction;

View File

@ -1,4 +1,4 @@
use std::io::Read;
use std::io::{Read, Write};
use tokio_postgres::types::Type;
use tokio_postgres::NoTls;
@ -154,13 +154,9 @@ fn copy_in() {
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
.unwrap();
client
.copy_in(
"COPY foo FROM stdin",
&[],
&mut &b"1\tsteven\n2\ttimothy"[..],
)
.unwrap();
let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap();
writer.write_all(b"1\tsteven\n2\ttimothy").unwrap();
writer.finish().unwrap();
let rows = client
.query("SELECT id, name FROM foo ORDER BY id", &[])
@ -173,6 +169,25 @@ fn copy_in() {
assert_eq!(rows[1].get::<_, &str>(1), "timothy");
}
#[test]
fn copy_in_abort() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
.unwrap();
let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap();
writer.write_all(b"1\tsteven\n2\ttimothy").unwrap();
drop(writer);
let rows = client
.query("SELECT id, name FROM foo ORDER BY id", &[])
.unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn copy_out() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

View File

@ -1,14 +1,10 @@
use crate::iter::Iter;
use crate::{CopyInWriter, CopyOutReader, Portal, Statement, ToStatement};
use fallible_iterator::FallibleIterator;
use futures::executor;
use std::io::{BufRead, Read};
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{Error, Row, SimpleQueryMessage};
use crate::copy_in_stream::CopyInStream;
use crate::copy_out_reader::CopyOutReader;
use crate::iter::Iter;
use crate::{Portal, Statement, ToStatement};
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
@ -117,17 +113,16 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::copy_in`.
pub fn copy_in<T, R>(
pub fn copy_in<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
reader: R,
) -> Result<u64, Error>
) -> Result<CopyInWriter<'_>, Error>
where
T: ?Sized + ToStatement,
R: Read + Unpin,
{
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
let sink = executor::block_on(self.0.copy_in(query, params))?;
Ok(CopyInWriter::new(sink))
}
/// Like `Client::copy_out`.
@ -135,7 +130,7 @@ impl<'a> Transaction<'a> {
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<impl BufRead, Error>
) -> Result<CopyOutReader<'_>, Error>
where
T: ?Sized + ToStatement,
{

View File

@ -8,7 +8,6 @@ edition = "2018"
byteorder = "1.0"
bytes = "0.5"
futures = "0.3"
parking_lot = "0.10"
pin-project-lite = "0.1"
tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" }

View File

@ -1,145 +1,95 @@
use bytes::{BufMut, Bytes, BytesMut, Buf};
use futures::{future, ready, Stream};
use parking_lot::Mutex;
use futures::{ready, Stream, SinkExt};
use pin_project_lite::pin_project;
use std::convert::TryFrom;
use std::error::Error;
use std::future::Future;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_postgres::types::{IsNull, ToSql, Type, FromSql, WrongType};
use tokio_postgres::CopyStream;
use tokio_postgres::{CopyStream, CopyInSink};
use std::io::Cursor;
use byteorder::{ByteOrder, BigEndian};
#[cfg(test)]
mod test;
const BLOCK_SIZE: usize = 4096;
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
pin_project! {
pub struct BinaryCopyInStream<F> {
pub struct BinaryCopyInWriter {
#[pin]
future: F,
buf: Arc<Mutex<BytesMut>>,
done: bool,
sink: CopyInSink<Bytes>,
types: Vec<Type>,
buf: BytesMut,
}
}
impl<F> BinaryCopyInStream<F>
where
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
{
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyInStream<F>
where
M: FnOnce(BinaryCopyInWriter) -> F,
{
impl BinaryCopyInWriter {
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
let mut buf = BytesMut::new();
buf.reserve(HEADER_LEN);
buf.put_slice(MAGIC); // magic
buf.put_i32(0); // flags
buf.put_i32(0); // header extension
let buf = Arc::new(Mutex::new(buf));
let writer = BinaryCopyInWriter {
buf: buf.clone(),
BinaryCopyInWriter {
sink,
types: types.to_vec(),
};
BinaryCopyInStream {
future: write_values(writer),
buf,
done: false,
}
}
}
impl<F> Stream for BinaryCopyInStream<F>
where
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
{
type Item = Result<Bytes, Box<dyn Error + Sync + Send>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.done {
return Poll::Ready(None);
}
*this.done = this.future.poll(cx)?.is_ready();
let mut buf = this.buf.lock();
if *this.done {
buf.reserve(2);
buf.put_i16(-1);
Poll::Ready(Some(Ok(buf.split().freeze())))
} else if buf.len() > BLOCK_SIZE {
Poll::Ready(Some(Ok(buf.split().freeze())))
} else {
Poll::Pending
}
}
}
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
pub struct BinaryCopyInWriter {
buf: Arc<Mutex<BytesMut>>,
types: Vec<Type>,
}
impl BinaryCopyInWriter {
pub async fn write(
&mut self,
self: Pin<&mut Self>,
values: &[&(dyn ToSql + Send)],
) -> Result<(), Box<dyn Error + Sync + Send>> {
self.write_raw(values.iter().cloned()).await
}
pub async fn write_raw<'a, I>(&mut self, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
where
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
I::IntoIter: ExactSizeIterator,
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
where
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
I::IntoIter: ExactSizeIterator,
{
let mut this = self.project();
let values = values.into_iter();
assert!(
values.len() == self.types.len(),
values.len() == this.types.len(),
"expected {} values but got {}",
self.types.len(),
this.types.len(),
values.len(),
);
future::poll_fn(|_| {
if self.buf.lock().len() > BLOCK_SIZE {
Poll::Pending
} else {
Poll::Ready(())
}
})
.await;
this.buf.put_i16(this.types.len() as i16);
let mut buf = self.buf.lock();
buf.reserve(2);
buf.put_u16(self.types.len() as u16);
for (value, type_) in values.zip(&self.types) {
let idx = buf.len();
buf.reserve(4);
buf.put_i32(0);
let len = match value.to_sql_checked(type_, &mut buf)? {
for (value, type_) in values.zip(this.types) {
let idx = this.buf.len();
this.buf.put_i32(0);
let len = match value.to_sql_checked(type_, this.buf)? {
IsNull::Yes => -1,
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
IsNull::No => i32::try_from(this.buf.len() - idx - 4)?,
};
BigEndian::write_i32(&mut buf[idx..], len);
BigEndian::write_i32(&mut this.buf[idx..], len);
}
if this.buf.len() > 4096 {
this.sink.send(this.buf.split().freeze()).await?;
}
Ok(())
}
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, tokio_postgres::Error> {
let mut this = self.project();
this.buf.put_i16(-1);
this.sink.send(this.buf.split().freeze()).await?;
this.sink.finish().await
}
}
struct Header {

View File

@ -1,7 +1,7 @@
use crate::{BinaryCopyInStream, BinaryCopyOutStream};
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
use tokio_postgres::types::Type;
use tokio_postgres::{Client, NoTls};
use futures::TryStreamExt;
use futures::{TryStreamExt, pin_mut};
async fn connect() -> Client {
let (client, connection) =
@ -23,19 +23,12 @@ async fn write_basic() {
.await
.unwrap();
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| {
async move {
w.write(&[&1i32, &"foobar"]).await?;
w.write(&[&2i32, &None::<&str>]).await?;
Ok(())
}
});
client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
.await
.unwrap();
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]);
pin_mut!(writer);
writer.as_mut().write(&[&1i32, &"foobar"]).await.unwrap();
writer.as_mut().write(&[&2i32, &None::<&str>]).await.unwrap();
writer.finish().await.unwrap();
let rows = client
.query("SELECT id, bar FROM foo ORDER BY id", &[])
@ -57,20 +50,15 @@ async fn write_many_rows() {
.await
.unwrap();
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| {
async move {
for i in 0..10_000i32 {
w.write(&[&i, &format!("the value for {}", i)]).await?;
}
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]);
pin_mut!(writer);
Ok(())
}
});
for i in 0..10_000i32 {
writer.as_mut().write(&[&i, &format!("the value for {}", i)]).await.unwrap();
}
client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
.await
.unwrap();
writer.finish().await.unwrap();
let rows = client
.query("SELECT id, bar FROM foo ORDER BY id", &[])
@ -91,20 +79,15 @@ async fn write_big_rows() {
.await
.unwrap();
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
async move {
for i in 0..2i32 {
w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?;
}
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::BYTEA]);
pin_mut!(writer);
Ok(())
}
});
for i in 0..2i32 {
writer.as_mut().write(&[&i, &vec![i as u8; 128 * 1024]]).await.unwrap();
}
client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
.await
.unwrap();
writer.finish().await.unwrap();
let rows = client
.query("SELECT id, bar FROM foo ORDER BY id", &[])

View File

@ -14,18 +14,17 @@ use crate::to_statement::ToStatement;
use crate::types::{Oid, ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{cancel_query_raw, copy_in, copy_out, query, Transaction};
use crate::{cancel_query_raw, copy_in, copy_out, query, CopyInSink, Transaction};
use crate::{prepare, SimpleQueryMessage};
use crate::{simple_query, Row};
use crate::{Error, Statement};
use bytes::{Buf, BytesMut};
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use futures::{future, pin_mut, ready, StreamExt, TryStream, TryStreamExt};
use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol::message::backend::Message;
use std::collections::HashMap;
use std::error;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
@ -340,29 +339,26 @@ impl Client {
query::execute(self.inner(), statement, params).await
}
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
/// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data.
///
/// The data in the provided stream is passed along to the server verbatim; it is the caller's responsibility to
/// ensure it uses the proper format.
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
/// not, the copy will be aborted.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn copy_in<T, S>(
pub async fn copy_in<T, U>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
stream: S,
) -> Result<u64, Error>
) -> Result<CopyInSink<U>, Error>
where
T: ?Sized + ToStatement,
S: TryStream,
S::Ok: Buf + 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
U: Buf + 'static + Send,
{
let statement = statement.__convert().into_statement(self).await?;
let params = slice_iter(params);
copy_in::copy_in(self.inner(), statement, params, stream).await
copy_in::copy_in(self.inner(), statement, params).await
}
/// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.

View File

@ -1,4 +1,4 @@
use crate::client::InnerClient;
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::ToSql;
@ -6,11 +6,13 @@ use crate::{query, Error, Statement};
use bytes::buf::BufExt;
use bytes::{Buf, BufMut, BytesMut};
use futures::channel::mpsc;
use futures::{pin_mut, ready, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
use futures::{ready, Sink, SinkExt, Stream, StreamExt};
use futures::future;
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use postgres_protocol::message::frontend::CopyData;
use std::error;
use std::marker::{PhantomPinned, PhantomData};
use std::pin::Pin;
use std::task::{Context, Poll};
@ -61,18 +63,148 @@ impl Stream for CopyInReceiver {
}
}
pub async fn copy_in<'a, I, S>(
enum SinkState {
Active,
Closing,
Reading,
}
pin_project! {
/// A sink for `COPY ... FROM STDIN` query data.
///
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
/// not, the copy will be aborted.
pub struct CopyInSink<T> {
#[pin]
sender: mpsc::Sender<CopyInMessage>,
responses: Responses,
buf: BytesMut,
state: SinkState,
#[pin]
_p: PhantomPinned,
_p2: PhantomData<T>,
}
}
impl<T> CopyInSink<T>
where
T: Buf + 'static + Send,
{
/// A poll-based version of `finish`.
pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
loop {
match self.state {
SinkState::Active => {
ready!(self.as_mut().poll_flush(cx))?;
let mut this = self.as_mut().project();
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
this.sender
.start_send(CopyInMessage::Done)
.map_err(|_| Error::closed())?;
*this.state = SinkState::Closing;
}
SinkState::Closing => {
let this = self.as_mut().project();
ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
*this.state = SinkState::Reading;
}
SinkState::Reading => {
let this = self.as_mut().project();
match ready!(this.responses.poll_next(cx))? {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
return Poll::Ready(Ok(rows));
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
}
}
}
}
}
/// Completes the copy, returning the number of rows inserted.
///
/// The `Sink::close` method is equivalent to `finish`, except that it does not return the
/// number of rows.
pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
}
}
impl<T> Sink<T> for CopyInSink<T>
where
T: Buf + 'static + Send,
{
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project()
.sender
.poll_ready(cx)
.map_err(|_| Error::closed())
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
let this = self.project();
let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
if this.buf.is_empty() {
Box::new(item)
} else {
Box::new(this.buf.split().freeze().chain(item))
}
} else {
this.buf.put(item);
if this.buf.len() > 4096 {
Box::new(this.buf.split().freeze())
} else {
return Ok(());
}
};
let data = CopyData::new(data).map_err(Error::encode)?;
this.sender
.start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
.map_err(|_| Error::closed())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let mut this = self.project();
if !this.buf.is_empty() {
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
let data = CopyData::new(data).map_err(Error::encode)?;
this.sender
.as_mut()
.start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
.map_err(|_| Error::closed())?;
}
this.sender.poll_flush(cx).map_err(|_| Error::closed())
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.poll_finish(cx).map_ok(|_| ())
}
}
pub async fn copy_in<'a, I, T>(
client: &InnerClient,
statement: Statement,
params: I,
stream: S,
) -> Result<u64, Error>
) -> Result<CopyInSink<T>, Error>
where
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
S: TryStream,
S::Ok: Buf + 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
T: Buf + 'static + Send,
{
let buf = query::encode(client, &statement, params)?;
@ -95,60 +227,12 @@ where
_ => return Err(Error::unexpected_message()),
}
let mut bytes = BytesMut::new();
let stream = stream.into_stream();
pin_mut!(stream);
while let Some(buf) = stream.try_next().await.map_err(Error::copy_in_stream)? {
let data: Box<dyn Buf + Send> = if buf.remaining() > 4096 {
if bytes.is_empty() {
Box::new(buf)
} else {
Box::new(bytes.split().freeze().chain(buf))
}
} else {
bytes.reserve(buf.remaining());
bytes.put(buf);
if bytes.len() > 4096 {
Box::new(bytes.split().freeze())
} else {
continue;
}
};
let data = CopyData::new(data).map_err(Error::encode)?;
sender
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
.await
.map_err(|_| Error::closed())?;
}
if !bytes.is_empty() {
let data: Box<dyn Buf + Send> = Box::new(bytes.freeze());
let data = CopyData::new(data).map_err(Error::encode)?;
sender
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
.await
.map_err(|_| Error::closed())?;
}
sender
.send(CopyInMessage::Done)
.await
.map_err(|_| Error::closed())?;
match responses.next().await? {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
Ok(rows)
}
_ => Err(Error::unexpected_message()),
}
Ok(CopyInSink {
sender,
responses,
buf: BytesMut::new(),
state: SinkState::Active,
_p: PhantomPinned,
_p2: PhantomData,
})
}

View File

@ -337,7 +337,6 @@ enum Kind {
ToSql(usize),
FromSql(usize),
Column(String),
CopyInStream,
Closed,
Db,
Parse,
@ -376,7 +375,6 @@ impl fmt::Display for Error {
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?,
Kind::CopyInStream => fmt.write_str("error from a copy_in stream")?,
Kind::Closed => fmt.write_str("connection closed")?,
Kind::Db => fmt.write_str("db error")?,
Kind::Parse => fmt.write_str("error parsing response from server")?,
@ -458,13 +456,6 @@ impl Error {
Error::new(Kind::Column(column), None)
}
pub(crate) fn copy_in_stream<E>(e: E) -> Error
where
E: Into<Box<dyn error::Error + Sync + Send>>,
{
Error::new(Kind::CopyInStream, Some(e.into()))
}
pub(crate) fn tls(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Tls, Some(e))
}

View File

@ -105,6 +105,7 @@
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::connection::Connection;
pub use crate::copy_in::CopyInSink;
pub use crate::copy_out::CopyStream;
use crate::error::DbError;
pub use crate::error::Error;

View File

@ -9,12 +9,12 @@ use crate::types::{ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
bind, query, slice_iter, Client, CopyInSink, Error, Portal, Row, SimpleQueryMessage, Statement,
ToStatement,
};
use bytes::Buf;
use futures::{TryStream, TryStreamExt};
use futures::{TryStreamExt};
use postgres_protocol::message::frontend;
use std::error;
use tokio::io::{AsyncRead, AsyncWrite};
/// A representation of a PostgreSQL database transaction.
@ -209,19 +209,16 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::copy_in`.
pub async fn copy_in<T, S>(
pub async fn copy_in<T, U>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
stream: S,
) -> Result<u64, Error>
) -> Result<CopyInSink<U>, Error>
where
T: ?Sized + ToStatement,
S: TryStream,
S::Ok: Buf + 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
U: Buf + 'static + Send,
{
self.client.copy_in(statement, params, stream).await
self.client.copy_in(statement, params).await
}
/// Like `Client::copy_out`.

View File

@ -2,8 +2,7 @@
use bytes::{Bytes, BytesMut};
use futures::channel::mpsc;
use futures::{future, stream, StreamExt};
use futures::{join, try_join, FutureExt, TryStreamExt};
use futures::{future, stream, StreamExt, SinkExt, pin_mut, join, try_join, FutureExt, TryStreamExt};
use std::fmt::Write;
use std::time::Duration;
use tokio::net::TcpStream;
@ -409,23 +408,21 @@ async fn copy_in() {
.await
.unwrap();
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
let stream = stream::iter(
let mut stream = stream::iter(
vec![
Bytes::from_static(b"1\tjim\n"),
Bytes::from_static(b"2\tjoe\n"),
]
.into_iter()
.map(Ok::<_, String>),
.map(Ok::<_, Error>),
);
let rows = client.copy_in(&stmt, &[], stream).await.unwrap();
let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap();
pin_mut!(sink);
sink.send_all(&mut stream).await.unwrap();
let rows = sink.finish().await.unwrap();
assert_eq!(rows, 2);
let stmt = client
.prepare("SELECT id, name FROM foo ORDER BY id")
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
@ -448,8 +445,6 @@ async fn copy_in_large() {
.await
.unwrap();
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
let a = Bytes::from_static(b"0\tname0\n");
let mut b = BytesMut::new();
for i in 1..5_000 {
@ -459,13 +454,16 @@ async fn copy_in_large() {
for i in 5_000..10_000 {
writeln!(c, "{0}\tname{0}", i).unwrap();
}
let stream = stream::iter(
let mut stream = stream::iter(
vec![a, b.freeze(), c.freeze()]
.into_iter()
.map(Ok::<_, String>),
.map(Ok::<_, Error>),
);
let rows = client.copy_in(&stmt, &[], stream).await.unwrap();
let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap();
pin_mut!(sink);
sink.send_all(&mut stream).await.unwrap();
let rows = sink.finish().await.unwrap();
assert_eq!(rows, 10_000);
}
@ -483,16 +481,13 @@ async fn copy_in_error() {
.await
.unwrap();
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
let stream = stream::iter(vec![Ok(Bytes::from_static(b"1\tjim\n")), Err("asdf")]);
let error = client.copy_in(&stmt, &[], stream).await.unwrap_err();
assert!(error.to_string().contains("asdf"));
{
let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap();
pin_mut!(sink);
sink.send(Bytes::from_static(b"1\tsteven")).await.unwrap();
}
let stmt = client
.prepare("SELECT id, name FROM foo ORDER BY id")
.await
.unwrap();
let rows = client.query(&stmt, &[]).await.unwrap();
let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap();
assert_eq!(rows.len(), 0);
}