From dd64882d32404fe82c7def23862836a251643424 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Fri, 23 Aug 2013 01:24:14 -0400 Subject: [PATCH] Updates and transactions --- src/lib.rs | 103 ++++++++++++++++++++++++++++++++++++------------- src/message.rs | 37 ++++++++++++++++-- src/test.rs | 9 +++-- 3 files changed, 117 insertions(+), 32 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8156e381..cd949f96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ extern mod extra; use std::cell::Cell; use std::hashmap::HashMap; -use std::rt::io::io_error; use std::rt::io::net::ip::SocketAddr; use std::rt::io::net::tcp::TcpStream; use extra::url::Url; @@ -40,24 +39,20 @@ impl PostgresConnection { match conn.read_message() { AuthenticationOk => (), - resp => fail!("Bad response: %?", resp) + resp => fail!("Bad response: %?", resp.to_str()) } - conn.finish_connect(); - - conn - } - - fn finish_connect(&self) { loop { - match self.read_message() { + match conn.read_message() { ParameterStatus(param, value) => printfln!("Param %s = %s", param, value), BackendKeyData(*) => (), ReadyForQuery(*) => break, - resp => fail!("Bad response: %?", resp) + resp => fail!("Bad response: %?", resp.to_str()) } } + + conn } fn write_message(&self, message: &FrontendMessage) { @@ -83,8 +78,8 @@ impl PostgresConnection { match self.read_message() { ParseComplete => (), - ErrorResponse(ref data) => fail!("Error: %?", data), - resp => fail!("Bad response: %?", resp) + resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()), + resp => fail!("Bad response: %?", resp.to_str()) } self.wait_for_ready(); @@ -94,12 +89,12 @@ impl PostgresConnection { let num_params = match self.read_message() { ParameterDescription(ref types) => types.len(), - resp => fail!("Bad response: %?", resp) + resp => fail!("Bad response: %?", resp.to_str()) }; match self.read_message() { RowDescription(*) | NoData => (), - resp => fail!("Bad response: %?", resp) + resp => fail!("Bad response: %?", resp.to_str()) } self.wait_for_ready(); @@ -112,10 +107,41 @@ impl PostgresConnection { } } + fn query(&self, query: &str) { + self.write_message(&Query(query)); + + loop { + match self.read_message() { + ReadyForQuery(*) => break, + resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()), + _ => () + } + } + } + + pub fn in_transaction(&self, blk: &fn(&PostgresConnection) + -> Result) + -> Result { + self.query("BEGIN"); + + // If this fails, Postgres will rollback when the connection closes + let ret = blk(self); + + if ret.is_ok() { + self.query("COMMIT"); + } else { + self.query("ROLLBACK"); + } + + ret + } + fn wait_for_ready(&self) { - match self.read_message() { - ReadyForQuery(*) => (), - resp => fail!("Bad response: %?", resp) + loop { + match self.read_message() { + ReadyForQuery(*) => break, + resp => fail!("Bad response: %?", resp.to_str()) + } } } } @@ -132,8 +158,12 @@ impl<'self> Drop for PostgresStatement<'self> { fn drop(&self) { self.conn.write_message(&Close('S' as u8, self.name.as_slice())); self.conn.write_message(&Sync); - self.conn.read_message(); // CloseComplete or ErrorResponse - self.conn.wait_for_ready(); + loop { + match self.conn.read_message() { + ReadyForQuery(*) => break, + _ => () + } + } } } @@ -142,25 +172,46 @@ impl<'self> PostgresStatement<'self> { self.num_params } - pub fn query(&self) { - let id = self.next_portal_id.take(); - let portal_name = ifmt!("{:s}_portal_{}", self.name.as_slice(), id); - self.next_portal_id.put_back(id + 1); - + fn execute(&self, portal_name: &str) { let formats = []; let values = []; let result_formats = []; self.conn.write_message(&Bind(portal_name, self.name.as_slice(), formats, values, result_formats)); + self.conn.write_message(&Execute(portal_name.as_slice(), 0)); self.conn.write_message(&Sync); match self.conn.read_message() { BindComplete => (), - ErrorResponse(ref data) => fail!("Error: %?", data), - resp => fail!("Bad response: %?", resp) + resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()), + resp => fail!("Bad response: %?", resp.to_str()) } + } + pub fn update(&self) -> uint { + self.execute(""); + + let mut num = 0; + loop { + match self.conn.read_message() { + CommandComplete(ret) => { + let s = ret.split_iter(' ').last().unwrap(); + match FromStr::from_str(s) { + None => (), + Some(n) => num = n + } + break; + } + DataRow(*) => (), + EmptyQueryResponse => break, + NoticeResponse(*) => (), + resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()), + resp => fail!("Bad response: %?", resp.to_str()) + } + } self.conn.wait_for_ready(); + + num } } diff --git a/src/message.rs b/src/message.rs index a2bb33ee..7a3c634f 100644 --- a/src/message.rs +++ b/src/message.rs @@ -9,13 +9,18 @@ use std::vec; pub static PROTOCOL_VERSION: i32 = 0x0003_0000; +#[deriving(ToStr)] pub enum BackendMessage { AuthenticationOk, BackendKeyData(i32, i32), BindComplete, CloseComplete, + CommandComplete(~str), + DataRow(~[Option<~[u8]>]), + EmptyQueryResponse, ErrorResponse(HashMap), NoData, + NoticeResponse(HashMap), ParameterDescription(~[i32]), ParameterStatus(~str, ~str), ParseComplete, @@ -23,6 +28,7 @@ pub enum BackendMessage { RowDescription(~[RowDescriptionEntry]) } +#[deriving(ToStr)] pub struct RowDescriptionEntry { name: ~str, table_oid: i32, @@ -39,6 +45,7 @@ pub enum FrontendMessage<'self> { &'self [i16]), Close(u8, &'self str), Describe(u8, &'self str), + Execute(&'self str, i32), /// name, query, parameter types Parse(&'self str, &'self str, &'self [i32]), Query(&'self str), @@ -107,6 +114,11 @@ impl WriteMessage for W { buf.write_u8_(variant); buf.write_string(name); } + Execute(name, num_rows) => { + ident = Some('E'); + buf.write_string(name); + buf.write_be_i32_(num_rows); + } Parse(name, query, param_types) => { ident = Some('P'); buf.write_string(name); @@ -185,9 +197,13 @@ impl ReadMessage for R { '1' => ParseComplete, '2' => BindComplete, '3' => CloseComplete, - 'E' => read_error_message(&mut buf), + 'C' => CommandComplete(buf.read_string()), + 'D' => read_data_row(&mut buf), + 'E' => ErrorResponse(read_hash(&mut buf)), + 'I' => EmptyQueryResponse, 'K' => BackendKeyData(buf.read_be_i32_(), buf.read_be_i32_()), 'n' => NoData, + 'N' => NoticeResponse(read_hash(&mut buf)), 'R' => read_auth_message(&mut buf), 'S' => ParameterStatus(buf.read_string(), buf.read_string()), 't' => read_parameter_description(&mut buf), @@ -201,7 +217,7 @@ impl ReadMessage for R { } } -fn read_error_message(buf: &mut MemReader) -> BackendMessage { +fn read_hash(buf: &mut MemReader) -> HashMap { let mut fields = HashMap::new(); loop { let ty = buf.read_u8_(); @@ -212,7 +228,22 @@ fn read_error_message(buf: &mut MemReader) -> BackendMessage { fields.insert(ty, buf.read_string()); } - ErrorResponse(fields) + fields +} + +fn read_data_row(buf: &mut MemReader) -> BackendMessage { + let len = buf.read_be_i16_() as uint; + let mut values = vec::with_capacity(len); + + do len.times() { + let val = match buf.read_be_i32_() { + -1 => None, + len => Some(buf.read_bytes(len as uint)) + }; + values.push(val); + } + + DataRow(values) } fn read_auth_message(buf: &mut MemReader) -> BackendMessage { diff --git a/src/test.rs b/src/test.rs index 0b307c15..69ca1d65 100644 --- a/src/test.rs +++ b/src/test.rs @@ -3,9 +3,12 @@ extern mod postgres; use postgres::PostgresConnection; #[test] -fn test_connect() { +fn test_basic() { let conn = PostgresConnection::connect("postgres://postgres@127.0.0.1:5432"); - let stmt = conn.prepare("CREATE TABLE foo (id BIGINT PRIMARY KEY)"); - stmt.query(); + do conn.in_transaction |conn| { + conn.prepare("CREATE TABLE foo (id BIGINT PRIMARY KEY)").update(); + + Err::<(), ()>(()) + }; }