Proper NULL support

This commit is contained in:
Steven Fackler 2013-07-27 16:39:10 -04:00
parent 762a55b38d
commit af2db76953
2 changed files with 87 additions and 24 deletions

View File

@ -1,6 +1,8 @@
use std::libc::c_int;
use std::ptr;
use std::str;
use std::vec;
use std::uint;
mod ffi {
use std::libc::{c_char, c_int, c_void};
@ -13,6 +15,8 @@ mod ffi {
pub static SQLITE_ROW: c_int = 100;
pub static SQLITE_DONE: c_int = 101;
pub static SQLITE_NULL: c_int = 5;
// A function because Rust doesn't like casting from an int to a function
// pointer in a static declaration
pub fn SQLITE_TRANSIENT() -> extern "C" fn(*c_void) {
@ -31,8 +35,10 @@ mod ffi {
fn sqlite3_reset(pStmt: *sqlite3_stmt) -> c_int;
fn sqlite3_bind_text(pStmt: *sqlite3_stmt, idx: c_int, text: *c_char,
n: c_int, free: extern "C" fn(*c_void)) -> c_int;
fn sqlite3_bind_null(pStmt: *sqlite3_stmt, idx: c_int) -> c_int;
fn sqlite3_step(pStmt: *sqlite3_stmt) -> c_int;
fn sqlite3_column_count(pStmt: *sqlite3_stmt) -> c_int;
fn sqlite3_column_type(pStmt: *sqlite3_stmt, iCol: c_int) -> c_int;
fn sqlite3_column_text(pStmt: *sqlite3_stmt, iCol: c_int) -> *c_char;
fn sqlite3_finalize(pStmt: *sqlite3_stmt) -> c_int;
}
@ -161,11 +167,16 @@ impl<'self> PreparedStatement<'self> {
fn bind_params(&self, params: &[@SqlType]) -> Result<(), ~str> {
for params.iter().enumerate().advance |(idx, param)| {
let ret = do param.to_sql_str().as_c_str |c_param| {
unsafe {
ffi::sqlite3_bind_text(self.stmt, (idx+1) as c_int,
c_param, -1,
ffi::SQLITE_TRANSIENT())
let ret = match param.to_sql_str() {
Some(val) => do val.as_c_str |c_param| {
unsafe {
ffi::sqlite3_bind_text(self.stmt, (idx+1) as c_int,
c_param, -1,
ffi::SQLITE_TRANSIENT())
}
},
None => unsafe {
ffi::sqlite3_bind_null(self.stmt, (idx+1) as c_int)
}
};
@ -218,7 +229,7 @@ impl<'self> Iterator<Row<'self>> for ResultIterator<'self> {
fn next(&mut self) -> Option<Row<'self>> {
let ret = unsafe { ffi::sqlite3_step(self.stmt.stmt) };
match ret {
ffi::SQLITE_ROW => Some(Row {stmt: self.stmt}),
ffi::SQLITE_ROW => Some(Row::new(self.stmt)),
// TODO: Ignoring errors for now
_ => None
}
@ -226,7 +237,8 @@ impl<'self> Iterator<Row<'self>> for ResultIterator<'self> {
}
pub struct Row<'self> {
priv stmt: &'self PreparedStatement<'self>
priv stmt: &'self PreparedStatement<'self>,
priv cols: ~[Option<~str>]
}
impl<'self> Container for Row<'self> {
@ -240,30 +252,61 @@ impl<'self> Container for Row<'self> {
}
impl<'self> Row<'self> {
pub fn get<T: SqlType>(&self, idx: uint) -> Option<T> {
let raw = unsafe {
ffi::sqlite3_column_text(self.stmt.stmt, idx as c_int)
};
fn new(stmt: &'self PreparedStatement<'self>) -> Row<'self> {
let count = unsafe { ffi::sqlite3_column_count(stmt.stmt) as uint};
let mut row = Row {stmt: stmt, cols: vec::with_capacity(count)};
if ptr::is_null(raw) {
return None;
for uint::range(0, count) |i| {
let typ = unsafe {
ffi::sqlite3_column_type(stmt.stmt, i as c_int)
};
let val = match typ {
ffi::SQLITE_NULL => None,
_ => Some(unsafe {
str::raw::from_c_str(ffi::sqlite3_column_text(stmt.stmt,
i as c_int))
})
};
row.cols.push(val);
}
SqlType::from_sql_str(unsafe { str::raw::from_c_str(raw) })
return row
}
}
impl<'self> Row<'self> {
pub fn get<T: SqlType>(&self, idx: uint) -> T {
SqlType::from_sql_str(&self.cols[idx])
}
}
pub trait SqlType {
fn to_sql_str(&self) -> ~str;
fn from_sql_str(sql_str: &str) -> Option<Self>;
fn to_sql_str(&self) -> Option<~str>;
fn from_sql_str(sql_str: &Option<~str>) -> Self;
}
impl SqlType for int {
fn to_sql_str(&self) -> ~str {
self.to_str()
fn to_sql_str(&self) -> Option<~str> {
Some(self.to_str())
}
fn from_sql_str(sql_str: &str) -> Option<int> {
FromStr::from_str(sql_str)
fn from_sql_str(sql_str: &Option<~str>) -> int {
FromStr::from_str(*sql_str.get_ref()).get()
}
}
impl SqlType for Option<int> {
fn to_sql_str(&self) -> Option<~str> {
match *self {
None => None,
Some(v) => Some(v.to_str())
}
}
fn from_sql_str(sql_str: &Option<~str>) -> Option<int> {
match *sql_str {
None => None,
Some(ref s) => Some(FromStr::from_str(*s).get())
}
}
}

View File

@ -21,7 +21,7 @@ fn test_basic() {
do conn.query("SELECT id FROM foo") |it| {
for it.advance |row| {
printfln!("%u %d", row.len(), row.get(0).get());
printfln!("%u %d", row.len(), row.get(0));
}
};
}
@ -37,7 +37,7 @@ fn test_trans() {
Err::<(), ~str>(~"")
};
assert_eq!(0, chk!(conn.query("SELECT COUNT(*) FROM bar", |it| {
it.next().get().get(0).get()
it.next().get().get(0)
})));
do conn.in_transaction |conn| {
@ -46,7 +46,7 @@ fn test_trans() {
};
assert_eq!(1, chk!(conn.query("SELECT COUNT(*) FROM bar", |it| {
it.next().get().get(0).get()
it.next().get().get(0)
})));
}
@ -60,6 +60,26 @@ fn test_params() {
&[@100 as @SqlType, @101 as @SqlType]));
assert_eq!(2, chk!(conn.query("SELECT COUNT(*) FROM foo", |it| {
it.next().get().get(0).get()
it.next().get().get(0)
})));
}
#[test]
fn test_null() {
let conn = chk!(sqlite3::open(":memory:"));
chk!(conn.update("CREATE TABLE foo (
id BIGINT PRIMARY KEY,
n BIGINT
)"));
chk!(conn.update_params("INSERT INTO foo (id, n) VALUES (?, ?), (?, ?)",
&[@100 as @SqlType, @None::<int> as @SqlType,
@101 as @SqlType, @Some(1) as @SqlType]));
do conn.query("SELECT n FROM foo WHERE id = 100") |it| {
assert!(it.next().get().get::<Option<int>>(0).is_none());
};
do conn.query("SELECT n FROM foo WHERE id = 101") |it| {
assert_eq!(Some(1), it.next().get().get(0))
};
}