From 9b9b82a7db615671f225997241cfbaf332748c03 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Mon, 25 May 2015 20:40:57 -0700 Subject: [PATCH] Add copy_in to Statement --- src/lib.rs | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++- tests/test.rs | 13 +++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index f9d9a770..3e619dbf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,6 @@ use std::mem; use std::slice; use std::result; use std::vec; -use byteorder::{WriteBytesExt, BigEndian}; #[cfg(feature = "unix_socket")] use std::path::PathBuf; @@ -1598,6 +1597,86 @@ impl<'conn> Statement<'conn> { }) } + /// Executes a `COPY FROM STDIN` statement, returning the number of rows + /// added. + /// + /// The contents of the provided `Read`er are passed to the Postgres server + /// verbatim; it is the caller's responsibility to ensure the data is in + /// the proper format. See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html) + /// for details. + /// + /// If the statement is not a `COPY FROM STDIN` statement, this method will + /// return an error though the statement will still be executed. + /// + /// # Examples + /// + /// ```rust,no_run + /// # use postgres::{Connection, SslMode}; + /// # let conn = Connection::connect("", &SslMode::None).unwrap(); + /// conn.batch_execute("CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR)").unwrap(); + /// let stmt = conn.prepare("COPY people FROM STDIN").unwrap(); + /// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap(); + /// ``` + pub fn copy_in(&self, params: &[&ToSql], r: &mut R) -> Result { + try!(self.inner_execute("", 0, params)); + let mut conn = self.conn.conn.borrow_mut(); + + match try!(conn.read_message()) { + CopyInResponse { .. } => {} + _ => { + loop { + match try!(conn.read_message()) { + ReadyForQuery { .. } => { + return Err(Error::IoError(std_io::Error::new( + std_io::ErrorKind::InvalidInput, + "called `copy_in` on a non-`COPY FROM STDIN` statement"))); + } + _ => {} + } + } + } + } + + let mut buf = vec![]; + loop { + match std::io::copy(&mut r.take(16 * 1024), &mut buf) { + Ok(0) => break, + Ok(len) => { + try_desync!(conn, conn.stream.write_message( + &CopyData { + data: &buf[..len as usize], + })); + buf.clear(); + } + Err(err) => { + // FIXME better to return the error directly + try_desync!(conn, conn.stream.write_message( + &CopyFail { + message: &err.to_string(), + })); + break; + } + } + } + + try!(conn.write_messages(&[CopyDone, Sync])); + + let num = match try!(conn.read_message()) { + CommandComplete { tag } => util::parse_update_count(tag), + ErrorResponse { fields } => { + try!(conn.wait_for_ready()); + return DbError::new(fields); + } + _ => { + conn.desynchronized = true; + return Err(Error::IoError(bad_response())); + } + }; + + try!(conn.wait_for_ready()); + Ok(num) + } + /// Consumes the statement, clearing it from the Postgres session. /// /// If this statement was created via the `prepare_cached` method, `finish` diff --git a/tests/test.rs b/tests/test.rs index 08176f81..e4094fa0 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -765,6 +765,19 @@ fn test_batch_execute_copy_from_err() { } } +#[test] +fn test_copy() { + let conn = or_panic!(Connection::connect("postgres://postgres@localhost", &SslMode::None)); + or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); + let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); + let mut data = &b"1\n2\n3\n5\n8\n"[..]; + assert_eq!(5, or_panic!(stmt.copy_in(&[], &mut data))); + let stmt = or_panic!(conn.prepare("SELECT id FROM foo ORDER BY id")); + assert_eq!(vec![1i32, 2, 3, 5, 8], + stmt.query(&[]).unwrap().iter().map(|r| r.get(0)).collect::>()); +} + + #[test] // Just make sure the impls don't infinite loop fn test_generic_connection() {