Add COPY TO STDOUT support.

Closes #51
This commit is contained in:
Steven Fackler 2015-08-15 23:21:39 -07:00
parent 63e278b9f2
commit 6e99874bd9
4 changed files with 237 additions and 5 deletions

View File

@ -350,3 +350,9 @@ impl From<byteorder::Error> for Error {
Error::IoError(From::from(err))
}
}
impl From<Error> for io::Error {
fn from(err: Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}
}

View File

@ -89,12 +89,12 @@ mod macros;
mod md5;
mod message;
mod priv_io;
mod stmt;
mod url;
mod util;
pub mod error;
pub mod io;
pub mod rows;
pub mod stmt;
pub mod types;
const TYPEINFO_QUERY: &'static str = "t";

View File

@ -1,8 +1,10 @@
//! Prepared statements
use debug_builders::DebugStruct;
use std::cell::Cell;
use std::cell::{Cell, RefMut};
use std::collections::VecDeque;
use std::fmt;
use std::io;
use std::io::{self, Cursor, BufRead, Read};
use error::{Error, DbError};
use types::{ReadWithInfo, SessionInfo, Type, ToSql, IsNull};
@ -12,7 +14,7 @@ use message::WriteMessage;
use util;
use rows::{Rows, LazyRows};
use {read_rows, bad_response, Connection, Transaction, StatementInternals, Result, RowsNew};
use {SessionInfoNew, LazyRowsNew, DbErrorNew, ColumnNew};
use {InnerConnection, SessionInfoNew, LazyRowsNew, DbErrorNew, ColumnNew};
/// A prepared statement.
pub struct Statement<'conn> {
@ -371,6 +373,84 @@ impl<'conn> Statement<'conn> {
Ok(num)
}
/// Executes a `COPY TO STDOUT` statement, returning a `Read`er of the
/// resulting data.
///
/// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
/// for details on the data format.
///
/// If the statement is not a `COPY TO STDOUT` statement it will still be
/// executed and this method will return an error.
///
/// # Warning
///
/// The underlying connection may not be used while the returned `Read`er
/// exists. Any attempt to do so will panic.
///
/// # Examples
///
/// ```rust,no_run
/// # use std::io::Read;
/// # use postgres::{Connection, SslMode};
/// # let conn = Connection::connect("", &SslMode::None).unwrap();
/// conn.batch_execute("
/// CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR);
/// INSERT INTO people (id, name) VALUES (1, 'john'), (2, 'jane');").unwrap();
/// let stmt = conn.prepare("COPY people TO STDOUT").unwrap();
/// let mut r = stmt.copy_out(&[]).unwrap();
/// let mut buf = vec![];
/// r.read_to_end(&mut buf).unwrap();
/// r.finish().unwrap();
/// assert_eq!(buf, b"1\tjohn\n2\tjane\n");
/// ```
pub fn copy_out<'a>(&'a self, params: &[&ToSql]) -> Result<CopyOutReader<'a>> {
try!(self.inner_execute("", 0, params));
let mut conn = self.conn.conn.borrow_mut();
let (format, column_formats) = match try!(conn.read_message()) {
CopyOutResponse { format, column_formats } => (format, column_formats),
CopyInResponse { .. } => {
try!(conn.write_messages(&[
CopyFail {
message: "",
},
CopyDone,
Sync]));
match try!(conn.read_message()) {
ErrorResponse { .. } => { /* expected from the CopyFail */ }
_ => {
conn.desynchronized = true;
return Err(Error::IoError(bad_response()));
}
}
try!(conn.wait_for_ready());
return Err(Error::IoError(io::Error::new(
io::ErrorKind::InvalidInput,
"called `copy_out` on a non-`COPY TO STDOUT` statement")));
}
_ => {
loop {
match try!(conn.read_message()) {
ReadyForQuery { .. } => {
return Err(Error::IoError(io::Error::new(
io::ErrorKind::InvalidInput,
"called `copy_out` on a non-`COPY TO STDOUT` statement")));
}
_ => {}
}
}
}
};
Ok(CopyOutReader {
conn: conn,
format: Format::from_u16(format as u16),
column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(),
buf: Cursor::new(vec![]),
finished: false,
})
}
/// Consumes the statement, clearing it from the Postgres session.
///
/// If this statement was created via the `prepare_cached` method, `finish`
@ -425,4 +505,119 @@ impl Column {
}
}
/// The format of a portion of COPY query data.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Format {
/// A text based format.
Text,
/// A binary format.
Binary,
}
impl Format {
fn from_u16(value: u16) -> Format {
match value {
0 => Format::Text,
_ => Format::Binary,
}
}
}
/// A `Read`er for data from `COPY TO STDOUT` queries.
///
/// # Warning
///
/// The underlying connection may not be used while a `CopyOutReader` exists.
/// Any calls to the connection with panic.
pub struct CopyOutReader<'a> {
conn: RefMut<'a, InnerConnection>,
format: Format,
column_formats: Vec<Format>,
buf: Cursor<Vec<u8>>,
finished: bool,
}
impl<'a> Drop for CopyOutReader<'a> {
fn drop(&mut self) {
let _ = self.finish_inner();
}
}
impl<'a> CopyOutReader<'a> {
/// Returns the format of the overall data.
pub fn format(&self) -> Format {
self.format
}
/// Returns the format of the individual columns.
pub fn column_formats(&self) -> &[Format] {
&self.column_formats
}
/// Consumes the `CopyOutReader`, throwing away any unread data.
///
/// Functionally equivalent to `CopyOutReader`'s `Drop` implementation,
/// except that it returns any error encountered to the caller.
pub fn finish(mut self) -> Result<()> {
self.finish_inner()
}
fn finish_inner(&mut self) -> Result<()> {
while !self.finished {
let pos = self.buf.get_ref().len() as u64;
self.buf.set_position(pos);
try!(self.ensure_filled());
}
Ok(())
}
fn ensure_filled(&mut self) -> Result<()> {
if self.finished || self.buf.position() != self.buf.get_ref().len() as u64 {
return Ok(());
}
match try!(self.conn.read_message()) {
BCopyData { data } => self.buf = Cursor::new(data),
BCopyDone => {
self.finished = true;
match try!(self.conn.read_message()) {
CommandComplete { .. } => {}
_ => {
self.conn.desynchronized = true;
return Err(Error::IoError(bad_response()));
}
}
try!(self.conn.wait_for_ready());
}
ErrorResponse { fields } => {
self.finished = true;
try!(self.conn.wait_for_ready());
return DbError::new(fields);
}
_ => {
self.conn.desynchronized = true;
return Err(Error::IoError(bad_response()));
}
}
Ok(())
}
}
impl<'a> Read for CopyOutReader<'a> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
try!(self.ensure_filled());
self.buf.read(buf)
}
}
impl<'a> BufRead for CopyOutReader<'a> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
try!(self.ensure_filled());
self.buf.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.buf.consume(amt)
}
}

