Move binary copy stuff directly into main crate

This commit is contained in:
Steven Fackler 2019-12-08 18:30:47 -08:00
parent 0c84ed9f82
commit bf8b335d2b
7 changed files with 91 additions and 76 deletions

View File

@ -9,7 +9,6 @@ members = [
"postgres-protocol",
"postgres-types",
"tokio-postgres",
"tokio-postgres-binary-copy",
]
[profile.release]

View File

@ -1,16 +0,0 @@
[package]
name = "tokio-postgres-binary-copy"
version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
edition = "2018"
[dependencies]
byteorder = "1.0"
bytes = "0.5"
futures = "0.3"
pin-project-lite = "0.1"
tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" }
[dev-dependencies]
tokio = { version = "0.2", features = ["full"] }
tokio-postgres = { version = "=0.5.0-alpha.2", path = "../tokio-postgres" }

View File

@ -37,6 +37,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"]
[dependencies]
bytes = "0.5"
byteorder = "1.0"
fallible-iterator = "0.2"
futures = "0.3"
log = "0.4"

View File

@ -1,24 +1,26 @@
//! Utilities for working with the PostgreSQL binary copy format.
use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{ready, SinkExt, Stream};
use pin_project_lite::pin_project;
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::io::Cursor;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_postgres::types::{FromSql, IsNull, ToSql, Type, WrongType};
use tokio_postgres::{CopyInSink, CopyOutStream};
#[cfg(test)]
mod test;
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
pin_project! {
/// A type which serializes rows into the PostgreSQL binary copy format.
///
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
pub struct BinaryCopyInWriter {
#[pin]
sink: CopyInSink<Bytes>,
@ -28,10 +30,10 @@ pin_project! {
}
impl BinaryCopyInWriter {
/// Creates a new writer which will write rows of the provided types to the provided sink.
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
let mut buf = BytesMut::new();
buf.reserve(HEADER_LEN);
buf.put_slice(MAGIC); // magic
buf.put_slice(MAGIC);
buf.put_i32(0); // flags
buf.put_i32(0); // header extension
@ -42,19 +44,23 @@ impl BinaryCopyInWriter {
}
}
pub async fn write(
self: Pin<&mut Self>,
values: &[&(dyn ToSql + Send)],
) -> Result<(), Box<dyn Error + Sync + Send>> {
self.write_raw(values.iter().cloned()).await
/// Writes a single row.
///
/// # Panics
///
/// Panics if the number of values provided does not match the number expected.
pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
self.write_raw(slice_iter(values)).await
}
pub async fn write_raw<'a, I>(
self: Pin<&mut Self>,
values: I,
) -> Result<(), Box<dyn Error + Sync + Send>>
/// A maximally-flexible version of `write`.
///
/// # Panics
///
/// Panics if the number of values provided does not match the number expected.
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let mut this = self.project();
@ -69,12 +75,16 @@ impl BinaryCopyInWriter {
this.buf.put_i16(this.types.len() as i16);
for (value, type_) in values.zip(this.types) {
for (i, (value, type_)) in values.zip(this.types).enumerate() {
let idx = this.buf.len();
this.buf.put_i32(0);
let len = match value.to_sql_checked(type_, this.buf)? {
let len = match value
.to_sql_checked(type_, this.buf)
.map_err(|e| Error::to_sql(e, i))?
{
IsNull::Yes => -1,
IsNull::No => i32::try_from(this.buf.len() - idx - 4)?,
IsNull::No => i32::try_from(this.buf.len() - idx - 4)
.map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
};
BigEndian::write_i32(&mut this.buf[idx..], len);
}
@ -86,7 +96,10 @@ impl BinaryCopyInWriter {
Ok(())
}
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, tokio_postgres::Error> {
/// Completes the copy, returning the number of rows added.
///
/// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
let mut this = self.project();
this.buf.put_i16(-1);
@ -100,6 +113,7 @@ struct Header {
}
pin_project! {
/// A stream of rows deserialized from the PostgreSQL binary copy format.
pub struct BinaryCopyOutStream {
#[pin]
stream: CopyOutStream,
@ -109,7 +123,8 @@ pin_project! {
}
impl BinaryCopyOutStream {
pub fn new(types: &[Type], stream: CopyOutStream) -> BinaryCopyOutStream {
/// Creates a stream from a raw copy out stream and the types of the columns being returned.
pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
BinaryCopyOutStream {
stream,
types: Arc::new(types.to_vec()),
@ -119,15 +134,15 @@ impl BinaryCopyOutStream {
}
impl Stream for BinaryCopyOutStream {
type Item = Result<BinaryCopyOutRow, Box<dyn Error + Sync + Send>>;
type Item = Result<BinaryCopyOutRow, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let chunk = match ready!(this.stream.poll_next(cx)) {
Some(Ok(chunk)) => chunk,
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
None => return Poll::Ready(Some(Err("unexpected EOF".into()))),
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
None => return Poll::Ready(Some(Err(Error::closed()))),
};
let mut chunk = Cursor::new(chunk);
@ -136,7 +151,10 @@ impl Stream for BinaryCopyOutStream {
None => {
check_remaining(&chunk, HEADER_LEN)?;
if &chunk.bytes()[..MAGIC.len()] != MAGIC {
return Poll::Ready(Some(Err("invalid magic value".into())));
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
io::ErrorKind::InvalidData,
"invalid magic value",
)))));
}
chunk.advance(MAGIC.len());
@ -162,7 +180,10 @@ impl Stream for BinaryCopyOutStream {
len += 1;
}
if len as usize != this.types.len() {
return Poll::Ready(Some(Err("unexpected tuple size".into())));
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
io::ErrorKind::InvalidInput,
format!("expected {} values but got {}", this.types.len(), len),
)))));
}
let mut ranges = vec![];
@ -188,14 +209,18 @@ impl Stream for BinaryCopyOutStream {
}
}
fn check_remaining(buf: &impl Buf, len: usize) -> Result<(), Box<dyn Error + Sync + Send>> {
fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
if buf.remaining() < len {
Err("unexpected EOF".into())
Err(Error::parse(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)))
} else {
Ok(())
}
}
/// A row of data parsed from a binary copy out stream.
pub struct BinaryCopyOutRow {
buf: Bytes,
ranges: Vec<Option<Range<usize>>>,
@ -203,21 +228,36 @@ pub struct BinaryCopyOutRow {
}
impl BinaryCopyOutRow {
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>>
/// Like `get`, but returns a `Result` rather than panicking.
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
where
T: FromSql<'a>,
{
let type_ = &self.types[idx];
let type_ = match self.types.get(idx) {
Some(type_) => type_,
None => return Err(Error::column(idx.to_string())),
};
if !T::accepts(type_) {
return Err(WrongType::new::<T>(type_.clone()).into());
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(type_.clone())),
idx,
));
}
match &self.ranges[idx] {
Some(range) => T::from_sql(type_, &self.buf[range.clone()]).map_err(Into::into),
None => T::from_sql_null(type_).map_err(Into::into),
}
let r = match &self.ranges[idx] {
Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
None => T::from_sql_null(type_),
};
r.map_err(|e| Error::from_sql(e, idx))
}
/// Deserializes a value from the row.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<'a, T>(&'a self, idx: usize) -> T
where
T: FromSql<'a>,

