diff --git a/src/lib.rs b/src/lib.rs index 98e62605..48341a3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,7 +73,7 @@ use std::path::PathBuf; use error::{Error, ConnectError, SqlState, DbError}; use types::{ToSql, FromSql}; use io::{StreamWrapper, NegotiateSsl}; -use types::{IsNull, Kind, Type, SessionInfo, Oid, Other}; +use types::{IsNull, Kind, Type, SessionInfo, Oid, Other, ReadWithInfo}; use message::BackendMessage::*; use message::FrontendMessage::*; use message::{FrontendMessage, BackendMessage, RowDescriptionEntry}; @@ -1530,9 +1530,10 @@ impl<'conn> Statement<'conn> { /// Executes a `COPY FROM STDIN` statement, returning the number of rows /// added. /// - /// The contents of the provided `Read`er are passed to the Postgres server - /// verbatim; it is the caller's responsibility to ensure the data is in - /// the proper format. See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html) + /// The data read out of the provided `Read`er are passed to the Postgres + /// server verbatim; it is the caller's responsibility to ensure the data + /// is in the proper format. See the + /// [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html) /// for details. /// /// If the statement is not a `COPY FROM STDIN` statement, it will still be @@ -1547,7 +1548,7 @@ impl<'conn> Statement<'conn> { /// let stmt = conn.prepare("COPY people FROM STDIN").unwrap(); /// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap(); /// ``` - pub fn copy_in(&self, params: &[&ToSql], r: &mut R) -> Result { + pub fn copy_in(&self, params: &[&ToSql], r: &mut R) -> Result { try!(self.inner_execute("", 0, params)); let mut conn = self.conn.conn.borrow_mut(); @@ -1567,16 +1568,15 @@ impl<'conn> Statement<'conn> { } } - let mut buf = vec![]; + let mut buf = [0; 16 * 1024]; loop { - match r.take(16 * 1024).read_to_end(&mut buf) { + match fill_copy_buf(&mut buf, r, &SessionInfo::new(&conn)) { Ok(0) => break, - Ok(_) => { + Ok(len) => { try_desync!(conn, conn.stream.write_message( &CopyData { - data: &buf, + data: &buf[..len], })); - buf.clear(); } Err(err) => { try!(conn.write_messages(&[ @@ -1628,6 +1628,20 @@ impl<'conn> Statement<'conn> { } } +fn fill_copy_buf(buf: &mut [u8], r: &mut R, info: &SessionInfo) + -> std_io::Result { + let mut nread = 0; + while nread < buf.len() { + match r.read_with_info(&mut buf[nread..], info) { + Ok(0) => break, + Ok(n) => nread += n, + Err(ref e) if e.kind() == std_io::ErrorKind::Interrupted => {} + Err(e) => return Err(e), + } + } + Ok(nread) +} + /// Information about a column of the result of a query. #[derive(PartialEq, Eq, Clone, Debug)] pub struct Column { diff --git a/src/types/mod.rs b/src/types/mod.rs index 49f5cc03..2eab704a 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::error; use std::fmt; use std::io::prelude::*; +use std::io; use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; pub use self::slice::Slice; @@ -75,6 +76,20 @@ impl<'a> SessionInfo<'a> { } } +/// Like `Read` except that a `SessionInfo` object is provided as well. +/// +/// All types that implement `Read` also implement this trait. +pub trait ReadWithInfo { + /// Like `Read::read`. + fn read_with_info(&mut self, buf: &mut [u8], info: &SessionInfo) -> io::Result; +} + +impl ReadWithInfo for R { + fn read_with_info(&mut self, buf: &mut [u8], _: &SessionInfo) -> io::Result { + self.read(buf) + } +} + /// A Postgres OID. pub type Oid = u32;