Change the copy_in interface
Rather than taking in a Stream and advancing it internally, return a Sink that can be advanced by the calling code. This significantly simplifies encoding logic for things like tokio-postgres-binary-copy. Similarly, the blocking interface returns a Writer. Closes #489
This commit is contained in:
parent
a5428e6a03
commit
e5e03b0064
@ -1,19 +1,15 @@
|
||||
use crate::iter::Iter;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Config;
|
||||
use crate::{CopyInWriter, CopyOutReader, Statement, ToStatement, Transaction};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::executor;
|
||||
use std::io::{BufRead, Read};
|
||||
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
|
||||
use tokio_postgres::types::{ToSql, Type};
|
||||
#[cfg(feature = "runtime")]
|
||||
use tokio_postgres::Socket;
|
||||
use tokio_postgres::{Error, Row, SimpleQueryMessage};
|
||||
|
||||
use crate::copy_in_stream::CopyInStream;
|
||||
use crate::copy_out_reader::CopyOutReader;
|
||||
use crate::iter::Iter;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Config;
|
||||
use crate::{Statement, ToStatement, Transaction};
|
||||
|
||||
/// A synchronous PostgreSQL client.
|
||||
///
|
||||
/// This is a lightweight wrapper over the asynchronous tokio_postgres `Client`.
|
||||
@ -264,29 +260,33 @@ impl Client {
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. The data in the provided reader is
|
||||
/// passed along to the server verbatim; it is the caller's responsibility to ensure it uses the proper format.
|
||||
///
|
||||
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use postgres::{Client, NoTls};
|
||||
/// use std::io::Write;
|
||||
///
|
||||
/// # fn main() -> Result<(), postgres::Error> {
|
||||
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let mut client = Client::connect("host=localhost user=postgres", NoTls)?;
|
||||
///
|
||||
/// client.copy_in("COPY people FROM stdin", &[], &mut "1\tjohn\n2\tjane\n".as_bytes())?;
|
||||
/// let mut writer = client.copy_in("COPY people FROM stdin", &[])?;
|
||||
/// writer.write_all(b"1\tjohn\n2\tjane\n")?;
|
||||
/// writer.finish()?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn copy_in<T, R>(
|
||||
pub fn copy_in<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
reader: R,
|
||||
) -> Result<u64, Error>
|
||||
) -> Result<CopyInWriter<'_>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
R: Read + Unpin,
|
||||
{
|
||||
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
|
||||
let sink = executor::block_on(self.0.copy_in(query, params))?;
|
||||
Ok(CopyInWriter::new(sink))
|
||||
}
|
||||
|
||||
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
|
||||
@ -312,7 +312,7 @@ impl Client {
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<impl BufRead, Error>
|
||||
) -> Result<CopyOutReader<'_>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
{
|
||||
|
@ -1,24 +0,0 @@
|
||||
use futures::Stream;
|
||||
use std::io::{self, Cursor, Read};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
pub struct CopyInStream<R>(pub R);
|
||||
|
||||
impl<R> Stream for CopyInStream<R>
|
||||
where
|
||||
R: Read + Unpin,
|
||||
{
|
||||
type Item = io::Result<Cursor<Vec<u8>>>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
_: &mut Context<'_>,
|
||||
) -> Poll<Option<io::Result<Cursor<Vec<u8>>>>> {
|
||||
let mut buf = vec![];
|
||||
match self.0.by_ref().take(4096).read_to_end(&mut buf)? {
|
||||
0 => Poll::Ready(None),
|
||||
_ => Poll::Ready(Some(Ok(Cursor::new(buf)))),
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::{executor, SinkExt};
|
||||
use std::io;
|
||||
use std::io::Write;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use tokio_postgres::{CopyInSink, Error};
|
||||
|
||||
/// The writer returned by the `copy_in` method.
|
||||
///
|
||||
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
|
||||
pub struct CopyInWriter<'a> {
|
||||
sink: Pin<Box<CopyInSink<Bytes>>>,
|
||||
buf: BytesMut,
|
||||
_p: PhantomData<&'a mut ()>,
|
||||
}
|
||||
|
||||
// no-op impl to extend borrow until drop
|
||||
impl Drop for CopyInWriter<'_> {
|
||||
fn drop(&mut self) {}
|
||||
}
|
||||
|
||||
impl<'a> CopyInWriter<'a> {
|
||||
pub(crate) fn new(sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
|
||||
CopyInWriter {
|
||||
sink: Box::pin(sink),
|
||||
buf: BytesMut::new(),
|
||||
_p: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Completes the copy, returning the number of rows written.
|
||||
///
|
||||
/// If this is not called, the copy will be aborted.
|
||||
pub fn finish(mut self) -> Result<u64, Error> {
|
||||
self.flush_inner()?;
|
||||
executor::block_on(self.sink.as_mut().finish())
|
||||
}
|
||||
|
||||
fn flush_inner(&mut self) -> Result<(), Error> {
|
||||
if self.buf.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
executor::block_on(self.sink.as_mut().send(self.buf.split().freeze()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for CopyInWriter<'_> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
if self.buf.len() > 4096 {
|
||||
self.flush()?;
|
||||
}
|
||||
|
||||
self.buf.extend_from_slice(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.flush_inner()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
}
|
||||
}
|
@ -1,33 +1,24 @@
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{executor, Stream};
|
||||
use futures::executor;
|
||||
use std::io::{self, BufRead, Cursor, Read};
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use tokio_postgres::Error;
|
||||
use tokio_postgres::{CopyStream, Error};
|
||||
|
||||
/// The reader returned by the `copy_out` method.
|
||||
pub struct CopyOutReader<'a, S>
|
||||
where
|
||||
S: Stream,
|
||||
{
|
||||
it: executor::BlockingStream<Pin<Box<S>>>,
|
||||
pub struct CopyOutReader<'a> {
|
||||
it: executor::BlockingStream<Pin<Box<CopyStream>>>,
|
||||
cur: Cursor<Bytes>,
|
||||
_p: PhantomData<&'a mut ()>,
|
||||
}
|
||||
|
||||
// no-op impl to extend borrow until drop
|
||||
impl<'a, S> Drop for CopyOutReader<'a, S>
|
||||
where
|
||||
S: Stream,
|
||||
{
|
||||
impl Drop for CopyOutReader<'_> {
|
||||
fn drop(&mut self) {}
|
||||
}
|
||||
|
||||
impl<'a, S> CopyOutReader<'a, S>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, Error>>,
|
||||
{
|
||||
pub(crate) fn new(stream: S) -> Result<CopyOutReader<'a, S>, Error> {
|
||||
impl<'a> CopyOutReader<'a> {
|
||||
pub(crate) fn new(stream: CopyStream) -> Result<CopyOutReader<'a>, Error> {
|
||||
let mut it = executor::block_on_stream(Box::pin(stream));
|
||||
let cur = match it.next() {
|
||||
Some(Ok(cur)) => cur,
|
||||
@ -43,10 +34,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S> Read for CopyOutReader<'a, S>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, Error>>,
|
||||
{
|
||||
impl Read for CopyOutReader<'_> {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let b = self.fill_buf()?;
|
||||
let len = usize::min(buf.len(), b.len());
|
||||
@ -56,10 +44,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, S> BufRead for CopyOutReader<'a, S>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, Error>>,
|
||||
{
|
||||
impl BufRead for CopyOutReader<'_> {
|
||||
fn fill_buf(&mut self) -> io::Result<&[u8]> {
|
||||
if self.cur.remaining() == 0 {
|
||||
match self.it.next() {
|
||||
|
@ -69,6 +69,8 @@ pub use tokio_postgres::{
|
||||
pub use crate::client::*;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::config::Config;
|
||||
pub use crate::copy_in_writer::CopyInWriter;
|
||||
pub use crate::copy_out_reader::CopyOutReader;
|
||||
#[doc(no_inline)]
|
||||
pub use crate::error::Error;
|
||||
#[doc(no_inline)]
|
||||
@ -80,7 +82,7 @@ pub use crate::transaction::*;
|
||||
mod client;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub mod config;
|
||||
mod copy_in_stream;
|
||||
mod copy_in_writer;
|
||||
mod copy_out_reader;
|
||||
mod iter;
|
||||
mod transaction;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use std::io::Read;
|
||||
use std::io::{Read, Write};
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
@ -154,13 +154,9 @@ fn copy_in() {
|
||||
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
|
||||
.unwrap();
|
||||
|
||||
client
|
||||
.copy_in(
|
||||
"COPY foo FROM stdin",
|
||||
&[],
|
||||
&mut &b"1\tsteven\n2\ttimothy"[..],
|
||||
)
|
||||
.unwrap();
|
||||
let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap();
|
||||
writer.write_all(b"1\tsteven\n2\ttimothy").unwrap();
|
||||
writer.finish().unwrap();
|
||||
|
||||
let rows = client
|
||||
.query("SELECT id, name FROM foo ORDER BY id", &[])
|
||||
@ -173,6 +169,25 @@ fn copy_in() {
|
||||
assert_eq!(rows[1].get::<_, &str>(1), "timothy");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_in_abort() {
|
||||
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
|
||||
|
||||
client
|
||||
.simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)")
|
||||
.unwrap();
|
||||
|
||||
let mut writer = client.copy_in("COPY foo FROM stdin", &[]).unwrap();
|
||||
writer.write_all(b"1\tsteven\n2\ttimothy").unwrap();
|
||||
drop(writer);
|
||||
|
||||
let rows = client
|
||||
.query("SELECT id, name FROM foo ORDER BY id", &[])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_out() {
|
||||
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
|
||||
|
@ -1,14 +1,10 @@
|
||||
use crate::iter::Iter;
|
||||
use crate::{CopyInWriter, CopyOutReader, Portal, Statement, ToStatement};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::executor;
|
||||
use std::io::{BufRead, Read};
|
||||
use tokio_postgres::types::{ToSql, Type};
|
||||
use tokio_postgres::{Error, Row, SimpleQueryMessage};
|
||||
|
||||
use crate::copy_in_stream::CopyInStream;
|
||||
use crate::copy_out_reader::CopyOutReader;
|
||||
use crate::iter::Iter;
|
||||
use crate::{Portal, Statement, ToStatement};
|
||||
|
||||
/// A representation of a PostgreSQL database transaction.
|
||||
///
|
||||
/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
|
||||
@ -117,17 +113,16 @@ impl<'a> Transaction<'a> {
|
||||
}
|
||||
|
||||
/// Like `Client::copy_in`.
|
||||
pub fn copy_in<T, R>(
|
||||
pub fn copy_in<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
reader: R,
|
||||
) -> Result<u64, Error>
|
||||
) -> Result<CopyInWriter<'_>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
R: Read + Unpin,
|
||||
{
|
||||
executor::block_on(self.0.copy_in(query, params, CopyInStream(reader)))
|
||||
let sink = executor::block_on(self.0.copy_in(query, params))?;
|
||||
Ok(CopyInWriter::new(sink))
|
||||
}
|
||||
|
||||
/// Like `Client::copy_out`.
|
||||
@ -135,7 +130,7 @@ impl<'a> Transaction<'a> {
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<impl BufRead, Error>
|
||||
) -> Result<CopyOutReader<'_>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
{
|
||||
|
@ -8,7 +8,6 @@ edition = "2018"
|
||||
byteorder = "1.0"
|
||||
bytes = "0.5"
|
||||
futures = "0.3"
|
||||
parking_lot = "0.10"
|
||||
pin-project-lite = "0.1"
|
||||
tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" }
|
||||
|
||||
|
@ -1,145 +1,95 @@
|
||||
use bytes::{BufMut, Bytes, BytesMut, Buf};
|
||||
use futures::{future, ready, Stream};
|
||||
use parking_lot::Mutex;
|
||||
use futures::{ready, Stream, SinkExt};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::future::Future;
|
||||
use std::ops::Range;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio_postgres::types::{IsNull, ToSql, Type, FromSql, WrongType};
|
||||
use tokio_postgres::CopyStream;
|
||||
use tokio_postgres::{CopyStream, CopyInSink};
|
||||
use std::io::Cursor;
|
||||
use byteorder::{ByteOrder, BigEndian};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
const BLOCK_SIZE: usize = 4096;
|
||||
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
|
||||
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
|
||||
|
||||
pin_project! {
|
||||
pub struct BinaryCopyInStream<F> {
|
||||
pub struct BinaryCopyInWriter {
|
||||
#[pin]
|
||||
future: F,
|
||||
buf: Arc<Mutex<BytesMut>>,
|
||||
done: bool,
|
||||
sink: CopyInSink<Bytes>,
|
||||
types: Vec<Type>,
|
||||
buf: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> BinaryCopyInStream<F>
|
||||
where
|
||||
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
|
||||
{
|
||||
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyInStream<F>
|
||||
where
|
||||
M: FnOnce(BinaryCopyInWriter) -> F,
|
||||
{
|
||||
impl BinaryCopyInWriter {
|
||||
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.reserve(HEADER_LEN);
|
||||
buf.put_slice(MAGIC); // magic
|
||||
buf.put_i32(0); // flags
|
||||
buf.put_i32(0); // header extension
|
||||
|
||||
let buf = Arc::new(Mutex::new(buf));
|
||||
let writer = BinaryCopyInWriter {
|
||||
buf: buf.clone(),
|
||||
BinaryCopyInWriter {
|
||||
sink,
|
||||
types: types.to_vec(),
|
||||
};
|
||||
|
||||
BinaryCopyInStream {
|
||||
future: write_values(writer),
|
||||
buf,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Stream for BinaryCopyInStream<F>
|
||||
where
|
||||
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
|
||||
{
|
||||
type Item = Result<Bytes, Box<dyn Error + Sync + Send>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
if *this.done {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
*this.done = this.future.poll(cx)?.is_ready();
|
||||
|
||||
let mut buf = this.buf.lock();
|
||||
if *this.done {
|
||||
buf.reserve(2);
|
||||
buf.put_i16(-1);
|
||||
Poll::Ready(Some(Ok(buf.split().freeze())))
|
||||
} else if buf.len() > BLOCK_SIZE {
|
||||
Poll::Ready(Some(Ok(buf.split().freeze())))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
|
||||
pub struct BinaryCopyInWriter {
|
||||
buf: Arc<Mutex<BytesMut>>,
|
||||
types: Vec<Type>,
|
||||
}
|
||||
|
||||
impl BinaryCopyInWriter {
|
||||
pub async fn write(
|
||||
&mut self,
|
||||
self: Pin<&mut Self>,
|
||||
values: &[&(dyn ToSql + Send)],
|
||||
) -> Result<(), Box<dyn Error + Sync + Send>> {
|
||||
self.write_raw(values.iter().cloned()).await
|
||||
}
|
||||
|
||||
pub async fn write_raw<'a, I>(&mut self, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
|
||||
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let mut this = self.project();
|
||||
|
||||
let values = values.into_iter();
|
||||
assert!(
|
||||
values.len() == self.types.len(),
|
||||
values.len() == this.types.len(),
|
||||
"expected {} values but got {}",
|
||||
self.types.len(),
|
||||
this.types.len(),
|
||||
values.len(),
|
||||
);
|
||||
|
||||
future::poll_fn(|_| {
|
||||
if self.buf.lock().len() > BLOCK_SIZE {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
this.buf.put_i16(this.types.len() as i16);
|
||||
|
||||
let mut buf = self.buf.lock();
|
||||
|
||||
buf.reserve(2);
|
||||
buf.put_u16(self.types.len() as u16);
|
||||
|
||||
for (value, type_) in values.zip(&self.types) {
|
||||
let idx = buf.len();
|
||||
buf.reserve(4);
|
||||
buf.put_i32(0);
|
||||
let len = match value.to_sql_checked(type_, &mut buf)? {
|
||||
for (value, type_) in values.zip(this.types) {
|
||||
let idx = this.buf.len();
|
||||
this.buf.put_i32(0);
|
||||
let len = match value.to_sql_checked(type_, this.buf)? {
|
||||
IsNull::Yes => -1,
|
||||
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
|
||||
IsNull::No => i32::try_from(this.buf.len() - idx - 4)?,
|
||||
};
|
||||
BigEndian::write_i32(&mut buf[idx..], len);
|
||||
BigEndian::write_i32(&mut this.buf[idx..], len);
|
||||
}
|
||||
|
||||
if this.buf.len() > 4096 {
|
||||
this.sink.send(this.buf.split().freeze()).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, tokio_postgres::Error> {
|
||||
let mut this = self.project();
|
||||
|
||||
this.buf.put_i16(-1);
|
||||
this.sink.send(this.buf.split().freeze()).await?;
|
||||
this.sink.finish().await
|
||||
}
|
||||
}
|
||||
|
||||
struct Header {
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::{BinaryCopyInStream, BinaryCopyOutStream};
|
||||
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::{Client, NoTls};
|
||||
use futures::TryStreamExt;
|
||||
use futures::{TryStreamExt, pin_mut};
|
||||
|
||||
async fn connect() -> Client {
|
||||
let (client, connection) =
|
||||
@ -23,19 +23,12 @@ async fn write_basic() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| {
|
||||
async move {
|
||||
w.write(&[&1i32, &"foobar"]).await?;
|
||||
w.write(&[&2i32, &None::<&str>]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
client
|
||||
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
|
||||
.await
|
||||
.unwrap();
|
||||
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
|
||||
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]);
|
||||
pin_mut!(writer);
|
||||
writer.as_mut().write(&[&1i32, &"foobar"]).await.unwrap();
|
||||
writer.as_mut().write(&[&2i32, &None::<&str>]).await.unwrap();
|
||||
writer.finish().await.unwrap();
|
||||
|
||||
let rows = client
|
||||
.query("SELECT id, bar FROM foo ORDER BY id", &[])
|
||||
@ -57,20 +50,15 @@ async fn write_many_rows() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::TEXT], |mut w| {
|
||||
async move {
|
||||
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
|
||||
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]);
|
||||
pin_mut!(writer);
|
||||
|
||||
for i in 0..10_000i32 {
|
||||
w.write(&[&i, &format!("the value for {}", i)]).await?;
|
||||
writer.as_mut().write(&[&i, &format!("the value for {}", i)]).await.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
client
|
||||
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.finish().await.unwrap();
|
||||
|
||||
let rows = client
|
||||
.query("SELECT id, bar FROM foo ORDER BY id", &[])
|
||||
@ -91,20 +79,15 @@ async fn write_big_rows() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stream = BinaryCopyInStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
|
||||
async move {
|
||||
let sink = client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[]).await.unwrap();
|
||||
let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::BYTEA]);
|
||||
pin_mut!(writer);
|
||||
|
||||
for i in 0..2i32 {
|
||||
w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?;
|
||||
writer.as_mut().write(&[&i, &vec![i as u8; 128 * 1024]]).await.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
client
|
||||
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.finish().await.unwrap();
|
||||
|
||||
let rows = client
|
||||
.query("SELECT id, bar FROM foo ORDER BY id", &[])
|
||||
|
@ -14,18 +14,17 @@ use crate::to_statement::ToStatement;
|
||||
use crate::types::{Oid, ToSql, Type};
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Socket;
|
||||
use crate::{cancel_query_raw, copy_in, copy_out, query, Transaction};
|
||||
use crate::{cancel_query_raw, copy_in, copy_out, query, CopyInSink, Transaction};
|
||||
use crate::{prepare, SimpleQueryMessage};
|
||||
use crate::{simple_query, Row};
|
||||
use crate::{Error, Statement};
|
||||
use bytes::{Buf, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{future, pin_mut, ready, StreamExt, TryStream, TryStreamExt};
|
||||
use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use std::collections::HashMap;
|
||||
use std::error;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
@ -340,29 +339,26 @@ impl Client {
|
||||
query::execute(self.inner(), statement, params).await
|
||||
}
|
||||
|
||||
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
|
||||
/// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data.
|
||||
///
|
||||
/// The data in the provided stream is passed along to the server verbatim; it is the caller's responsibility to
|
||||
/// ensure it uses the proper format.
|
||||
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
|
||||
/// not, the copy will be aborted.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of parameters provided does not match the number expected.
|
||||
pub async fn copy_in<T, S>(
|
||||
pub async fn copy_in<T, U>(
|
||||
&self,
|
||||
statement: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
stream: S,
|
||||
) -> Result<u64, Error>
|
||||
) -> Result<CopyInSink<U>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
S: TryStream,
|
||||
S::Ok: Buf + 'static + Send,
|
||||
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
|
||||
U: Buf + 'static + Send,
|
||||
{
|
||||
let statement = statement.__convert().into_statement(self).await?;
|
||||
let params = slice_iter(params);
|
||||
copy_in::copy_in(self.inner(), statement, params, stream).await
|
||||
copy_in::copy_in(self.inner(), statement, params).await
|
||||
}
|
||||
|
||||
/// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data.
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::ToSql;
|
||||
@ -6,11 +6,13 @@ use crate::{query, Error, Statement};
|
||||
use bytes::buf::BufExt;
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{pin_mut, ready, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
|
||||
use futures::{ready, Sink, SinkExt, Stream, StreamExt};
|
||||
use futures::future;
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::message::frontend::CopyData;
|
||||
use std::error;
|
||||
use std::marker::{PhantomPinned, PhantomData};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
@ -61,18 +63,148 @@ impl Stream for CopyInReceiver {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn copy_in<'a, I, S>(
|
||||
enum SinkState {
|
||||
Active,
|
||||
Closing,
|
||||
Reading,
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A sink for `COPY ... FROM STDIN` query data.
|
||||
///
|
||||
/// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is
|
||||
/// not, the copy will be aborted.
|
||||
pub struct CopyInSink<T> {
|
||||
#[pin]
|
||||
sender: mpsc::Sender<CopyInMessage>,
|
||||
responses: Responses,
|
||||
buf: BytesMut,
|
||||
state: SinkState,
|
||||
#[pin]
|
||||
_p: PhantomPinned,
|
||||
_p2: PhantomData<T>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CopyInSink<T>
|
||||
where
|
||||
T: Buf + 'static + Send,
|
||||
{
|
||||
/// A poll-based version of `finish`.
|
||||
pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<u64, Error>> {
|
||||
loop {
|
||||
match self.state {
|
||||
SinkState::Active => {
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
let mut this = self.as_mut().project();
|
||||
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
|
||||
this.sender
|
||||
.start_send(CopyInMessage::Done)
|
||||
.map_err(|_| Error::closed())?;
|
||||
*this.state = SinkState::Closing;
|
||||
}
|
||||
SinkState::Closing => {
|
||||
let this = self.as_mut().project();
|
||||
ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?;
|
||||
*this.state = SinkState::Reading;
|
||||
}
|
||||
SinkState::Reading => {
|
||||
let this = self.as_mut().project();
|
||||
match ready!(this.responses.poll_next(cx))? {
|
||||
Message::CommandComplete(body) => {
|
||||
let rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
return Poll::Ready(Ok(rows));
|
||||
}
|
||||
_ => return Poll::Ready(Err(Error::unexpected_message())),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Completes the copy, returning the number of rows inserted.
|
||||
///
|
||||
/// The `Sink::close` method is equivalent to `finish`, except that it does not return the
|
||||
/// number of rows.
|
||||
pub async fn finish(mut self: Pin<&mut Self>) -> Result<u64, Error> {
|
||||
future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sink<T> for CopyInSink<T>
|
||||
where
|
||||
T: Buf + 'static + Send,
|
||||
{
|
||||
type Error = Error;
|
||||
|
||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
self.project()
|
||||
.sender
|
||||
.poll_ready(cx)
|
||||
.map_err(|_| Error::closed())
|
||||
}
|
||||
|
||||
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> {
|
||||
let this = self.project();
|
||||
|
||||
let data: Box<dyn Buf + Send> = if item.remaining() > 4096 {
|
||||
if this.buf.is_empty() {
|
||||
Box::new(item)
|
||||
} else {
|
||||
Box::new(this.buf.split().freeze().chain(item))
|
||||
}
|
||||
} else {
|
||||
this.buf.put(item);
|
||||
if this.buf.len() > 4096 {
|
||||
Box::new(this.buf.split().freeze())
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let data = CopyData::new(data).map_err(Error::encode)?;
|
||||
this.sender
|
||||
.start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
|
||||
.map_err(|_| Error::closed())
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
let mut this = self.project();
|
||||
|
||||
if !this.buf.is_empty() {
|
||||
ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
|
||||
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
|
||||
let data = CopyData::new(data).map_err(Error::encode)?;
|
||||
this.sender
|
||||
.as_mut()
|
||||
.start_send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
|
||||
.map_err(|_| Error::closed())?;
|
||||
}
|
||||
|
||||
this.sender.poll_flush(cx).map_err(|_| Error::closed())
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
self.poll_finish(cx).map_ok(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn copy_in<'a, I, T>(
|
||||
client: &InnerClient,
|
||||
statement: Statement,
|
||||
params: I,
|
||||
stream: S,
|
||||
) -> Result<u64, Error>
|
||||
) -> Result<CopyInSink<T>, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a dyn ToSql>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
S: TryStream,
|
||||
S::Ok: Buf + 'static + Send,
|
||||
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
|
||||
T: Buf + 'static + Send,
|
||||
{
|
||||
let buf = query::encode(client, &statement, params)?;
|
||||
|
||||
@ -95,60 +227,12 @@ where
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
let mut bytes = BytesMut::new();
|
||||
let stream = stream.into_stream();
|
||||
pin_mut!(stream);
|
||||
|
||||
while let Some(buf) = stream.try_next().await.map_err(Error::copy_in_stream)? {
|
||||
let data: Box<dyn Buf + Send> = if buf.remaining() > 4096 {
|
||||
if bytes.is_empty() {
|
||||
Box::new(buf)
|
||||
} else {
|
||||
Box::new(bytes.split().freeze().chain(buf))
|
||||
}
|
||||
} else {
|
||||
bytes.reserve(buf.remaining());
|
||||
bytes.put(buf);
|
||||
if bytes.len() > 4096 {
|
||||
Box::new(bytes.split().freeze())
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let data = CopyData::new(data).map_err(Error::encode)?;
|
||||
sender
|
||||
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
|
||||
.await
|
||||
.map_err(|_| Error::closed())?;
|
||||
}
|
||||
|
||||
if !bytes.is_empty() {
|
||||
let data: Box<dyn Buf + Send> = Box::new(bytes.freeze());
|
||||
let data = CopyData::new(data).map_err(Error::encode)?;
|
||||
sender
|
||||
.send(CopyInMessage::Message(FrontendMessage::CopyData(data)))
|
||||
.await
|
||||
.map_err(|_| Error::closed())?;
|
||||
}
|
||||
|
||||
sender
|
||||
.send(CopyInMessage::Done)
|
||||
.await
|
||||
.map_err(|_| Error::closed())?;
|
||||
|
||||
match responses.next().await? {
|
||||
Message::CommandComplete(body) => {
|
||||
let rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
Ok(rows)
|
||||
}
|
||||
_ => Err(Error::unexpected_message()),
|
||||
}
|
||||
Ok(CopyInSink {
|
||||
sender,
|
||||
responses,
|
||||
buf: BytesMut::new(),
|
||||
state: SinkState::Active,
|
||||
_p: PhantomPinned,
|
||||
_p2: PhantomData,
|
||||
})
|
||||
}
|
||||
|
@ -337,7 +337,6 @@ enum Kind {
|
||||
ToSql(usize),
|
||||
FromSql(usize),
|
||||
Column(String),
|
||||
CopyInStream,
|
||||
Closed,
|
||||
Db,
|
||||
Parse,
|
||||
@ -376,7 +375,6 @@ impl fmt::Display for Error {
|
||||
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
|
||||
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||
Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?,
|
||||
Kind::CopyInStream => fmt.write_str("error from a copy_in stream")?,
|
||||
Kind::Closed => fmt.write_str("connection closed")?,
|
||||
Kind::Db => fmt.write_str("db error")?,
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
@ -458,13 +456,6 @@ impl Error {
|
||||
Error::new(Kind::Column(column), None)
|
||||
}
|
||||
|
||||
pub(crate) fn copy_in_stream<E>(e: E) -> Error
|
||||
where
|
||||
E: Into<Box<dyn error::Error + Sync + Send>>,
|
||||
{
|
||||
Error::new(Kind::CopyInStream, Some(e.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn tls(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::Tls, Some(e))
|
||||
}
|
||||
|
@ -105,6 +105,7 @@
|
||||
pub use crate::client::Client;
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connection::Connection;
|
||||
pub use crate::copy_in::CopyInSink;
|
||||
pub use crate::copy_out::CopyStream;
|
||||
use crate::error::DbError;
|
||||
pub use crate::error::Error;
|
||||
|
@ -9,12 +9,12 @@ use crate::types::{ToSql, Type};
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Socket;
|
||||
use crate::{
|
||||
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
|
||||
bind, query, slice_iter, Client, CopyInSink, Error, Portal, Row, SimpleQueryMessage, Statement,
|
||||
ToStatement,
|
||||
};
|
||||
use bytes::Buf;
|
||||
use futures::{TryStream, TryStreamExt};
|
||||
use futures::{TryStreamExt};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// A representation of a PostgreSQL database transaction.
|
||||
@ -209,19 +209,16 @@ impl<'a> Transaction<'a> {
|
||||
}
|
||||
|
||||
/// Like `Client::copy_in`.
|
||||
pub async fn copy_in<T, S>(
|
||||
pub async fn copy_in<T, U>(
|
||||
&self,
|
||||
statement: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
stream: S,
|
||||
) -> Result<u64, Error>
|
||||
) -> Result<CopyInSink<U>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
S: TryStream,
|
||||
S::Ok: Buf + 'static + Send,
|
||||
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
|
||||
U: Buf + 'static + Send,
|
||||
{
|
||||
self.client.copy_in(statement, params, stream).await
|
||||
self.client.copy_in(statement, params).await
|
||||
}
|
||||
|
||||
/// Like `Client::copy_out`.
|
||||
|
@ -2,8 +2,7 @@
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{future, stream, StreamExt};
|
||||
use futures::{join, try_join, FutureExt, TryStreamExt};
|
||||
use futures::{future, stream, StreamExt, SinkExt, pin_mut, join, try_join, FutureExt, TryStreamExt};
|
||||
use std::fmt::Write;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
@ -409,23 +408,21 @@ async fn copy_in() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
|
||||
let stream = stream::iter(
|
||||
let mut stream = stream::iter(
|
||||
vec![
|
||||
Bytes::from_static(b"1\tjim\n"),
|
||||
Bytes::from_static(b"2\tjoe\n"),
|
||||
]
|
||||
.into_iter()
|
||||
.map(Ok::<_, String>),
|
||||
.map(Ok::<_, Error>),
|
||||
);
|
||||
let rows = client.copy_in(&stmt, &[], stream).await.unwrap();
|
||||
let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap();
|
||||
pin_mut!(sink);
|
||||
sink.send_all(&mut stream).await.unwrap();
|
||||
let rows = sink.finish().await.unwrap();
|
||||
assert_eq!(rows, 2);
|
||||
|
||||
let stmt = client
|
||||
.prepare("SELECT id, name FROM foo ORDER BY id")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap();
|
||||
|
||||
assert_eq!(rows.len(), 2);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
@ -448,8 +445,6 @@ async fn copy_in_large() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
|
||||
|
||||
let a = Bytes::from_static(b"0\tname0\n");
|
||||
let mut b = BytesMut::new();
|
||||
for i in 1..5_000 {
|
||||
@ -459,13 +454,16 @@ async fn copy_in_large() {
|
||||
for i in 5_000..10_000 {
|
||||
writeln!(c, "{0}\tname{0}", i).unwrap();
|
||||
}
|
||||
let stream = stream::iter(
|
||||
let mut stream = stream::iter(
|
||||
vec![a, b.freeze(), c.freeze()]
|
||||
.into_iter()
|
||||
.map(Ok::<_, String>),
|
||||
.map(Ok::<_, Error>),
|
||||
);
|
||||
|
||||
let rows = client.copy_in(&stmt, &[], stream).await.unwrap();
|
||||
let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap();
|
||||
pin_mut!(sink);
|
||||
sink.send_all(&mut stream).await.unwrap();
|
||||
let rows = sink.finish().await.unwrap();
|
||||
assert_eq!(rows, 10_000);
|
||||
}
|
||||
|
||||
@ -483,16 +481,13 @@ async fn copy_in_error() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stmt = client.prepare("COPY foo FROM STDIN").await.unwrap();
|
||||
let stream = stream::iter(vec![Ok(Bytes::from_static(b"1\tjim\n")), Err("asdf")]);
|
||||
let error = client.copy_in(&stmt, &[], stream).await.unwrap_err();
|
||||
assert!(error.to_string().contains("asdf"));
|
||||
{
|
||||
let sink = client.copy_in("COPY foo FROM STDIN", &[]).await.unwrap();
|
||||
pin_mut!(sink);
|
||||
sink.send(Bytes::from_static(b"1\tsteven")).await.unwrap();
|
||||
}
|
||||
|
||||
let stmt = client
|
||||
.prepare("SELECT id, name FROM foo ORDER BY id")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = client.query(&stmt, &[]).await.unwrap();
|
||||
let rows = client.query("SELECT id, name FROM foo ORDER BY id", &[]).await.unwrap();
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user