From 793c5f1b872ccabd03b9a744889ef033646e628d Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 23 Dec 2018 13:08:02 -0800 Subject: [PATCH] Add sync copy_out --- postgres/Cargo.toml | 1 + postgres/src/client.rs | 63 ++++++++++++++++++++++++++++++++++++- postgres/src/test.rs | 26 +++++++++++++++ postgres/src/transaction.rs | 13 +++++++- 4 files changed, 101 insertions(+), 2 deletions(-) diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 03176d02..a59c0794 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -10,6 +10,7 @@ default = ["runtime"] runtime = ["tokio-postgres/runtime", "tokio", "lazy_static", "log"] [dependencies] +bytes = "0.4" futures = "0.1" tokio-postgres = { version = "0.3", path = "../tokio-postgres", default-features = false } diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 2f2496e2..f038e1c2 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -1,6 +1,9 @@ +use bytes::{Buf, Bytes}; +use futures::stream; use futures::sync::mpsc; use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; -use std::io::{self, Read}; +use std::io::{self, BufRead, Cursor, Read}; +use std::marker::PhantomData; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row}; #[cfg(feature = "runtime")] @@ -78,6 +81,30 @@ impl Client { .wait() } + pub fn copy_out( + &mut self, + query: &T, + params: &[&dyn ToSql], + ) -> Result, Error> + where + T: ?Sized + Query, + { + let statement = query.__statement(self)?; + let mut stream = self.0.copy_out(&statement.0, params).wait(); + + let cur = match stream.next() { + Some(Ok(cur)) => cur, + Some(Err(e)) => return Err(e), + None => Bytes::new(), + }; + + Ok(CopyOutReader { + stream, + cur: Cursor::new(cur), + _p: PhantomData, + }) + } + pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { self.0.batch_execute(query).wait() } @@ -179,3 +206,37 @@ where } } } + +pub struct CopyOutReader<'a> { + stream: stream::Wait, + cur: Cursor, + _p: PhantomData<&'a mut ()>, +} + +impl<'a> Read for CopyOutReader<'a> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let b = self.fill_buf()?; + let len = usize::min(buf.len(), b.len()); + buf[..len].copy_from_slice(&b[..len]); + self.consume(len); + Ok(len) + } +} + +impl<'a> BufRead for CopyOutReader<'a> { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + if self.cur.remaining() == 0 { + match self.stream.next() { + Some(Ok(cur)) => self.cur = Cursor::new(cur), + Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)), + None => {} + }; + } + + Ok(Buf::bytes(&self.cur)) + } + + fn consume(&mut self, amt: usize) { + self.cur.advance(amt); + } +} diff --git a/postgres/src/test.rs b/postgres/src/test.rs index c86f2327..70f8537b 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -1,3 +1,4 @@ +use std::io::Read; use tokio_postgres::types::Type; use tokio_postgres::NoTls; @@ -171,3 +172,28 @@ fn copy_in() { assert_eq!(rows[1].get::<_, i32>(0), 2); assert_eq!(rows[1].get::<_, &str>(1), "timothy"); } + +#[test] +fn copy_out() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo (id INT, name TEXT); + + INSERT INTO foo (id, name) VALUES (1, 'steven'), (2, 'timothy'); + ", + ) + .unwrap(); + + let mut reader = client + .copy_out("COPY foo (id, name) TO STDOUT", &[]) + .unwrap(); + let mut s = String::new(); + reader.read_to_string(&mut s).unwrap(); + + assert_eq!(s, "1\tsteven\n2\ttimothy\n"); + + client.batch_execute("SELECT 1").unwrap(); +} diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index a541ee91..0fa19b9b 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -2,7 +2,7 @@ use std::io::Read; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row}; -use crate::{Client, Query, Statement}; +use crate::{Client, CopyOutReader, Query, Statement}; pub struct Transaction<'a> { client: &'a mut Client, @@ -86,6 +86,17 @@ impl<'a> Transaction<'a> { self.client.copy_in(query, params, reader) } + pub fn copy_out( + &mut self, + query: &T, + params: &[&dyn ToSql], + ) -> Result, Error> + where + T: ?Sized + Query, + { + self.client.copy_out(query, params) + } + pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { self.client.batch_execute(query) }