diff --git a/src/stmt.rs b/src/stmt.rs index 0b423d6f..9ebbb19d 100644 --- a/src/stmt.rs +++ b/src/stmt.rs @@ -7,7 +7,7 @@ use std::fmt; use std::io::{self, Cursor, BufRead, Read}; use error::{Error, DbError}; -use types::{ReadWithInfo, SessionInfo, Type, ToSql, IsNull}; +use types::{SessionInfo, Type, ToSql, IsNull}; use message::FrontendMessage::*; use message::BackendMessage::*; use message::WriteMessage; @@ -309,8 +309,8 @@ impl<'conn> Statement<'conn> { try!(self.inner_execute("", 0, params)); let mut conn = self.conn.conn.borrow_mut(); - match try!(conn.read_message()) { - CopyInResponse { .. } => {} + let (format, column_formats) = match try!(conn.read_message()) { + CopyInResponse { format, column_formats } => (format, column_formats), _ => { loop { match try!(conn.read_message()) { @@ -323,53 +323,59 @@ impl<'conn> Statement<'conn> { } } } - } + }; + + let mut info = CopyInfo { + conn: conn, + format: Format::from_u16(format as u16), + column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(), + }; let mut buf = [0; 16 * 1024]; loop { - match fill_copy_buf(&mut buf, r, &SessionInfo::new(&conn)) { + match fill_copy_buf(&mut buf, r, &info) { Ok(0) => break, Ok(len) => { - try_desync!(conn, conn.stream.write_message( + try_desync!(info.conn, info.conn.stream.write_message( &CopyData { data: &buf[..len], })); } Err(err) => { - try!(conn.write_messages(&[ + try!(info.conn.write_messages(&[ CopyFail { message: "", }, CopyDone, Sync])); - match try!(conn.read_message()) { + match try!(info.conn.read_message()) { ErrorResponse { .. } => { /* expected from the CopyFail */ } _ => { - conn.desynchronized = true; + info.conn.desynchronized = true; return Err(Error::IoError(bad_response())); } } - try!(conn.wait_for_ready()); + try!(info.conn.wait_for_ready()); return Err(Error::IoError(err)); } } } - try!(conn.write_messages(&[CopyDone, Sync])); + try!(info.conn.write_messages(&[CopyDone, Sync])); - let num = match try!(conn.read_message()) { + let num = match try!(info.conn.read_message()) { CommandComplete { tag } => util::parse_update_count(tag), ErrorResponse { fields } => { - try!(conn.wait_for_ready()); + try!(info.conn.wait_for_ready()); return DbError::new(fields); } _ => { - conn.desynchronized = true; + info.conn.desynchronized = true; return Err(Error::IoError(bad_response())); } }; - try!(conn.wait_for_ready()); + try!(info.conn.wait_for_ready()); Ok(num) } @@ -443,9 +449,11 @@ impl<'conn> Statement<'conn> { }; Ok(CopyOutReader { - conn: conn, - format: Format::from_u16(format as u16), - column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(), + info: CopyInfo { + conn: conn, + format: Format::from_u16(format as u16), + column_formats: column_formats.iter().map(|&f| Format::from_u16(f)).collect(), + }, buf: Cursor::new(vec![]), finished: false, }) @@ -463,7 +471,7 @@ impl<'conn> Statement<'conn> { } } -fn fill_copy_buf(buf: &mut [u8], r: &mut R, info: &SessionInfo) +fn fill_copy_buf(buf: &mut [u8], r: &mut R, info: &CopyInfo) -> io::Result { let mut nread = 0; while nread < buf.len() { @@ -493,6 +501,44 @@ impl ColumnNew for Column { } } +/// A struct containing information relevant for a `COPY` operation. +pub struct CopyInfo<'a> { + conn: RefMut<'a, InnerConnection>, + format: Format, + column_formats: Vec, +} + +impl<'a> CopyInfo<'a> { + /// Returns the format of the overall data. + pub fn format(&self) -> Format { + self.format + } + + /// Returns the format of the individual columns. + pub fn column_formats(&self) -> &[Format] { + &self.column_formats + } + + /// Returns session info for the associated connection. + pub fn session_info<'b>(&'b self) -> SessionInfo<'b> { + SessionInfo::new(&*self.conn) + } +} + +/// Like `Read` except that a `CopyInfo` object is provided as well. +/// +/// All types that implement `Read` also implement this trait. +pub trait ReadWithInfo { + /// Like `Read::read`. + fn read_with_info(&mut self, buf: &mut [u8], info: &CopyInfo) -> io::Result; +} + +impl ReadWithInfo for R { + fn read_with_info(&mut self, buf: &mut [u8], _: &CopyInfo) -> io::Result { + self.read(buf) + } +} + impl Column { /// The name of the column. pub fn name(&self) -> &str { @@ -530,9 +576,7 @@ impl Format { /// The underlying connection may not be used while a `CopyOutReader` exists. /// Any attempt to do so will panic. pub struct CopyOutReader<'a> { - conn: RefMut<'a, InnerConnection>, - format: Format, - column_formats: Vec, + info: CopyInfo<'a>, buf: Cursor>, finished: bool, } @@ -544,19 +588,9 @@ impl<'a> Drop for CopyOutReader<'a> { } impl<'a> CopyOutReader<'a> { - /// Returns the format of the overall data. - pub fn format(&self) -> Format { - self.format - } - - /// Returns the format of the individual columns. - pub fn column_formats(&self) -> &[Format] { - &self.column_formats - } - - /// Returns session info for the associated connection. - pub fn session_info<'b>(&'b self) -> SessionInfo<'b> { - SessionInfo::new(&*self.conn) + /// Returns the `CopyInfo` for the current operation. + pub fn info(&self) -> &CopyInfo { + &self.info } /// Consumes the `CopyOutReader`, throwing away any unread data. @@ -581,26 +615,26 @@ impl<'a> CopyOutReader<'a> { return Ok(()); } - match try!(self.conn.read_message()) { + match try!(self.info.conn.read_message()) { BCopyData { data } => self.buf = Cursor::new(data), BCopyDone => { self.finished = true; - match try!(self.conn.read_message()) { + match try!(self.info.conn.read_message()) { CommandComplete { .. } => {} _ => { - self.conn.desynchronized = true; + self.info.conn.desynchronized = true; return Err(Error::IoError(bad_response())); } } - try!(self.conn.wait_for_ready()); + try!(self.info.conn.wait_for_ready()); } ErrorResponse { fields } => { self.finished = true; - try!(self.conn.wait_for_ready()); + try!(self.info.conn.wait_for_ready()); return DbError::new(fields); } _ => { - self.conn.desynchronized = true; + self.info.conn.desynchronized = true; return Err(Error::IoError(bad_response())); } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 9364cf22..1ab388f4 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -76,20 +76,6 @@ impl<'a> SessionInfo<'a> { } } -/// Like `Read` except that a `SessionInfo` object is provided as well. -/// -/// All types that implement `Read` also implement this trait. -pub trait ReadWithInfo { - /// Like `Read::read`. - fn read_with_info(&mut self, buf: &mut [u8], info: &SessionInfo) -> io::Result; -} - -impl ReadWithInfo for R { - fn read_with_info(&mut self, buf: &mut [u8], _: &SessionInfo) -> io::Result { - self.read(buf) - } -} - /// A Postgres OID. pub type Oid = u32;