Move binary copy stuff directly into main crate
This commit is contained in:
parent
0c84ed9f82
commit
bf8b335d2b
@ -9,7 +9,6 @@ members = [
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-binary-copy",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
|
@ -1,16 +0,0 @@
|
||||
[package]
|
||||
name = "tokio-postgres-binary-copy"
|
||||
version = "0.1.0"
|
||||
authors = ["Steven Fackler <sfackler@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
byteorder = "1.0"
|
||||
bytes = "0.5"
|
||||
futures = "0.3"
|
||||
pin-project-lite = "0.1"
|
||||
tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "0.2", features = ["full"] }
|
||||
tokio-postgres = { version = "=0.5.0-alpha.2", path = "../tokio-postgres" }
|
@ -37,6 +37,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"]
|
||||
|
||||
[dependencies]
|
||||
bytes = "0.5"
|
||||
byteorder = "1.0"
|
||||
fallible-iterator = "0.2"
|
||||
futures = "0.3"
|
||||
log = "0.4"
|
||||
|
@ -1,24 +1,26 @@
|
||||
//! Utilities for working with the PostgreSQL binary copy format.
|
||||
|
||||
use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
|
||||
use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use futures::{ready, SinkExt, Stream};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
use std::io::Cursor;
|
||||
use std::ops::Range;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio_postgres::types::{FromSql, IsNull, ToSql, Type, WrongType};
|
||||
use tokio_postgres::{CopyInSink, CopyOutStream};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
|
||||
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
|
||||
|
||||
pin_project! {
|
||||
/// A type which serializes rows into the PostgreSQL binary copy format.
|
||||
///
|
||||
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
|
||||
pub struct BinaryCopyInWriter {
|
||||
#[pin]
|
||||
sink: CopyInSink<Bytes>,
|
||||
@ -28,10 +30,10 @@ pin_project! {
|
||||
}
|
||||
|
||||
impl BinaryCopyInWriter {
|
||||
/// Creates a new writer which will write rows of the provided types to the provided sink.
|
||||
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.reserve(HEADER_LEN);
|
||||
buf.put_slice(MAGIC); // magic
|
||||
buf.put_slice(MAGIC);
|
||||
buf.put_i32(0); // flags
|
||||
buf.put_i32(0); // header extension
|
||||
|
||||
@ -42,19 +44,23 @@ impl BinaryCopyInWriter {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write(
|
||||
self: Pin<&mut Self>,
|
||||
values: &[&(dyn ToSql + Send)],
|
||||
) -> Result<(), Box<dyn Error + Sync + Send>> {
|
||||
self.write_raw(values.iter().cloned()).await
|
||||
/// Writes a single row.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of values provided does not match the number expected.
|
||||
pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
|
||||
self.write_raw(slice_iter(values)).await
|
||||
}
|
||||
|
||||
pub async fn write_raw<'a, I>(
|
||||
self: Pin<&mut Self>,
|
||||
values: I,
|
||||
) -> Result<(), Box<dyn Error + Sync + Send>>
|
||||
/// A maximally-flexible version of `write`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of values provided does not match the number expected.
|
||||
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
|
||||
I: IntoIterator<Item = &'a dyn ToSql>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let mut this = self.project();
|
||||
@ -69,12 +75,16 @@ impl BinaryCopyInWriter {
|
||||
|
||||
this.buf.put_i16(this.types.len() as i16);
|
||||
|
||||
for (value, type_) in values.zip(this.types) {
|
||||
for (i, (value, type_)) in values.zip(this.types).enumerate() {
|
||||
let idx = this.buf.len();
|
||||
this.buf.put_i32(0);
|
||||
let len = match value.to_sql_checked(type_, this.buf)? {
|
||||
let len = match value
|
||||
.to_sql_checked(type_, this.buf)
|
||||
.map_err(|e| Error::to_sql(e, i))?
|
||||
{
|
||||
IsNull::Yes => -1,
|
||||
IsNull::No => i32::try_from(this.buf.len() - idx - 4)?,
|
||||
IsNull::No => i32::try_from(this.buf.len() - idx - 4)
|
||||
.map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
|
||||
};
|
||||
BigEndian::write_i32(&mut this.buf[idx..], len);
|
||||
}
|
||||
@ -86,7 +96,10 @@ impl BinaryCopyInWriter {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, tokio_postgres::Error> {
|
||||
/// Completes the copy, returning the number of rows added.
|
||||
///
|
||||
/// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
|
||||
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
|
||||
let mut this = self.project();
|
||||
|
||||
this.buf.put_i16(-1);
|
||||
@ -100,6 +113,7 @@ struct Header {
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A stream of rows deserialized from the PostgreSQL binary copy format.
|
||||
pub struct BinaryCopyOutStream {
|
||||
#[pin]
|
||||
stream: CopyOutStream,
|
||||
@ -109,7 +123,8 @@ pin_project! {
|
||||
}
|
||||
|
||||
impl BinaryCopyOutStream {
|
||||
pub fn new(types: &[Type], stream: CopyOutStream) -> BinaryCopyOutStream {
|
||||
/// Creates a stream from a raw copy out stream and the types of the columns being returned.
|
||||
pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
|
||||
BinaryCopyOutStream {
|
||||
stream,
|
||||
types: Arc::new(types.to_vec()),
|
||||
@ -119,15 +134,15 @@ impl BinaryCopyOutStream {
|
||||
}
|
||||
|
||||
impl Stream for BinaryCopyOutStream {
|
||||
type Item = Result<BinaryCopyOutRow, Box<dyn Error + Sync + Send>>;
|
||||
type Item = Result<BinaryCopyOutRow, Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
let chunk = match ready!(this.stream.poll_next(cx)) {
|
||||
Some(Ok(chunk)) => chunk,
|
||||
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
|
||||
None => return Poll::Ready(Some(Err("unexpected EOF".into()))),
|
||||
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
|
||||
None => return Poll::Ready(Some(Err(Error::closed()))),
|
||||
};
|
||||
let mut chunk = Cursor::new(chunk);
|
||||
|
||||
@ -136,7 +151,10 @@ impl Stream for BinaryCopyOutStream {
|
||||
None => {
|
||||
check_remaining(&chunk, HEADER_LEN)?;
|
||||
if &chunk.bytes()[..MAGIC.len()] != MAGIC {
|
||||
return Poll::Ready(Some(Err("invalid magic value".into())));
|
||||
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid magic value",
|
||||
)))));
|
||||
}
|
||||
chunk.advance(MAGIC.len());
|
||||
|
||||
@ -162,7 +180,10 @@ impl Stream for BinaryCopyOutStream {
|
||||
len += 1;
|
||||
}
|
||||
if len as usize != this.types.len() {
|
||||
return Poll::Ready(Some(Err("unexpected tuple size".into())));
|
||||
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("expected {} values but got {}", this.types.len(), len),
|
||||
)))));
|
||||
}
|
||||
|
||||
let mut ranges = vec![];
|
||||
@ -188,14 +209,18 @@ impl Stream for BinaryCopyOutStream {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_remaining(buf: &impl Buf, len: usize) -> Result<(), Box<dyn Error + Sync + Send>> {
|
||||
fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
|
||||
if buf.remaining() < len {
|
||||
Err("unexpected EOF".into())
|
||||
Err(Error::parse(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF",
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A row of data parsed from a binary copy out stream.
|
||||
pub struct BinaryCopyOutRow {
|
||||
buf: Bytes,
|
||||
ranges: Vec<Option<Range<usize>>>,
|
||||
@ -203,21 +228,36 @@ pub struct BinaryCopyOutRow {
|
||||
}
|
||||
|
||||
impl BinaryCopyOutRow {
|
||||
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>>
|
||||
/// Like `get`, but returns a `Result` rather than panicking.
|
||||
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
|
||||
where
|
||||
T: FromSql<'a>,
|
||||
{
|
||||
let type_ = &self.types[idx];
|
||||
let type_ = match self.types.get(idx) {
|
||||
Some(type_) => type_,
|
||||
None => return Err(Error::column(idx.to_string())),
|
||||
};
|
||||
|
||||
if !T::accepts(type_) {
|
||||
return Err(WrongType::new::<T>(type_.clone()).into());
|
||||
return Err(Error::from_sql(
|
||||
Box::new(WrongType::new::<T>(type_.clone())),
|
||||
idx,
|
||||
));
|
||||
}
|
||||
|
||||
match &self.ranges[idx] {
|
||||
Some(range) => T::from_sql(type_, &self.buf[range.clone()]).map_err(Into::into),
|
||||
None => T::from_sql_null(type_).map_err(Into::into),
|
||||
}
|
||||
let r = match &self.ranges[idx] {
|
||||
Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
|
||||
None => T::from_sql_null(type_),
|
||||
};
|
||||
|
||||
r.map_err(|e| Error::from_sql(e, idx))
|
||||
}
|
||||
|
||||
/// Deserializes a value from the row.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
|
||||
pub fn get<'a, T>(&'a self, idx: usize) -> T
|
||||
where
|
||||
T: FromSql<'a>,
|
@ -123,6 +123,7 @@ pub use crate::to_statement::ToStatement;
|
||||
pub use crate::transaction::Transaction;
|
||||
use crate::types::ToSql;
|
||||
|
||||
pub mod binary_copy;
|
||||
mod bind;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod cancel_query;
|
||||
|
@ -1,22 +1,11 @@
|
||||
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
|
||||
use crate::connect;
|
||||
use futures::{pin_mut, TryStreamExt};
|
||||
use tokio_postgres::binary_copy::{BinaryCopyInWriter, BinaryCopyOutStream};
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::{Client, NoTls};
|
||||
|
||||
async fn connect() -> Client {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::spawn(async {
|
||||
connection.await.unwrap();
|
||||
});
|
||||
client
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_basic() {
|
||||
let client = connect().await;
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
||||
@ -50,7 +39,7 @@ async fn write_basic() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_many_rows() {
|
||||
let client = connect().await;
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
||||
@ -86,7 +75,7 @@ async fn write_many_rows() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_big_rows() {
|
||||
let client = connect().await;
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
|
||||
@ -122,7 +111,7 @@ async fn write_big_rows() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_basic() {
|
||||
let client = connect().await;
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute(
|
||||
@ -138,7 +127,7 @@ async fn read_basic() {
|
||||
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
|
||||
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
@ -152,7 +141,7 @@ async fn read_basic() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_many_rows() {
|
||||
let client = connect().await;
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute(
|
||||
@ -167,7 +156,7 @@ async fn read_many_rows() {
|
||||
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
|
||||
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
||||
@ -181,7 +170,7 @@ async fn read_many_rows() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_big_rows() {
|
||||
let client = connect().await;
|
||||
let client = connect("user=postgres").await;
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
|
||||
@ -201,7 +190,7 @@ async fn read_big_rows() {
|
||||
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream)
|
||||
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::BYTEA])
|
||||
.try_collect::<Vec<_>>()
|
||||
.await
|
||||
.unwrap();
|
@ -14,6 +14,7 @@ use tokio_postgres::tls::{NoTls, NoTlsStream};
|
||||
use tokio_postgres::types::{Kind, Type};
|
||||
use tokio_postgres::{AsyncMessage, Client, Config, Connection, Error, SimpleQueryMessage};
|
||||
|
||||
mod binary_copy;
|
||||
mod parse;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod runtime;
|
||||
|
Loading…
Reference in New Issue
Block a user