Add copy_in to Statement

This commit is contained in:
Steven Fackler 2015-05-25 20:40:57 -07:00
parent b5d9a38a59
commit 9b9b82a7db
2 changed files with 93 additions and 1 deletions

View File

@ -72,7 +72,6 @@ use std::mem;
use std::slice;
use std::result;
use std::vec;
use byteorder::{WriteBytesExt, BigEndian};
#[cfg(feature = "unix_socket")]
use std::path::PathBuf;
@ -1598,6 +1597,86 @@ impl<'conn> Statement<'conn> {
})
}
/// Executes a `COPY FROM STDIN` statement, returning the number of rows
/// added.
///
/// The contents of the provided `Read`er are passed to the Postgres server
/// verbatim; it is the caller's responsibility to ensure the data is in
/// the proper format. See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html)
/// for details.
///
/// If the statement is not a `COPY FROM STDIN` statement, this method will
/// return an error though the statement will still be executed.
///
/// # Examples
///
/// ```rust,no_run
/// # use postgres::{Connection, SslMode};
/// # let conn = Connection::connect("", &SslMode::None).unwrap();
/// conn.batch_execute("CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR)").unwrap();
/// let stmt = conn.prepare("COPY people FROM STDIN").unwrap();
/// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap();
/// ```
pub fn copy_in<R: Read>(&self, params: &[&ToSql], r: &mut R) -> Result<u64> {
try!(self.inner_execute("", 0, params));
let mut conn = self.conn.conn.borrow_mut();
match try!(conn.read_message()) {
CopyInResponse { .. } => {}
_ => {
loop {
match try!(conn.read_message()) {
ReadyForQuery { .. } => {
return Err(Error::IoError(std_io::Error::new(
std_io::ErrorKind::InvalidInput,
"called `copy_in` on a non-`COPY FROM STDIN` statement")));
}
_ => {}
}
}
}
}
let mut buf = vec![];
loop {
match std::io::copy(&mut r.take(16 * 1024), &mut buf) {
Ok(0) => break,
Ok(len) => {
try_desync!(conn, conn.stream.write_message(
&CopyData {
data: &buf[..len as usize],
}));
buf.clear();
}
Err(err) => {
// FIXME better to return the error directly
try_desync!(conn, conn.stream.write_message(
&CopyFail {
message: &err.to_string(),
}));
break;
}
}
}
try!(conn.write_messages(&[CopyDone, Sync]));
let num = match try!(conn.read_message()) {
CommandComplete { tag } => util::parse_update_count(tag),
ErrorResponse { fields } => {
try!(conn.wait_for_ready());
return DbError::new(fields);
}
_ => {
conn.desynchronized = true;
return Err(Error::IoError(bad_response()));
}
};
try!(conn.wait_for_ready());
Ok(num)
}
/// Consumes the statement, clearing it from the Postgres session.
///
/// If this statement was created via the `prepare_cached` method, `finish`

View File

@ -765,6 +765,19 @@ fn test_batch_execute_copy_from_err() {
}
}
#[test]
fn test_copy() {
let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None));
or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]));
let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN"));
let mut data = &b"1\n2\n3\n5\n8\n"[..];
assert_eq!(5, or_panic!(stmt.copy_in(&[], &mut data)));
let stmt = or_panic!(conn.prepare("SELECT id FROM foo ORDER BY id"));
assert_eq!(vec![1i32, 2, 3, 5, 8],
stmt.query(&[]).unwrap().iter().map(|r| r.get(0)).collect::<Vec<i32>>());
}
#[test]
// Just make sure the impls don't infinite loop
fn test_generic_connection() {