Start on binary copy rewrite

This commit is contained in:
Steven Fackler 2019-11-17 18:28:12 -08:00
parent cff1189cda
commit 8ebe859183
4 changed files with 242 additions and 0 deletions

View File

@ -9,6 +9,7 @@ members = [
"postgres-protocol",
"postgres-types",
"tokio-postgres",
"tokio-postgres-binary-copy",
]
[profile.release]

View File

@ -0,0 +1,17 @@
[package]
name = "tokio-postgres-binary-copy"
version = "0.1.0"
authors = ["Steven Fackler <sfackler@gmail.com>"]
edition = "2018"
[dependencies]
bytes = "0.4"
futures-preview = "=0.3.0-alpha.19"
parking_lot = "0.9"
pin-project-lite = "0.1"
tokio-postgres = { version = "=0.5.0-alpha.1", default-features = false, path = "../tokio-postgres" }
[dev-dependencies]
tokio = "=0.2.0-alpha.6"
tokio-postgres = { version = "=0.5.0-alpha.1", path = "../tokio-postgres" }

View File

@ -0,0 +1,123 @@
use bytes::{BigEndian, BufMut, ByteOrder, Bytes, BytesMut};
use futures::{future, Stream};
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::convert::TryFrom;
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio_postgres::types::{IsNull, ToSql, Type};
#[cfg(test)]
mod test;
const BLOCK_SIZE: usize = 4096;
pin_project! {
pub struct BinaryCopyStream<F> {
#[pin]
future: F,
buf: Arc<Mutex<BytesMut>>,
done: bool,
}
}
impl<F> BinaryCopyStream<F>
where
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
{
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyStream<F>
where
M: FnOnce(BinaryCopyWriter) -> F,
{
let mut buf = BytesMut::new();
buf.reserve(11 + 4 + 4);
buf.put_slice(b"PGCOPY\n\xff\r\n\0"); // magic
buf.put_i32_be(0); // flags
buf.put_i32_be(0); // header extension
let buf = Arc::new(Mutex::new(buf));
let writer = BinaryCopyWriter {
buf: buf.clone(),
types: types.to_vec(),
idx: 0,
};
BinaryCopyStream {
future: write_values(writer),
buf,
done: false,
}
}
}
impl<F> Stream for BinaryCopyStream<F>
where
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
{
type Item = Result<Bytes, Box<dyn Error + Sync + Send>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
if *this.done {
return Poll::Ready(None);
}
*this.done = this.future.poll(cx)?.is_ready();
let mut buf = this.buf.lock();
if *this.done {
buf.reserve(2);
buf.put_i16_be(-1);
Poll::Ready(Some(Ok(buf.take().freeze())))
} else if buf.len() > BLOCK_SIZE {
Poll::Ready(Some(Ok(buf.take().freeze())))
} else {
Poll::Pending
}
}
}
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
pub struct BinaryCopyWriter {
buf: Arc<Mutex<BytesMut>>,
types: Vec<Type>,
idx: usize,
}
impl BinaryCopyWriter {
pub async fn write(
&mut self,
value: &(dyn ToSql + Send),
) -> Result<(), Box<dyn Error + Sync + Send>> {
future::poll_fn(|_| {
if self.buf.lock().len() > BLOCK_SIZE {
Poll::Pending
} else {
Poll::Ready(())
}
})
.await;
let mut buf = self.buf.lock();
if self.idx == 0 {
buf.reserve(2);
buf.put_i16_be(self.types.len() as i16);
}
let idx = buf.len();
buf.reserve(4);
buf.put_i32_be(0);
let len = match value.to_sql_checked(&self.types[self.idx], &mut buf)? {
IsNull::Yes => -1,
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
};
BigEndian::write_i32(&mut buf[idx..], len);
self.idx = (self.idx + 1) % self.types.len();
Ok(())
}
}

View File

@ -0,0 +1,101 @@
use crate::BinaryCopyStream;
use tokio_postgres::types::Type;
use tokio_postgres::{Client, NoTls};
async fn connect() -> Client {
let (client, connection) =
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
.await
.unwrap();
tokio::spawn(async {
connection.await.unwrap();
});
client
}
#[tokio::test]
async fn write_basic() {
let client = connect().await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
.await
.unwrap();
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
async move {
w.write(&1i32).await?;
w.write(&"foobar").await?;
w.write(&2i32).await?;
w.write(&None::<&str>).await?;
Ok(())
}
});
client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
.await
.unwrap();
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar"));
assert_eq!(rows[1].get::<_, i32>(0), 2);
assert_eq!(rows[1].get::<_, Option<&str>>(1), None);
}
#[tokio::test]
async fn write_many_rows() {
let client = connect().await;
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
.await
.unwrap();
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| async move {
for i in 0..10_000i32 {
w.write(&i).await?;
w.write(&format!("the value for {}", i)).await?;
}
Ok(())
});
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
for (i, row) in rows.iter().enumerate() {
assert_eq!(row.get::<_, i32>(0), i as i32);
assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i));
}
}
#[tokio::test]
async fn write_big_rows() {
let client = connect().await;
client.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)").await.unwrap();
let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
async move {
for i in 0..2i32 {
w.write(&i).await.unwrap();
w.write(&vec![i as u8; 128 * 1024]).await.unwrap();
}
Ok(())
}
});
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
for (i, row) in rows.iter().enumerate() {
assert_eq!(row.get::<_, i32>(0), i as i32);
assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);
}
}