Move binary copy stuff directly into main crate
This commit is contained in:
parent
0c84ed9f82
commit
bf8b335d2b
@ -9,7 +9,6 @@ members = [
|
|||||||
"postgres-protocol",
|
"postgres-protocol",
|
||||||
"postgres-types",
|
"postgres-types",
|
||||||
"tokio-postgres",
|
"tokio-postgres",
|
||||||
"tokio-postgres-binary-copy",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "tokio-postgres-binary-copy"
|
|
||||||
version = "0.1.0"
|
|
||||||
authors = ["Steven Fackler <sfackler@gmail.com>"]
|
|
||||||
edition = "2018"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
byteorder = "1.0"
|
|
||||||
bytes = "0.5"
|
|
||||||
futures = "0.3"
|
|
||||||
pin-project-lite = "0.1"
|
|
||||||
tokio-postgres = { version = "=0.5.0-alpha.2", default-features = false, path = "../tokio-postgres" }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
tokio = { version = "0.2", features = ["full"] }
|
|
||||||
tokio-postgres = { version = "=0.5.0-alpha.2", path = "../tokio-postgres" }
|
|
@ -37,6 +37,7 @@ with-uuid-0_8 = ["postgres-types/with-uuid-0_8"]
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytes = "0.5"
|
bytes = "0.5"
|
||||||
|
byteorder = "1.0"
|
||||||
fallible-iterator = "0.2"
|
fallible-iterator = "0.2"
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
@ -1,24 +1,26 @@
|
|||||||
|
//! Utilities for working with the PostgreSQL binary copy format.
|
||||||
|
|
||||||
|
use crate::types::{FromSql, IsNull, ToSql, Type, WrongType};
|
||||||
|
use crate::{slice_iter, CopyInSink, CopyOutStream, Error};
|
||||||
use byteorder::{BigEndian, ByteOrder};
|
use byteorder::{BigEndian, ByteOrder};
|
||||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
use futures::{ready, SinkExt, Stream};
|
use futures::{ready, SinkExt, Stream};
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use std::error::Error;
|
use std::io;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::ops::Range;
|
use std::ops::Range;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio_postgres::types::{FromSql, IsNull, ToSql, Type, WrongType};
|
|
||||||
use tokio_postgres::{CopyInSink, CopyOutStream};
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test;
|
|
||||||
|
|
||||||
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
|
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
|
||||||
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
|
const HEADER_LEN: usize = MAGIC.len() + 4 + 4;
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
|
/// A type which serializes rows into the PostgreSQL binary copy format.
|
||||||
|
///
|
||||||
|
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
|
||||||
pub struct BinaryCopyInWriter {
|
pub struct BinaryCopyInWriter {
|
||||||
#[pin]
|
#[pin]
|
||||||
sink: CopyInSink<Bytes>,
|
sink: CopyInSink<Bytes>,
|
||||||
@ -28,10 +30,10 @@ pin_project! {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BinaryCopyInWriter {
|
impl BinaryCopyInWriter {
|
||||||
|
/// Creates a new writer which will write rows of the provided types to the provided sink.
|
||||||
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
|
pub fn new(sink: CopyInSink<Bytes>, types: &[Type]) -> BinaryCopyInWriter {
|
||||||
let mut buf = BytesMut::new();
|
let mut buf = BytesMut::new();
|
||||||
buf.reserve(HEADER_LEN);
|
buf.put_slice(MAGIC);
|
||||||
buf.put_slice(MAGIC); // magic
|
|
||||||
buf.put_i32(0); // flags
|
buf.put_i32(0); // flags
|
||||||
buf.put_i32(0); // header extension
|
buf.put_i32(0); // header extension
|
||||||
|
|
||||||
@ -42,19 +44,23 @@ impl BinaryCopyInWriter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write(
|
/// Writes a single row.
|
||||||
self: Pin<&mut Self>,
|
///
|
||||||
values: &[&(dyn ToSql + Send)],
|
/// # Panics
|
||||||
) -> Result<(), Box<dyn Error + Sync + Send>> {
|
///
|
||||||
self.write_raw(values.iter().cloned()).await
|
/// Panics if the number of values provided does not match the number expected.
|
||||||
|
pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
|
||||||
|
self.write_raw(slice_iter(values)).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn write_raw<'a, I>(
|
/// A maximally-flexible version of `write`.
|
||||||
self: Pin<&mut Self>,
|
///
|
||||||
values: I,
|
/// # Panics
|
||||||
) -> Result<(), Box<dyn Error + Sync + Send>>
|
///
|
||||||
|
/// Panics if the number of values provided does not match the number expected.
|
||||||
|
pub async fn write_raw<'a, I>(self: Pin<&mut Self>, values: I) -> Result<(), Error>
|
||||||
where
|
where
|
||||||
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
|
I: IntoIterator<Item = &'a dyn ToSql>,
|
||||||
I::IntoIter: ExactSizeIterator,
|
I::IntoIter: ExactSizeIterator,
|
||||||
{
|
{
|
||||||
let mut this = self.project();
|
let mut this = self.project();
|
||||||
@ -69,12 +75,16 @@ impl BinaryCopyInWriter {
|
|||||||
|
|
||||||
this.buf.put_i16(this.types.len() as i16);
|
this.buf.put_i16(this.types.len() as i16);
|
||||||
|
|
||||||
for (value, type_) in values.zip(this.types) {
|
for (i, (value, type_)) in values.zip(this.types).enumerate() {
|
||||||
let idx = this.buf.len();
|
let idx = this.buf.len();
|
||||||
this.buf.put_i32(0);
|
this.buf.put_i32(0);
|
||||||
let len = match value.to_sql_checked(type_, this.buf)? {
|
let len = match value
|
||||||
|
.to_sql_checked(type_, this.buf)
|
||||||
|
.map_err(|e| Error::to_sql(e, i))?
|
||||||
|
{
|
||||||
IsNull::Yes => -1,
|
IsNull::Yes => -1,
|
||||||
IsNull::No => i32::try_from(this.buf.len() - idx - 4)?,
|
IsNull::No => i32::try_from(this.buf.len() - idx - 4)
|
||||||
|
.map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?,
|
||||||
};
|
};
|
||||||
BigEndian::write_i32(&mut this.buf[idx..], len);
|
BigEndian::write_i32(&mut this.buf[idx..], len);
|
||||||
}
|
}
|
||||||
@ -86,7 +96,10 @@ impl BinaryCopyInWriter {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, tokio_postgres::Error> {
|
/// Completes the copy, returning the number of rows added.
|
||||||
|
///
|
||||||
|
/// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
|
||||||
|
pub async fn finish(self: Pin<&mut Self>) -> Result<u64, Error> {
|
||||||
let mut this = self.project();
|
let mut this = self.project();
|
||||||
|
|
||||||
this.buf.put_i16(-1);
|
this.buf.put_i16(-1);
|
||||||
@ -100,6 +113,7 @@ struct Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
|
/// A stream of rows deserialized from the PostgreSQL binary copy format.
|
||||||
pub struct BinaryCopyOutStream {
|
pub struct BinaryCopyOutStream {
|
||||||
#[pin]
|
#[pin]
|
||||||
stream: CopyOutStream,
|
stream: CopyOutStream,
|
||||||
@ -109,7 +123,8 @@ pin_project! {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BinaryCopyOutStream {
|
impl BinaryCopyOutStream {
|
||||||
pub fn new(types: &[Type], stream: CopyOutStream) -> BinaryCopyOutStream {
|
/// Creates a stream from a raw copy out stream and the types of the columns being returned.
|
||||||
|
pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream {
|
||||||
BinaryCopyOutStream {
|
BinaryCopyOutStream {
|
||||||
stream,
|
stream,
|
||||||
types: Arc::new(types.to_vec()),
|
types: Arc::new(types.to_vec()),
|
||||||
@ -119,15 +134,15 @@ impl BinaryCopyOutStream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Stream for BinaryCopyOutStream {
|
impl Stream for BinaryCopyOutStream {
|
||||||
type Item = Result<BinaryCopyOutRow, Box<dyn Error + Sync + Send>>;
|
type Item = Result<BinaryCopyOutRow, Error>;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
let this = self.project();
|
let this = self.project();
|
||||||
|
|
||||||
let chunk = match ready!(this.stream.poll_next(cx)) {
|
let chunk = match ready!(this.stream.poll_next(cx)) {
|
||||||
Some(Ok(chunk)) => chunk,
|
Some(Ok(chunk)) => chunk,
|
||||||
Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
|
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
|
||||||
None => return Poll::Ready(Some(Err("unexpected EOF".into()))),
|
None => return Poll::Ready(Some(Err(Error::closed()))),
|
||||||
};
|
};
|
||||||
let mut chunk = Cursor::new(chunk);
|
let mut chunk = Cursor::new(chunk);
|
||||||
|
|
||||||
@ -136,7 +151,10 @@ impl Stream for BinaryCopyOutStream {
|
|||||||
None => {
|
None => {
|
||||||
check_remaining(&chunk, HEADER_LEN)?;
|
check_remaining(&chunk, HEADER_LEN)?;
|
||||||
if &chunk.bytes()[..MAGIC.len()] != MAGIC {
|
if &chunk.bytes()[..MAGIC.len()] != MAGIC {
|
||||||
return Poll::Ready(Some(Err("invalid magic value".into())));
|
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
"invalid magic value",
|
||||||
|
)))));
|
||||||
}
|
}
|
||||||
chunk.advance(MAGIC.len());
|
chunk.advance(MAGIC.len());
|
||||||
|
|
||||||
@ -162,7 +180,10 @@ impl Stream for BinaryCopyOutStream {
|
|||||||
len += 1;
|
len += 1;
|
||||||
}
|
}
|
||||||
if len as usize != this.types.len() {
|
if len as usize != this.types.len() {
|
||||||
return Poll::Ready(Some(Err("unexpected tuple size".into())));
|
return Poll::Ready(Some(Err(Error::parse(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
format!("expected {} values but got {}", this.types.len(), len),
|
||||||
|
)))));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut ranges = vec![];
|
let mut ranges = vec![];
|
||||||
@ -188,14 +209,18 @@ impl Stream for BinaryCopyOutStream {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_remaining(buf: &impl Buf, len: usize) -> Result<(), Box<dyn Error + Sync + Send>> {
|
fn check_remaining(buf: &Cursor<Bytes>, len: usize) -> Result<(), Error> {
|
||||||
if buf.remaining() < len {
|
if buf.remaining() < len {
|
||||||
Err("unexpected EOF".into())
|
Err(Error::parse(io::Error::new(
|
||||||
|
io::ErrorKind::UnexpectedEof,
|
||||||
|
"unexpected EOF",
|
||||||
|
)))
|
||||||
} else {
|
} else {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A row of data parsed from a binary copy out stream.
|
||||||
pub struct BinaryCopyOutRow {
|
pub struct BinaryCopyOutRow {
|
||||||
buf: Bytes,
|
buf: Bytes,
|
||||||
ranges: Vec<Option<Range<usize>>>,
|
ranges: Vec<Option<Range<usize>>>,
|
||||||
@ -203,21 +228,36 @@ pub struct BinaryCopyOutRow {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BinaryCopyOutRow {
|
impl BinaryCopyOutRow {
|
||||||
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Box<dyn Error + Sync + Send>>
|
/// Like `get`, but returns a `Result` rather than panicking.
|
||||||
|
pub fn try_get<'a, T>(&'a self, idx: usize) -> Result<T, Error>
|
||||||
where
|
where
|
||||||
T: FromSql<'a>,
|
T: FromSql<'a>,
|
||||||
{
|
{
|
||||||
let type_ = &self.types[idx];
|
let type_ = match self.types.get(idx) {
|
||||||
|
Some(type_) => type_,
|
||||||
|
None => return Err(Error::column(idx.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
if !T::accepts(type_) {
|
if !T::accepts(type_) {
|
||||||
return Err(WrongType::new::<T>(type_.clone()).into());
|
return Err(Error::from_sql(
|
||||||
|
Box::new(WrongType::new::<T>(type_.clone())),
|
||||||
|
idx,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
match &self.ranges[idx] {
|
let r = match &self.ranges[idx] {
|
||||||
Some(range) => T::from_sql(type_, &self.buf[range.clone()]).map_err(Into::into),
|
Some(range) => T::from_sql(type_, &self.buf[range.clone()]),
|
||||||
None => T::from_sql_null(type_).map_err(Into::into),
|
None => T::from_sql_null(type_),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
r.map_err(|e| Error::from_sql(e, idx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Deserializes a value from the row.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
///
|
||||||
|
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
|
||||||
pub fn get<'a, T>(&'a self, idx: usize) -> T
|
pub fn get<'a, T>(&'a self, idx: usize) -> T
|
||||||
where
|
where
|
||||||
T: FromSql<'a>,
|
T: FromSql<'a>,
|
@ -123,6 +123,7 @@ pub use crate::to_statement::ToStatement;
|
|||||||
pub use crate::transaction::Transaction;
|
pub use crate::transaction::Transaction;
|
||||||
use crate::types::ToSql;
|
use crate::types::ToSql;
|
||||||
|
|
||||||
|
pub mod binary_copy;
|
||||||
mod bind;
|
mod bind;
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
mod cancel_query;
|
mod cancel_query;
|
||||||
|
@ -1,22 +1,11 @@
|
|||||||
use crate::{BinaryCopyInWriter, BinaryCopyOutStream};
|
use crate::connect;
|
||||||
use futures::{pin_mut, TryStreamExt};
|
use futures::{pin_mut, TryStreamExt};
|
||||||
|
use tokio_postgres::binary_copy::{BinaryCopyInWriter, BinaryCopyOutStream};
|
||||||
use tokio_postgres::types::Type;
|
use tokio_postgres::types::Type;
|
||||||
use tokio_postgres::{Client, NoTls};
|
|
||||||
|
|
||||||
async fn connect() -> Client {
|
|
||||||
let (client, connection) =
|
|
||||||
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
tokio::spawn(async {
|
|
||||||
connection.await.unwrap();
|
|
||||||
});
|
|
||||||
client
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn write_basic() {
|
async fn write_basic() {
|
||||||
let client = connect().await;
|
let client = connect("user=postgres").await;
|
||||||
|
|
||||||
client
|
client
|
||||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
||||||
@ -50,7 +39,7 @@ async fn write_basic() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn write_many_rows() {
|
async fn write_many_rows() {
|
||||||
let client = connect().await;
|
let client = connect("user=postgres").await;
|
||||||
|
|
||||||
client
|
client
|
||||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
||||||
@ -86,7 +75,7 @@ async fn write_many_rows() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn write_big_rows() {
|
async fn write_big_rows() {
|
||||||
let client = connect().await;
|
let client = connect("user=postgres").await;
|
||||||
|
|
||||||
client
|
client
|
||||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
|
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
|
||||||
@ -122,7 +111,7 @@ async fn write_big_rows() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn read_basic() {
|
async fn read_basic() {
|
||||||
let client = connect().await;
|
let client = connect("user=postgres").await;
|
||||||
|
|
||||||
client
|
client
|
||||||
.batch_execute(
|
.batch_execute(
|
||||||
@ -138,7 +127,7 @@ async fn read_basic() {
|
|||||||
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
|
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -152,7 +141,7 @@ async fn read_basic() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn read_many_rows() {
|
async fn read_many_rows() {
|
||||||
let client = connect().await;
|
let client = connect("user=postgres").await;
|
||||||
|
|
||||||
client
|
client
|
||||||
.batch_execute(
|
.batch_execute(
|
||||||
@ -167,7 +156,7 @@ async fn read_many_rows() {
|
|||||||
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::TEXT], stream)
|
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT])
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -181,7 +170,7 @@ async fn read_many_rows() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn read_big_rows() {
|
async fn read_big_rows() {
|
||||||
let client = connect().await;
|
let client = connect("user=postgres").await;
|
||||||
|
|
||||||
client
|
client
|
||||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
|
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
|
||||||
@ -201,7 +190,7 @@ async fn read_big_rows() {
|
|||||||
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
.copy_out("COPY foo (id, bar) TO STDIN BINARY")
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let rows = BinaryCopyOutStream::new(&[Type::INT4, Type::BYTEA], stream)
|
let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::BYTEA])
|
||||||
.try_collect::<Vec<_>>()
|
.try_collect::<Vec<_>>()
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
@ -14,6 +14,7 @@ use tokio_postgres::tls::{NoTls, NoTlsStream};
|
|||||||
use tokio_postgres::types::{Kind, Type};
|
use tokio_postgres::types::{Kind, Type};
|
||||||
use tokio_postgres::{AsyncMessage, Client, Config, Connection, Error, SimpleQueryMessage};
|
use tokio_postgres::{AsyncMessage, Client, Config, Connection, Error, SimpleQueryMessage};
|
||||||
|
|
||||||
|
mod binary_copy;
|
||||||
mod parse;
|
mod parse;
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
mod runtime;
|
mod runtime;
|
||||||
|
Loading…
Reference in New Issue
Block a user