Write full rows in binary copy

This commit is contained in:
Steven Fackler 2019-11-18 18:06:03 -08:00
parent 8ebe859183
commit 6e2435eb60
2 changed files with 64 additions and 35 deletions

View File

@ -42,7 +42,6 @@ where
let writer = BinaryCopyWriter {
buf: buf.clone(),
types: types.to_vec(),
idx: 0,
};
BinaryCopyStream {
@ -85,14 +84,29 @@ where
pub struct BinaryCopyWriter {
buf: Arc<Mutex<BytesMut>>,
types: Vec<Type>,
idx: usize,
}
impl BinaryCopyWriter {
pub async fn write(
&mut self,
value: &(dyn ToSql + Send),
values: &[&(dyn ToSql + Send)],
) -> Result<(), Box<dyn Error + Sync + Send>> {
self.write_raw(values.iter().cloned()).await
}
pub async fn write_raw<'a, I>(&mut self, values: I) -> Result<(), Box<dyn Error + Sync + Send>>
where
I: IntoIterator<Item = &'a (dyn ToSql + Send)>,
I::IntoIter: ExactSizeIterator,
{
let values = values.into_iter();
assert!(
values.len() == self.types.len(),
"expected {} values but got {}",
self.types.len(),
values.len(),
);
future::poll_fn(|_| {
if self.buf.lock().len() > BLOCK_SIZE {
Poll::Pending
@ -103,20 +117,20 @@ impl BinaryCopyWriter {
.await;
let mut buf = self.buf.lock();
if self.idx == 0 {
buf.reserve(2);
buf.put_i16_be(self.types.len() as i16);
}
let idx = buf.len();
buf.reserve(4);
buf.put_i32_be(0);
let len = match value.to_sql_checked(&self.types[self.idx], &mut buf)? {
IsNull::Yes => -1,
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
};
BigEndian::write_i32(&mut buf[idx..], len);
self.idx = (self.idx + 1) % self.types.len();
buf.reserve(2);
buf.put_i16_be(self.types.len() as i16);
for (value, type_) in values.zip(&self.types) {
let idx = buf.len();
buf.reserve(4);
buf.put_i32_be(0);
let len = match value.to_sql_checked(type_, &mut buf)? {
IsNull::Yes => -1,
IsNull::No => i32::try_from(buf.len() - idx - 4)?,
};
BigEndian::write_i32(&mut buf[idx..], len);
}
Ok(())
}

View File

@ -24,11 +24,8 @@ async fn write_basic() {
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
async move {
w.write(&1i32).await?;
w.write(&"foobar").await?;
w.write(&2i32).await?;
w.write(&None::<&str>).await?;
w.write(&[&1i32, &"foobar"]).await?;
w.write(&[&2i32, &None::<&str>]).await?;
Ok(())
}
@ -39,7 +36,10 @@ async fn write_basic() {
.await
.unwrap();
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
let rows = client
.query("SELECT id, bar FROM foo ORDER BY id", &[])
.await
.unwrap();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar"));
@ -56,18 +56,25 @@ async fn write_many_rows() {
.await
.unwrap();
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| async move {
for i in 0..10_000i32 {
w.write(&i).await?;
w.write(&format!("the value for {}", i)).await?;
}
let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| {
async move {
for i in 0..10_000i32 {
w.write(&[&i, &format!("the value for {}", i)]).await?;
}
Ok(())
Ok(())
}
});
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
.await
.unwrap();
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
let rows = client
.query("SELECT id, bar FROM foo ORDER BY id", &[])
.await
.unwrap();
for (i, row) in rows.iter().enumerate() {
assert_eq!(row.get::<_, i32>(0), i as i32);
assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i));
@ -78,22 +85,30 @@ async fn write_many_rows() {
async fn write_big_rows() {
let client = connect().await;
client.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)").await.unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)")
.await
.unwrap();
let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| {
async move {
for i in 0..2i32 {
w.write(&i).await.unwrap();
w.write(&vec![i as u8; 128 * 1024]).await.unwrap();
w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?;
}
Ok(())
}
});
client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap();
client
.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream)
.await
.unwrap();
let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap();
let rows = client
.query("SELECT id, bar FROM foo ORDER BY id", &[])
.await
.unwrap();
for (i, row) in rows.iter().enumerate() {
assert_eq!(row.get::<_, i32>(0), i as i32);
assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);