Updates and transactions
This commit is contained in:
parent
e5c5b783f3
commit
dd64882d32
103
src/lib.rs
103
src/lib.rs
@ -2,7 +2,6 @@ extern mod extra;
|
||||
|
||||
use std::cell::Cell;
|
||||
use std::hashmap::HashMap;
|
||||
use std::rt::io::io_error;
|
||||
use std::rt::io::net::ip::SocketAddr;
|
||||
use std::rt::io::net::tcp::TcpStream;
|
||||
use extra::url::Url;
|
||||
@ -40,24 +39,20 @@ impl PostgresConnection {
|
||||
|
||||
match conn.read_message() {
|
||||
AuthenticationOk => (),
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
|
||||
conn.finish_connect();
|
||||
|
||||
conn
|
||||
}
|
||||
|
||||
fn finish_connect(&self) {
|
||||
loop {
|
||||
match self.read_message() {
|
||||
match conn.read_message() {
|
||||
ParameterStatus(param, value) =>
|
||||
printfln!("Param %s = %s", param, value),
|
||||
BackendKeyData(*) => (),
|
||||
ReadyForQuery(*) => break,
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
}
|
||||
|
||||
conn
|
||||
}
|
||||
|
||||
fn write_message(&self, message: &FrontendMessage) {
|
||||
@ -83,8 +78,8 @@ impl PostgresConnection {
|
||||
|
||||
match self.read_message() {
|
||||
ParseComplete => (),
|
||||
ErrorResponse(ref data) => fail!("Error: %?", data),
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
|
||||
self.wait_for_ready();
|
||||
@ -94,12 +89,12 @@ impl PostgresConnection {
|
||||
|
||||
let num_params = match self.read_message() {
|
||||
ParameterDescription(ref types) => types.len(),
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
};
|
||||
|
||||
match self.read_message() {
|
||||
RowDescription(*) | NoData => (),
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
|
||||
self.wait_for_ready();
|
||||
@ -112,10 +107,41 @@ impl PostgresConnection {
|
||||
}
|
||||
}
|
||||
|
||||
fn query(&self, query: &str) {
|
||||
self.write_message(&Query(query));
|
||||
|
||||
loop {
|
||||
match self.read_message() {
|
||||
ReadyForQuery(*) => break,
|
||||
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn in_transaction<T, E: ToStr>(&self, blk: &fn(&PostgresConnection)
|
||||
-> Result<T, E>)
|
||||
-> Result<T, E> {
|
||||
self.query("BEGIN");
|
||||
|
||||
// If this fails, Postgres will rollback when the connection closes
|
||||
let ret = blk(self);
|
||||
|
||||
if ret.is_ok() {
|
||||
self.query("COMMIT");
|
||||
} else {
|
||||
self.query("ROLLBACK");
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
fn wait_for_ready(&self) {
|
||||
match self.read_message() {
|
||||
ReadyForQuery(*) => (),
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
loop {
|
||||
match self.read_message() {
|
||||
ReadyForQuery(*) => break,
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -132,8 +158,12 @@ impl<'self> Drop for PostgresStatement<'self> {
|
||||
fn drop(&self) {
|
||||
self.conn.write_message(&Close('S' as u8, self.name.as_slice()));
|
||||
self.conn.write_message(&Sync);
|
||||
self.conn.read_message(); // CloseComplete or ErrorResponse
|
||||
self.conn.wait_for_ready();
|
||||
loop {
|
||||
match self.conn.read_message() {
|
||||
ReadyForQuery(*) => break,
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,25 +172,46 @@ impl<'self> PostgresStatement<'self> {
|
||||
self.num_params
|
||||
}
|
||||
|
||||
pub fn query(&self) {
|
||||
let id = self.next_portal_id.take();
|
||||
let portal_name = ifmt!("{:s}_portal_{}", self.name.as_slice(), id);
|
||||
self.next_portal_id.put_back(id + 1);
|
||||
|
||||
fn execute(&self, portal_name: &str) {
|
||||
let formats = [];
|
||||
let values = [];
|
||||
let result_formats = [];
|
||||
|
||||
self.conn.write_message(&Bind(portal_name, self.name.as_slice(),
|
||||
formats, values, result_formats));
|
||||
self.conn.write_message(&Execute(portal_name.as_slice(), 0));
|
||||
self.conn.write_message(&Sync);
|
||||
|
||||
match self.conn.read_message() {
|
||||
BindComplete => (),
|
||||
ErrorResponse(ref data) => fail!("Error: %?", data),
|
||||
resp => fail!("Bad response: %?", resp)
|
||||
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&self) -> uint {
|
||||
self.execute("");
|
||||
|
||||
let mut num = 0;
|
||||
loop {
|
||||
match self.conn.read_message() {
|
||||
CommandComplete(ret) => {
|
||||
let s = ret.split_iter(' ').last().unwrap();
|
||||
match FromStr::from_str(s) {
|
||||
None => (),
|
||||
Some(n) => num = n
|
||||
}
|
||||
break;
|
||||
}
|
||||
DataRow(*) => (),
|
||||
EmptyQueryResponse => break,
|
||||
NoticeResponse(*) => (),
|
||||
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
|
||||
resp => fail!("Bad response: %?", resp.to_str())
|
||||
}
|
||||
}
|
||||
self.conn.wait_for_ready();
|
||||
|
||||
num
|
||||
}
|
||||
}
|
||||
|
@ -9,13 +9,18 @@ use std::vec;
|
||||
|
||||
pub static PROTOCOL_VERSION: i32 = 0x0003_0000;
|
||||
|
||||
#[deriving(ToStr)]
|
||||
pub enum BackendMessage {
|
||||
AuthenticationOk,
|
||||
BackendKeyData(i32, i32),
|
||||
BindComplete,
|
||||
CloseComplete,
|
||||
CommandComplete(~str),
|
||||
DataRow(~[Option<~[u8]>]),
|
||||
EmptyQueryResponse,
|
||||
ErrorResponse(HashMap<u8, ~str>),
|
||||
NoData,
|
||||
NoticeResponse(HashMap<u8, ~str>),
|
||||
ParameterDescription(~[i32]),
|
||||
ParameterStatus(~str, ~str),
|
||||
ParseComplete,
|
||||
@ -23,6 +28,7 @@ pub enum BackendMessage {
|
||||
RowDescription(~[RowDescriptionEntry])
|
||||
}
|
||||
|
||||
#[deriving(ToStr)]
|
||||
pub struct RowDescriptionEntry {
|
||||
name: ~str,
|
||||
table_oid: i32,
|
||||
@ -39,6 +45,7 @@ pub enum FrontendMessage<'self> {
|
||||
&'self [i16]),
|
||||
Close(u8, &'self str),
|
||||
Describe(u8, &'self str),
|
||||
Execute(&'self str, i32),
|
||||
/// name, query, parameter types
|
||||
Parse(&'self str, &'self str, &'self [i32]),
|
||||
Query(&'self str),
|
||||
@ -107,6 +114,11 @@ impl<W: Writer> WriteMessage for W {
|
||||
buf.write_u8_(variant);
|
||||
buf.write_string(name);
|
||||
}
|
||||
Execute(name, num_rows) => {
|
||||
ident = Some('E');
|
||||
buf.write_string(name);
|
||||
buf.write_be_i32_(num_rows);
|
||||
}
|
||||
Parse(name, query, param_types) => {
|
||||
ident = Some('P');
|
||||
buf.write_string(name);
|
||||
@ -185,9 +197,13 @@ impl<R: Reader> ReadMessage for R {
|
||||
'1' => ParseComplete,
|
||||
'2' => BindComplete,
|
||||
'3' => CloseComplete,
|
||||
'E' => read_error_message(&mut buf),
|
||||
'C' => CommandComplete(buf.read_string()),
|
||||
'D' => read_data_row(&mut buf),
|
||||
'E' => ErrorResponse(read_hash(&mut buf)),
|
||||
'I' => EmptyQueryResponse,
|
||||
'K' => BackendKeyData(buf.read_be_i32_(), buf.read_be_i32_()),
|
||||
'n' => NoData,
|
||||
'N' => NoticeResponse(read_hash(&mut buf)),
|
||||
'R' => read_auth_message(&mut buf),
|
||||
'S' => ParameterStatus(buf.read_string(), buf.read_string()),
|
||||
't' => read_parameter_description(&mut buf),
|
||||
@ -201,7 +217,7 @@ impl<R: Reader> ReadMessage for R {
|
||||
}
|
||||
}
|
||||
|
||||
fn read_error_message(buf: &mut MemReader) -> BackendMessage {
|
||||
fn read_hash(buf: &mut MemReader) -> HashMap<u8, ~str> {
|
||||
let mut fields = HashMap::new();
|
||||
loop {
|
||||
let ty = buf.read_u8_();
|
||||
@ -212,7 +228,22 @@ fn read_error_message(buf: &mut MemReader) -> BackendMessage {
|
||||
fields.insert(ty, buf.read_string());
|
||||
}
|
||||
|
||||
ErrorResponse(fields)
|
||||
fields
|
||||
}
|
||||
|
||||
fn read_data_row(buf: &mut MemReader) -> BackendMessage {
|
||||
let len = buf.read_be_i16_() as uint;
|
||||
let mut values = vec::with_capacity(len);
|
||||
|
||||
do len.times() {
|
||||
let val = match buf.read_be_i32_() {
|
||||
-1 => None,
|
||||
len => Some(buf.read_bytes(len as uint))
|
||||
};
|
||||
values.push(val);
|
||||
}
|
||||
|
||||
DataRow(values)
|
||||
}
|
||||
|
||||
fn read_auth_message(buf: &mut MemReader) -> BackendMessage {
|
||||
|
@ -3,9 +3,12 @@ extern mod postgres;
|
||||
use postgres::PostgresConnection;
|
||||
|
||||
#[test]
|
||||
fn test_connect() {
|
||||
fn test_basic() {
|
||||
let conn = PostgresConnection::connect("postgres://postgres@127.0.0.1:5432");
|
||||
|
||||
let stmt = conn.prepare("CREATE TABLE foo (id BIGINT PRIMARY KEY)");
|
||||
stmt.query();
|
||||
do conn.in_transaction |conn| {
|
||||
conn.prepare("CREATE TABLE foo (id BIGINT PRIMARY KEY)").update();
|
||||
|
||||
Err::<(), ()>(())
|
||||
};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user