diff --git a/tokio-postgres-binary-copy/src/lib.rs b/tokio-postgres-binary-copy/src/lib.rs index 9e0677d2..fa42a84a 100644 --- a/tokio-postgres-binary-copy/src/lib.rs +++ b/tokio-postgres-binary-copy/src/lib.rs @@ -42,7 +42,6 @@ where let writer = BinaryCopyWriter { buf: buf.clone(), types: types.to_vec(), - idx: 0, }; BinaryCopyStream { @@ -85,14 +84,29 @@ where pub struct BinaryCopyWriter { buf: Arc>, types: Vec, - idx: usize, } impl BinaryCopyWriter { pub async fn write( &mut self, - value: &(dyn ToSql + Send), + values: &[&(dyn ToSql + Send)], ) -> Result<(), Box> { + self.write_raw(values.iter().cloned()).await + } + + pub async fn write_raw<'a, I>(&mut self, values: I) -> Result<(), Box> + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let values = values.into_iter(); + assert!( + values.len() == self.types.len(), + "expected {} values but got {}", + self.types.len(), + values.len(), + ); + future::poll_fn(|_| { if self.buf.lock().len() > BLOCK_SIZE { Poll::Pending @@ -103,20 +117,20 @@ impl BinaryCopyWriter { .await; let mut buf = self.buf.lock(); - if self.idx == 0 { - buf.reserve(2); - buf.put_i16_be(self.types.len() as i16); - } - let idx = buf.len(); - buf.reserve(4); - buf.put_i32_be(0); - let len = match value.to_sql_checked(&self.types[self.idx], &mut buf)? { - IsNull::Yes => -1, - IsNull::No => i32::try_from(buf.len() - idx - 4)?, - }; - BigEndian::write_i32(&mut buf[idx..], len); - self.idx = (self.idx + 1) % self.types.len(); + buf.reserve(2); + buf.put_i16_be(self.types.len() as i16); + + for (value, type_) in values.zip(&self.types) { + let idx = buf.len(); + buf.reserve(4); + buf.put_i32_be(0); + let len = match value.to_sql_checked(type_, &mut buf)? { + IsNull::Yes => -1, + IsNull::No => i32::try_from(buf.len() - idx - 4)?, + }; + BigEndian::write_i32(&mut buf[idx..], len); + } Ok(()) } diff --git a/tokio-postgres-binary-copy/src/test.rs b/tokio-postgres-binary-copy/src/test.rs index dcf91b69..486ac581 100644 --- a/tokio-postgres-binary-copy/src/test.rs +++ b/tokio-postgres-binary-copy/src/test.rs @@ -24,11 +24,8 @@ async fn write_basic() { let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| { async move { - w.write(&1i32).await?; - w.write(&"foobar").await?; - - w.write(&2i32).await?; - w.write(&None::<&str>).await?; + w.write(&[&1i32, &"foobar"]).await?; + w.write(&[&2i32, &None::<&str>]).await?; Ok(()) } @@ -39,7 +36,10 @@ async fn write_basic() { .await .unwrap(); - let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap(); + let rows = client + .query("SELECT id, bar FROM foo ORDER BY id", &[]) + .await + .unwrap(); assert_eq!(rows.len(), 2); assert_eq!(rows[0].get::<_, i32>(0), 1); assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar")); @@ -56,18 +56,25 @@ async fn write_many_rows() { .await .unwrap(); - let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| async move { - for i in 0..10_000i32 { - w.write(&i).await?; - w.write(&format!("the value for {}", i)).await?; - } + let stream = BinaryCopyStream::new(&[Type::INT4, Type::TEXT], |mut w| { + async move { + for i in 0..10_000i32 { + w.write(&[&i, &format!("the value for {}", i)]).await?; + } - Ok(()) + Ok(()) + } }); - client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap(); + client + .copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream) + .await + .unwrap(); - let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap(); + let rows = client + .query("SELECT id, bar FROM foo ORDER BY id", &[]) + .await + .unwrap(); for (i, row) in rows.iter().enumerate() { assert_eq!(row.get::<_, i32>(0), i as i32); assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i)); @@ -78,22 +85,30 @@ async fn write_many_rows() { async fn write_big_rows() { let client = connect().await; - client.batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)").await.unwrap(); + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)") + .await + .unwrap(); let stream = BinaryCopyStream::new(&[Type::INT4, Type::BYTEA], |mut w| { async move { for i in 0..2i32 { - w.write(&i).await.unwrap(); - w.write(&vec![i as u8; 128 * 1024]).await.unwrap(); + w.write(&[&i, &vec![i as u8; 128 * 1024]]).await?; } Ok(()) } }); - client.copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream).await.unwrap(); + client + .copy_in("COPY foo (id, bar) FROM STDIN BINARY", &[], stream) + .await + .unwrap(); - let rows = client.query("SELECT id, bar FROM foo ORDER BY id", &[]).await.unwrap(); + let rows = client + .query("SELECT id, bar FROM foo ORDER BY id", &[]) + .await + .unwrap(); for (i, row) in rows.iter().enumerate() { assert_eq!(row.get::<_, i32>(0), i as i32); assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]);