Start on binary copy rewrite
This commit is contained in:
parent
cff1189cda
commit
8ebe859183
@ -9,6 +9,7 @@ members = [
|
||||
"postgres-protocol",
|
||||
"postgres-types",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-binary-copy",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
|
17
tokio-postgres-binary-copy/Cargo.toml
Normal file
17
tokio-postgres-binary-copy/Cargo.toml
Normal file
@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "tokio-postgres-binary-copy"
|
||||
version = "0.1.0"
|
||||
authors = ["Steven Fackler <sfackler@gmail.com>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
bytes = "0.4"
|
||||
futures-preview = "=0.3.0-alpha.19"
|
||||
parking_lot = "0.9"
|
||||
pin-project-lite = "0.1"
|
||||
tokio-postgres = { version = "=0.5.0-alpha.1", default-features = false, path = "../tokio-postgres" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = "=0.2.0-alpha.6"
|
||||
tokio-postgres = { version = "=0.5.0-alpha.1", path = "../tokio-postgres" }
|
||||
|
123
tokio-postgres-binary-copy/src/lib.rs
Normal file
123
tokio-postgres-binary-copy/src/lib.rs
Normal file
@ -0,0 +1,123 @@
|
||||
use bytes::{BigEndian, BufMut, ByteOrder, Bytes, BytesMut};
|
||||
use futures::{future, Stream};
|
||||
use parking_lot::Mutex;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio_postgres::types::{IsNull, ToSql, Type};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
const BLOCK_SIZE: usize = 4096;
|
||||
|
||||
pin_project! {
|
||||
pub struct BinaryCopyStream<F> {
|
||||
#[pin]
|
||||
future: F,
|
||||
buf: Arc<Mutex<BytesMut>>,
|
||||
done: bool,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> BinaryCopyStream<F>
|
||||
where
|
||||
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
|
||||
{
|
||||
pub fn new<M>(types: &[Type], write_values: M) -> BinaryCopyStream<F>
|
||||
where
|
||||
M: FnOnce(BinaryCopyWriter) -> F,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
buf.reserve(11 + 4 + 4);
|
||||
buf.put_slice(b"PGCOPY\n\xff\r\n\0"); // magic
|
||||
buf.put_i32_be(0); // flags
|
||||
buf.put_i32_be(0); // header extension
|
||||
|
||||
let buf = Arc::new(Mutex::new(buf));
|
||||
let writer = BinaryCopyWriter {
|
||||
buf: buf.clone(),
|
||||
types: types.to_vec(),
|
||||
idx: 0,
|
||||
};
|
||||
|
||||
BinaryCopyStream {
|
||||
future: write_values(writer),
|
||||
buf,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Stream for BinaryCopyStream<F>
|
||||
where
|
||||
F: Future<Output = Result<(), Box<dyn Error + Sync + Send>>>,
|
||||
{
|
||||
type Item = Result<Bytes, Box<dyn Error + Sync + Send>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
if *this.done {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
*this.done = this.future.poll(cx)?.is_ready();
|
||||
|
||||
let mut buf = this.buf.lock();
|
||||
if *this.done {
|
||||
buf.reserve(2);
|
||||
buf.put_i16_be(-1);
|
||||
Poll::Ready(Some(Ok(buf.take().freeze())))
|
||||
} else if buf.len() > BLOCK_SIZE {
|
||||
Poll::Ready(Some(Ok(buf.take().freeze())))
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME this should really just take a reference to the buffer, but that requires HKT :(
|
||||
pub struct BinaryCopyWriter {
|
||||
buf: Arc<Mutex<BytesMut>>,
|
||||
types: Vec<Type>,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl BinaryCopyWriter {
|
||||
pub async fn write(
|
||||
&mut self,
|
||||
value: &(dyn ToSql + Send),
|
||||
) -> Result<(), Box<dyn Error + Sync + Send>> {
|
||||
future::poll_fn(|_| {
|
||||
if self.buf.lock().len() > BLOCK_SIZE {
|
||||
Poll::Pending
|
||||
} else {
|
||||
Poll::Ready(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut buf = self.buf.lock();
|
||||
if self.idx == 0 {
|
||||
buf.reserve(2);
|
||||
buf.put_i16_be(self.types.len() as i16);
|
||||
}
|
||||
let idx = buf.len();
|
||||
buf.reserve(4);
|
||||
buf.put_i32_be(0);
|
||||
let len = match value.to_sql_checked(&self.types[self.idx], &mut buf)? {
|
||||
IsNull::Yes => -1,
|
||||
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
|
||||
};
|
||||
BigEndian::write_i32(&mut buf[idx..], len);
|
||||
|
||||
self.idx = (self.idx + 1) % self.types.len();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
101
tokio-postgres-binary-copy/src/test.rs
Normal file
101
tokio-postgres-binary-copy/src/test.rs
Normal file
@ -0,0 +1,101 @@
|
||||
use crate::BinaryCopyStream;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::{Client, NoTls};
|
||||
|
||||
async fn connect() -> Client {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls)
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::spawn(async {
|
||||
connection.await.unwrap();
|
||||
});
|
||||
client
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_basic() {
|
||||
let client = connect().await;
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
|
||||
async move {
|
||||
w.write(&1i32).await?;
|
||||
w.write(&"foobar").await?;
|
||||
|
||||
w.write(&2i32).await?;
|
||||
w.write(&None::<&str>).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
client
|
||||
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
|
||||
assert_eq!(rows.len(), 2);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar"));
|
||||
assert_eq!(rows[1].get::<_, i32>(0), 2);
|
||||
assert_eq!(rows[1].get::<_, Option<&str>>(1), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_many_rows() {
|
||||
let client = connect().await;
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| async move {
|
||||
for i in 0..10_000i32 {
|
||||
w.write(&i).await?;
|
||||
w.write(&format!("the value for {}", i)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
|
||||
|
||||
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
assert_eq!(row.get::<_, i32>(0), i as i32);
|
||||
assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_big_rows() {
|
||||
let client = connect().await;
|
||||
|
||||
client.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)").await.unwrap();
|
||||
|
||||
let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
|
||||
async move {
|
||||
for i in 0..2i32 {
|
||||
w.write(&i).await.unwrap();
|
||||
w.write(&vec![i as u8; 128 * 1024]).await.unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
|
||||
|
||||
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
assert_eq!(row.get::<_, i32>(0), i as i32);
|
||||
assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user