diff --git a/src/lib.rs b/src/lib.rs index 5d14c7d7..38936512 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,8 @@ use extra::md5::Md5; use extra::url::{UserInfo, Url}; use std::cell::Cell; use std::hashmap::HashMap; -use std::rt::io::io_error; +use std::rt::io::{io_error, Decorator}; +use std::rt::io::mem::MemWriter; use std::rt::io::net::ip::SocketAddr; use std::rt::io::net::tcp::TcpStream; @@ -183,6 +184,16 @@ impl PostgresConnection { Ok(conn) } + fn write_messages(&self, messages: &[&FrontendMessage]) { + let mut buf = MemWriter::new(); + for &message in messages.iter() { + buf.write_message(message); + } + do self.stream.with_mut_ref |s| { + s.write(buf.inner_ref().as_slice()); + } + } + fn write_message(&self, message: &FrontendMessage) { do self.stream.with_mut_ref |s| { s.write_message(message); @@ -249,13 +260,17 @@ impl PostgresConnection { self.next_stmt_id.put_back(id + 1); let types = []; - self.write_message(&Parse { - name: stmt_name, - query: query, - param_types: types - }); - self.write_message(&Describe { variant: 'S' as u8, name: stmt_name }); - self.write_message(&Sync); + self.write_messages([ + &Parse { + name: stmt_name, + query: query, + param_types: types + }, + &Describe { + variant: 'S' as u8, + name: stmt_name + }, + &Sync]); match_read_message!(self, { ParseComplete => (), @@ -371,11 +386,12 @@ pub struct PostgresStatement<'self> { impl<'self> Drop for PostgresStatement<'self> { fn drop(&self) { do io_error::cond.trap(|_| {}).inside { - self.conn.write_message(&Close { - variant: 'S' as u8, - name: self.name.as_slice() - }); - self.conn.write_message(&Sync); + self.conn.write_messages([ + &Close { + variant: 'S' as u8, + name: self.name.as_slice() + }, + &Sync]); loop { match_read_message!(self.conn, { ReadyForQuery {_} => break, @@ -399,18 +415,19 @@ impl<'self> PostgresStatement<'self> { let result_formats = []; - self.conn.write_message(&Bind { - portal: portal_name, - statement: self.name.as_slice(), - formats: formats, - values: values, - result_formats: result_formats - }); - self.conn.write_message(&Execute { - portal: portal_name.as_slice(), - max_rows: 0 - }); - self.conn.write_message(&Sync); + self.conn.write_messages([ + &Bind { + portal: portal_name, + statement: self.name.as_slice(), + formats: formats, + values: values, + result_formats: result_formats + }, + &Execute { + portal: portal_name.as_slice(), + max_rows: 0 + }, + &Sync]); match_read_message!(self.conn, { BindComplete => None, @@ -521,11 +538,12 @@ pub struct PostgresResult<'self> { impl<'self> Drop for PostgresResult<'self> { fn drop(&self) { do io_error::cond.trap(|_| {}).inside { - self.stmt.conn.write_message(&Close { - variant: 'P' as u8, - name: self.name.as_slice() - }); - self.stmt.conn.write_message(&Sync); + self.stmt.conn.write_messages([ + &Close { + variant: 'P' as u8, + name: self.name.as_slice() + }, + &Sync]); loop { match_read_message!(self.stmt.conn, { ReadyForQuery {_} => break,