View File

@ -8,6 +8,7 @@ extern crate openssl;
use openssl::ssl::{SslContext, SslMethod};
use std::thread;
use std::io;
use std::io::prelude::*;
use postgres::{HandleNotice,
Notification,
@ -757,7 +758,7 @@ fn test_copy() {
}
#[test]
fn test_copy_out_query() {
fn test_query_copy_out_err() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.batch_execute("
CREATE TEMPORARY TABLE foo (id INT);
@ -770,6 +771,36 @@ fn test_copy_out_query() {
}
}
#[test]
fn test_copy_out() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.batch_execute("
CREATE TEMPORARY TABLE foo (id INT);
INSERT INTO foo (id) VALUES (0), (1), (2), (3)"));
let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT"));
let mut reader = or_panic!(stmt.copy_out(&[]));
let mut out = vec![];
or_panic!(reader.read_to_end(&mut out));
assert_eq!(out, b"0\n1\n2\n3\n");
drop(reader);
or_panic!(conn.batch_execute("SELECT 1"));
}
#[test]
fn test_copy_out_partial_read() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.batch_execute("
CREATE TEMPORARY TABLE foo (id INT);
INSERT INTO foo (id) VALUES (0), (1), (2), (3)"));
let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT"));
let mut reader = or_panic!(stmt.copy_out(&[]));
let mut out = vec![];
or_panic!(reader.by_ref().take(5).read_to_end(&mut out));
assert_eq!(out, b"0\n1\n2");
drop(reader);
or_panic!(conn.batch_execute("SELECT 1"));
}
#[test]
// Just make sure the impls don't infinite loop
fn test_generic_connection() {