From 49bed84c8108a216959aced67a8b00068ed6e6c4 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 4 Aug 2013 01:21:16 -0400 Subject: [PATCH] Query parameters and transactions --- notes.md | 9 ++++ src/postgres/lib.rs | 119 +++++++++++++++++++++++++++++++++++-------- src/postgres/test.rs | 42 +++++++++++---- 3 files changed, 139 insertions(+), 31 deletions(-) create mode 100644 notes.md diff --git a/notes.md b/notes.md new file mode 100644 index 00000000..bbcf0355 --- /dev/null +++ b/notes.md @@ -0,0 +1,9 @@ +Results + * Postgres: Query a returns PGresult struct with all information that's + contents are valid as long as the struct exists, even after the connection + is closed. Can simply return unique pointer. Supports random access. + + * SQLite: Retrieve one row at a time from the prepared statement object. No + random access, forward only. Can't even return a result object since a + second query of the same prepared statement will lose the state of the old + result. diff --git a/src/postgres/lib.rs b/src/postgres/lib.rs index c1685d02..6e5ce01f 100644 --- a/src/postgres/lib.rs +++ b/src/postgres/lib.rs @@ -4,6 +4,7 @@ use std::ptr; use std::libc::{c_void, c_char, c_int}; use std::cast; use std::iterator::RandomAccessIterator; +use std::vec; mod ffi { use std::libc::{c_char, c_int, c_uchar, c_uint, c_void}; @@ -37,12 +38,14 @@ mod ffi { fn PQclear(result: *PGresult); fn PQprepare(conn: *PGconn, stmtName: *c_char, query: *c_char, nParams: c_int, paramTypes: *Oid) -> *PGresult; + fn PQdescribePrepared(conn: *PGconn, stmtName: *c_char) -> *PGresult; fn PQexecPrepared(conn: *PGconn, stmtName: *c_char, nParams: c_int, paramValues: **c_char, paramLengths: *c_int, paramFormats: *c_int, resultFormat: c_int) -> *PGresult; fn PQntuples(result: *PGresult) -> c_int; fn PQnfields(result: *PGresult) -> c_int; + fn PQnparams(result: *PGresult) -> c_int; fn PQcmdTuples(result: *PGresult) -> *c_char; fn PQgetvalue(result: *PGresult, row_number: c_int, col_number: c_int) -> *c_char; @@ -116,7 +119,7 @@ impl<'self> PostgresConnection<'self> { let name = fmt!("__libpostgres_stmt_%u", id); self.next_stmt_id.put_back(id + 1); - let res = unsafe { + let mut res = unsafe { let raw_res = do query.as_c_str |c_query| { do name.as_c_str |c_name| { ffi::PQprepare(self.conn, c_name, c_query, @@ -126,29 +129,65 @@ impl<'self> PostgresConnection<'self> { PostgresResult {result: raw_res} }; - match res.status() { - ffi::PGRES_COMMAND_OK => - Ok(~PostgresStatement {conn: self, name: name}), - _ => Err(res.error()) + if res.status() != ffi::PGRES_COMMAND_OK { + return Err(res.error()); + } + + res = unsafe { + let raw_res = do name.as_c_str |c_name| { + ffi::PQdescribePrepared(self.conn, c_name) + }; + PostgresResult {result: raw_res} + }; + + if res.status() != ffi::PGRES_COMMAND_OK { + return Err(res.error()); + } + + Ok(~PostgresStatement {conn: self, name: name, + num_params: res.num_params()}) + } + + pub fn update(&self, query: &str, params: &[~str]) -> Result { + do self.prepare(query).chain |stmt| { + stmt.update(params) } } - pub fn update(&self, query: &str) -> Result { + pub fn query(&self, query: &str, params: &[~str]) + -> Result<~PostgresResult, ~str> { do self.prepare(query).chain |stmt| { - stmt.update() + stmt.query(params) } } - pub fn query(&self, query: &str) -> Result<~PostgresResult, ~str> { - do self.prepare(query).chain |stmt| { - stmt.query() + pub fn in_transaction(&self, + blk: &fn(&PostgresConnection) -> Result) + -> Result { + match self.update("BEGIN", []) { + Ok(_) => (), + Err(err) => return Err(err) + }; + + // If the task fails in blk, the transaction will roll back when the + // connection closes + let ret = blk(self); + + // TODO What to do about errors here? + if ret.is_ok() { + self.update("COMMIT", []); + } else { + self.update("ABORT", []); } + + ret } } pub struct PostgresStatement<'self> { priv conn: &'self PostgresConnection<'self>, - priv name: ~str + priv name: ~str, + priv num_params: uint } #[unsafe_destructor] @@ -165,11 +204,19 @@ impl<'self> Drop for PostgresStatement<'self> { } impl<'self> PostgresStatement<'self> { - fn exec(&self) -> Result<~PostgresResult, ~str> { + fn exec(&self, params: &[~str]) -> Result<~PostgresResult, ~str> { + if params.len() != self.num_params { + return Err(~"Incorrect number of parameters"); + } + let res = unsafe { let raw_res = do self.name.as_c_str |c_name| { - ffi::PQexecPrepared(self.conn.conn, c_name, 0, ptr::null(), - ptr::null(), ptr::null(), ffi::TEXT_FORMAT) + do as_c_str_array(params) |c_params| { + ffi::PQexecPrepared(self.conn.conn, c_name, + self.num_params as c_int, + c_params, ptr::null(), ptr::null(), + ffi::TEXT_FORMAT) + } }; ~PostgresResult{result: raw_res} }; @@ -182,14 +229,14 @@ impl<'self> PostgresStatement<'self> { } } - pub fn update(&self) -> Result { - do self.exec().chain |res| { + pub fn update(&self, params: &[~str]) -> Result { + do self.exec(params).chain |res| { Ok(res.affected_rows()) } } - pub fn query(&self) -> Result<~PostgresResult, ~str> { - do self.exec().chain |res| { + pub fn query(&self, params: &[~str]) -> Result<~PostgresResult, ~str> { + do self.exec(params).chain |res| { Ok(res) } } @@ -224,6 +271,10 @@ impl PostgresResult { None => 0 } } + + fn num_params(&self) -> uint { + unsafe { ffi::PQnparams(self.result) as uint } + } } impl Container for PostgresResult { @@ -236,6 +287,14 @@ impl PostgresResult { pub fn iter<'a>(&'a self) -> PostgresResultIterator<'a> { PostgresResultIterator {result: self, next_row: 0} } + + pub fn get<'a>(&'a self, idx: uint) -> PostgresRow<'a> { + if idx >= self.len() { + fail!("Out of bounds access"); + } + + self.iter().idx(idx).get() + } } pub struct PostgresResultIterator<'self> { @@ -287,18 +346,36 @@ impl<'self> Container for PostgresRow<'self> { } impl<'self, T: FromStr> Index> for PostgresRow<'self> { - fn index(&self, index: &uint) -> Option { - if *index >= self.len() { + fn index(&self, idx: &uint) -> Option { + self.get(*idx) + } +} + +impl<'self> PostgresRow<'self> { + pub fn get(&self, idx: uint) -> Option { + if idx >= self.len() { fail!("Out of bounds access"); } let s = unsafe { let raw_s = ffi::PQgetvalue(self.result.result, self.row as c_int, - *index as c_int); + idx as c_int); str::raw::from_c_str(raw_s) }; FromStr::from_str(s) } } + +fn as_c_str_array(array: &[~str], blk: &fn(**c_char) -> T) -> T { + let mut c_array: ~[*c_char] = vec::with_capacity(array.len() + 1); + foreach s in array.iter() { + // DANGER, WILL ROBINSON + do s.as_c_str |c_s| { + c_array.push(c_s); + } + } + c_array.push(ptr::null()); + blk(vec::raw::to_ptr(c_array)) +} diff --git a/src/postgres/test.rs b/src/postgres/test.rs index 719a351e..1f8f642a 100644 --- a/src/postgres/test.rs +++ b/src/postgres/test.rs @@ -12,17 +12,39 @@ macro_rules! chk( ) #[test] -fn test_conn() { +fn test_basic() { + let conn = chk!(PostgresConnection::new("postgres://postgres@localhost")); + + do conn.in_transaction |conn| { + chk!(conn.update("CREATE TABLE basic (id INT PRIMARY KEY)", [])); + chk!(conn.update("INSERT INTO basic (id) VALUES (101)", [])); + + let res = chk!(conn.query("SELECT id from basic WHERE id = 101", [])); + assert_eq!(1, res.len()); + let rows: ~[PostgresRow] = res.iter().collect(); + assert_eq!(1, rows.len()); + assert_eq!(1, rows[0].len()); + assert_eq!(Some(101), rows[0][0]); + + Err::<(), ~str>(~"") + }; +} + +#[test] +fn test_params() { let conn = chk!(PostgresConnection::new("postgres://postgres@localhost")); - chk!(conn.update("DROP TABLE IF EXISTS foo")); - chk!(conn.update("CREATE TABLE foo (foo INT PRIMARY KEY)")); - chk!(conn.update("INSERT INTO foo (foo) VALUES (101)")); + do conn.in_transaction |conn| { + chk!(conn.update("CREATE TABLE basic (id INT PRIMARY KEY)", [])); + chk!(conn.update("INSERT INTO basic (id) VALUES ($1)", [~"101"])); - let res = chk!(conn.query("SELECT foo from foo")); - assert_eq!(1, res.len()); - let rows: ~[PostgresRow] = res.iter().collect(); - assert_eq!(1, rows.len()); - assert_eq!(1, rows[0].len()); - assert_eq!(Some(101), rows[0][0]); + let res = chk!(conn.query("SELECT id from basic WHERE id = $1", [~"101"])); + assert_eq!(1, res.len()); + let rows: ~[PostgresRow] = res.iter().collect(); + assert_eq!(1, rows.len()); + assert_eq!(1, rows[0].len()); + assert_eq!(Some(101), rows[0][0]); + + Err::<(), ~str>(~"") + }; }