Add support for binary copy in execution

Binary copy in usage needs to call into `ToSql::to_sql`, which needs a
`SessionInfo`. This defines a `Read`-like trait that also passes an
instance in. A blanket impl is provided for `R: Read` so this should be
backwards compatible.
This commit is contained in:
Steven Fackler 2015-07-02 23:57:54 -07:00
parent 525801327b
commit 03ee761108
2 changed files with 39 additions and 10 deletions

View File

@ -73,7 +73,7 @@ use std::path::PathBuf;
use error::{Error, ConnectError, SqlState, DbError}; use error::{Error, ConnectError, SqlState, DbError};
use types::{ToSql, FromSql}; use types::{ToSql, FromSql};
use io::{StreamWrapper, NegotiateSsl}; use io::{StreamWrapper, NegotiateSsl};
use types::{IsNull, Kind, Type, SessionInfo, Oid, Other}; use types::{IsNull, Kind, Type, SessionInfo, Oid, Other, ReadWithInfo};
use message::BackendMessage::*; use message::BackendMessage::*;
use message::FrontendMessage::*; use message::FrontendMessage::*;
use message::{FrontendMessage, BackendMessage, RowDescriptionEntry}; use message::{FrontendMessage, BackendMessage, RowDescriptionEntry};
@ -1530,9 +1530,10 @@ impl<'conn> Statement<'conn> {
/// Executes a `COPY FROM STDIN` statement, returning the number of rows /// Executes a `COPY FROM STDIN` statement, returning the number of rows
/// added. /// added.
/// ///
/// The contents of the provided `Read`er are passed to the Postgres server /// The data read out of the provided `Read`er are passed to the Postgres
/// verbatim; it is the caller's responsibility to ensure the data is in /// server verbatim; it is the caller's responsibility to ensure the data
/// the proper format. See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html) /// is in the proper format. See the
/// [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
/// for details. /// for details.
/// ///
/// If the statement is not a `COPY FROM STDIN` statement, it will still be /// If the statement is not a `COPY FROM STDIN` statement, it will still be
@ -1547,7 +1548,7 @@ impl<'conn> Statement<'conn> {
/// let stmt = conn.prepare("COPY people FROM STDIN").unwrap(); /// let stmt = conn.prepare("COPY people FROM STDIN").unwrap();
/// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap(); /// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap();
/// ``` /// ```
pub fn copy_in<R: Read>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> { pub fn copy_in<R: ReadWithInfo>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> {
try!(self.inner_execute("", 0, params)); try!(self.inner_execute("", 0, params));
let mut conn = self.conn.conn.borrow_mut(); let mut conn = self.conn.conn.borrow_mut();
@ -1567,16 +1568,15 @@ impl<'conn> Statement<'conn> {
} }
} }
let mut buf = vec![]; let mut buf = [0; 16 * 1024];
loop { loop {
match r.take(16 * 1024).read_to_end(&mut buf) { match fill_copy_buf(&mut buf, r, &SessionInfo::new(&conn)) {
Ok(0) => break, Ok(0) => break,
Ok(_) => { Ok(len) => {
try_desync!(conn, conn.stream.write_message( try_desync!(conn, conn.stream.write_message(
&CopyData { &CopyData {
data: &buf, data: &buf[..len],
})); }));
buf.clear();
} }
Err(err) => { Err(err) => {
try!(conn.write_messages(&[ try!(conn.write_messages(&[
@ -1628,6 +1628,20 @@ impl<'conn> Statement<'conn> {
} }
} }
fn fill_copy_buf<R: ReadWithInfo>(buf: &mut [u8], r: &mut R, info: &SessionInfo)
-> std_io::Result<usize> {
let mut nread = 0;
while nread < buf.len() {
match r.read_with_info(&mut buf[nread..], info) {
Ok(0) => break,
Ok(n) => nread += n,
Err(ref e) if e.kind() == std_io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(nread)
}
/// Information about a column of the result of a query. /// Information about a column of the result of a query.
#[derive(PartialEq, Eq, Clone, Debug)] #[derive(PartialEq, Eq, Clone, Debug)]
pub struct Column { pub struct Column {

View File

@ -4,6 +4,7 @@ use std::collections::HashMap;
use std::error; use std::error;
use std::fmt; use std::fmt;
use std::io::prelude::*; use std::io::prelude::*;
use std::io;
use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian};
pub use self::slice::Slice; pub use self::slice::Slice;
@ -75,6 +76,20 @@ impl<'a> SessionInfo<'a> {
} }
} }
/// Like `Read` except that a `SessionInfo` object is provided as well.
///
/// All types that implement `Read` also implement this trait.
pub trait ReadWithInfo {
/// Like `Read::read`.
fn read_with_info(&mut self, buf: &mut [u8], info: &SessionInfo) -> io::Result<usize>;
}
impl<R: Read> ReadWithInfo for R {
fn read_with_info(&mut self, buf: &mut [u8], _: &SessionInfo) -> io::Result<usize> {
self.read(buf)
}
}
/// A Postgres OID. /// A Postgres OID.
pub type Oid = u32; pub type Oid = u32;