Updates and transactions

This commit is contained in:
Steven Fackler 2013-08-23 01:24:14 -04:00
parent e5c5b783f3
commit dd64882d32
3 changed files with 117 additions and 32 deletions

View File

@ -2,7 +2,6 @@ extern mod extra;
use std::cell::Cell;
use std::hashmap::HashMap;
use std::rt::io::io_error;
use std::rt::io::net::ip::SocketAddr;
use std::rt::io::net::tcp::TcpStream;
use extra::url::Url;
@ -40,24 +39,20 @@ impl PostgresConnection {
match conn.read_message() {
AuthenticationOk => (),
resp => fail!("Bad response: %?", resp)
resp => fail!("Bad response: %?", resp.to_str())
}
conn.finish_connect();
conn
}
fn finish_connect(&self) {
loop {
match self.read_message() {
match conn.read_message() {
ParameterStatus(param, value) =>
printfln!("Param %s = %s", param, value),
BackendKeyData(*) => (),
ReadyForQuery(*) => break,
resp => fail!("Bad response: %?", resp)
resp => fail!("Bad response: %?", resp.to_str())
}
}
conn
}
fn write_message(&self, message: &FrontendMessage) {
@ -83,8 +78,8 @@ impl PostgresConnection {
match self.read_message() {
ParseComplete => (),
ErrorResponse(ref data) => fail!("Error: %?", data),
resp => fail!("Bad response: %?", resp)
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
}
self.wait_for_ready();
@ -94,12 +89,12 @@ impl PostgresConnection {
let num_params = match self.read_message() {
ParameterDescription(ref types) => types.len(),
resp => fail!("Bad response: %?", resp)
resp => fail!("Bad response: %?", resp.to_str())
};
match self.read_message() {
RowDescription(*) | NoData => (),
resp => fail!("Bad response: %?", resp)
resp => fail!("Bad response: %?", resp.to_str())
}
self.wait_for_ready();
@ -112,10 +107,41 @@ impl PostgresConnection {
}
}
fn query(&self, query: &str) {
self.write_message(&Query(query));
loop {
match self.read_message() {
ReadyForQuery(*) => break,
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
_ => ()
}
}
}
pub fn in_transaction<T, E: ToStr>(&self, blk: &fn(&PostgresConnection)
-> Result<T, E>)
-> Result<T, E> {
self.query("BEGIN");
// If this fails, Postgres will rollback when the connection closes
let ret = blk(self);
if ret.is_ok() {
self.query("COMMIT");
} else {
self.query("ROLLBACK");
}
ret
}
fn wait_for_ready(&self) {
match self.read_message() {
ReadyForQuery(*) => (),
resp => fail!("Bad response: %?", resp)
loop {
match self.read_message() {
ReadyForQuery(*) => break,
resp => fail!("Bad response: %?", resp.to_str())
}
}
}
}
@ -132,8 +158,12 @@ impl<'self> Drop for PostgresStatement<'self> {
fn drop(&self) {
self.conn.write_message(&Close('S' as u8, self.name.as_slice()));
self.conn.write_message(&Sync);
self.conn.read_message(); // CloseComplete or ErrorResponse
self.conn.wait_for_ready();
loop {
match self.conn.read_message() {
ReadyForQuery(*) => break,
_ => ()
}
}
}
}
@ -142,25 +172,46 @@ impl<'self> PostgresStatement<'self> {
self.num_params
}
pub fn query(&self) {
let id = self.next_portal_id.take();
let portal_name = ifmt!("{:s}_portal_{}", self.name.as_slice(), id);
self.next_portal_id.put_back(id + 1);
fn execute(&self, portal_name: &str) {
let formats = [];
let values = [];
let result_formats = [];
self.conn.write_message(&Bind(portal_name, self.name.as_slice(),
formats, values, result_formats));
self.conn.write_message(&Execute(portal_name.as_slice(), 0));
self.conn.write_message(&Sync);
match self.conn.read_message() {
BindComplete => (),
ErrorResponse(ref data) => fail!("Error: %?", data),
resp => fail!("Bad response: %?", resp)
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
}
}
pub fn update(&self) -> uint {
self.execute("");
let mut num = 0;
loop {
match self.conn.read_message() {
CommandComplete(ret) => {
let s = ret.split_iter(' ').last().unwrap();
match FromStr::from_str(s) {
None => (),
Some(n) => num = n
}
break;
}
DataRow(*) => (),
EmptyQueryResponse => break,
NoticeResponse(*) => (),
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
}
}
self.conn.wait_for_ready();
num
}
}

View File

@ -9,13 +9,18 @@ use std::vec;
pub static PROTOCOL_VERSION: i32 = 0x0003_0000;
#[deriving(ToStr)]
pub enum BackendMessage {
AuthenticationOk,
BackendKeyData(i32, i32),
BindComplete,
CloseComplete,
CommandComplete(~str),
DataRow(~[Option<~[u8]>]),
EmptyQueryResponse,
ErrorResponse(HashMap<u8, ~str>),
NoData,
NoticeResponse(HashMap<u8, ~str>),
ParameterDescription(~[i32]),
ParameterStatus(~str, ~str),
ParseComplete,
@ -23,6 +28,7 @@ pub enum BackendMessage {
RowDescription(~[RowDescriptionEntry])
}
#[deriving(ToStr)]
pub struct RowDescriptionEntry {
name: ~str,
table_oid: i32,
@ -39,6 +45,7 @@ pub enum FrontendMessage<'self> {
&'self [i16]),
Close(u8, &'self str),
Describe(u8, &'self str),
Execute(&'self str, i32),
/// name, query, parameter types
Parse(&'self str, &'self str, &'self [i32]),
Query(&'self str),
@ -107,6 +114,11 @@ impl<W: Writer> WriteMessage for W {
buf.write_u8_(variant);
buf.write_string(name);
}
Execute(name, num_rows) => {
ident = Some('E');
buf.write_string(name);
buf.write_be_i32_(num_rows);
}
Parse(name, query, param_types) => {
ident = Some('P');
buf.write_string(name);
@ -185,9 +197,13 @@ impl<R: Reader> ReadMessage for R {
'1' => ParseComplete,
'2' => BindComplete,
'3' => CloseComplete,
'E' => read_error_message(&mut buf),
'C' => CommandComplete(buf.read_string()),
'D' => read_data_row(&mut buf),
'E' => ErrorResponse(read_hash(&mut buf)),
'I' => EmptyQueryResponse,
'K' => BackendKeyData(buf.read_be_i32_(), buf.read_be_i32_()),
'n' => NoData,
'N' => NoticeResponse(read_hash(&mut buf)),
'R' => read_auth_message(&mut buf),
'S' => ParameterStatus(buf.read_string(), buf.read_string()),
't' => read_parameter_description(&mut buf),
@ -201,7 +217,7 @@ impl<R: Reader> ReadMessage for R {
}
}
fn read_error_message(buf: &mut MemReader) -> BackendMessage {
fn read_hash(buf: &mut MemReader) -> HashMap<u8, ~str> {
let mut fields = HashMap::new();
loop {
let ty = buf.read_u8_();
@ -212,7 +228,22 @@ fn read_error_message(buf: &mut MemReader) -> BackendMessage {
fields.insert(ty, buf.read_string());
}
ErrorResponse(fields)
fields
}
fn read_data_row(buf: &mut MemReader) -> BackendMessage {
let len = buf.read_be_i16_() as uint;
let mut values = vec::with_capacity(len);
do len.times() {
let val = match buf.read_be_i32_() {
-1 => None,
len => Some(buf.read_bytes(len as uint))
};
values.push(val);
}
DataRow(values)
}
fn read_auth_message(buf: &mut MemReader) -> BackendMessage {

View File

@ -3,9 +3,12 @@ extern mod postgres;
use postgres::PostgresConnection;
#[test]
fn test_connect() {
fn test_basic() {
let conn = PostgresConnection::connect("postgres://postgres@127.0.0.1:5432");
let stmt = conn.prepare("CREATE TABLE foo (id BIGINT PRIMARY KEY)");
stmt.query();
do conn.in_transaction |conn| {
conn.prepare("CREATE TABLE foo (id BIGINT PRIMARY KEY)").update();
Err::<(), ()>(())
};
}