View File

@ -123,6 +123,7 @@ pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction;
use crate::types::ToSql;
pub mod binary_copy;
mod bind;
#[cfg(feature = "runtime")]
mod cancel_query;

View File

@ -1,22 +1,11 @@
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
use crate::connect;
use futures::{pin_mut, TryStreamExt};
use tokio_postgres::binary_copy::{BinaryCopyInWriter, BinaryCopyOutStream};
use tokio_postgres::types::Type;
use tokio_postgres::{Client, NoTls};
async fn connect() -> Client {
let (client, connection) =
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
.await
.unwrap();
tokio::spawn(async {
connection.await.unwrap();
});
client
}
#[tokio::test]
async fn write_basic() {
let client = connect().await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
@ -50,7 +39,7 @@ async fn write_basic() {
#[tokio::test]
async fn write_many_rows() {
let client = connect().await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
@ -86,7 +75,7 @@ async fn write_many_rows() {
#[tokio::test]
async fn write_big_rows() {
let client = connect().await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
@ -122,7 +111,7 @@ async fn write_big_rows() {
#[tokio::test]
async fn read_basic() {
let client = connect().await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -138,7 +127,7 @@ async fn read_basic() {
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
.await
.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -152,7 +141,7 @@ async fn read_basic() {
#[tokio::test]
async fn read_many_rows() {
let client = connect().await;
let client = connect("user=postgres").await;
client
.batch_execute(
@ -167,7 +156,7 @@ async fn read_many_rows() {
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
.await
.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
.try_collect::<Vec<_>>()
.await
.unwrap();
@ -181,7 +170,7 @@ async fn read_many_rows() {
#[tokio::test]
async fn read_big_rows() {
let client = connect().await;
let client = connect("user=postgres").await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
@ -201,7 +190,7 @@ async fn read_big_rows() {
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
.await
.unwrap();
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream)
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::BYTEA])
.try_collect::<Vec<_>>()
.await
.unwrap();

View File

@ -14,6 +14,7 @@ use tokio_postgres::tls::{NoTls, NoTlsStream};
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::{AsyncMessage, Client, Config, Connection, Error, SimpleQueryMessage};
mod binary_copy;
mod parse;
#[cfg(feature = "runtime")]
mod runtime;