rust-postgres/src/lib.rs

493 lines
13 KiB
Rust
Raw Normal View History

2013-08-22 05:52:15 +00:00
extern mod extra;
2013-08-04 02:17:32 +00:00
use extra::digest::Digest;
use extra::md5::Md5;
use extra::url::Url;
2013-08-22 05:52:15 +00:00
use std::cell::Cell;
use std::rt::io::io_error;
2013-08-22 05:52:15 +00:00
use std::rt::io::net::ip::SocketAddr;
use std::rt::io::net::tcp::TcpStream;
2013-08-23 07:13:42 +00:00
use std::str;
2013-08-04 02:17:32 +00:00
2013-08-22 05:52:15 +00:00
use message::*;
2013-07-25 07:10:18 +00:00
2013-08-22 05:52:15 +00:00
mod message;
2013-07-25 07:10:18 +00:00
2013-08-18 03:30:31 +00:00
pub struct PostgresConnection {
2013-08-22 05:52:15 +00:00
priv stream: Cell<TcpStream>,
priv next_stmt_id: Cell<int>
2013-08-04 02:17:32 +00:00
}
2013-08-18 03:30:31 +00:00
impl Drop for PostgresConnection {
2013-08-04 02:17:32 +00:00
fn drop(&self) {
do io_error::cond.trap(|_| {}).inside {
self.write_message(&Terminate);
}
2013-08-18 03:42:40 +00:00
}
2013-08-04 02:17:32 +00:00
}
2013-07-25 07:10:18 +00:00
2013-08-18 03:30:31 +00:00
impl PostgresConnection {
2013-08-23 02:47:06 +00:00
pub fn connect(url: &str) -> PostgresConnection {
let url: Url = FromStr::from_str(url).unwrap();
2013-08-22 05:52:15 +00:00
let socket_url = fmt!("%s:%s", url.host,
url.port.get_ref().as_slice());
2013-08-22 05:52:15 +00:00
let addr: SocketAddr = FromStr::from_str(socket_url).unwrap();
let conn = PostgresConnection {
stream: Cell::new(TcpStream::connect(addr).unwrap()),
2013-08-18 03:42:40 +00:00
next_stmt_id: Cell::new(0)
};
let mut args = url.query.clone();
args.push((~"user", url.user.get_ref().user.clone()));
if !url.path.is_empty() {
args.push((~"database", url.path.clone()));
2013-08-04 02:17:32 +00:00
}
conn.write_message(&StartupMessage(args.as_slice()));
conn.handle_auth(&url);
2013-08-04 02:17:32 +00:00
2013-08-22 05:52:15 +00:00
loop {
2013-08-23 05:24:14 +00:00
match conn.read_message() {
2013-08-22 05:52:15 +00:00
ParameterStatus(param, value) =>
info!("Parameter %s = %s", param, value),
2013-08-22 06:41:26 +00:00
BackendKeyData(*) => (),
ReadyForQuery(*) => break,
2013-08-23 05:24:14 +00:00
resp => fail!("Bad response: %?", resp.to_str())
}
2013-08-05 00:48:48 +00:00
}
2013-08-23 05:24:14 +00:00
conn
2013-08-05 00:48:48 +00:00
}
2013-08-22 06:41:26 +00:00
fn write_message(&self, message: &FrontendMessage) {
2013-08-22 05:52:15 +00:00
do self.stream.with_mut_ref |s| {
2013-08-22 06:41:26 +00:00
s.write_message(message);
2013-07-25 07:10:18 +00:00
}
2013-08-22 06:41:26 +00:00
}
2013-08-04 02:17:32 +00:00
2013-08-22 06:41:26 +00:00
fn read_message(&self) -> BackendMessage {
2013-08-22 05:52:15 +00:00
do self.stream.with_mut_ref |s| {
2013-08-22 06:41:26 +00:00
s.read_message()
2013-08-04 05:21:16 +00:00
}
2013-08-22 06:41:26 +00:00
}
fn handle_auth(&self, url: &Url) {
loop {
match self.read_message() {
AuthenticationOk => break,
AuthenticationCleartextPassword => {
let pass = url.user.get_ref().pass.get_ref().as_slice();
self.write_message(&PasswordMessage(pass));
}
AuthenticationMD5Password(nonce) => {
let input = url.user.get_ref().pass.get_ref().as_slice() +
url.user.get_ref().user.as_slice();
let mut md5 = Md5::new();
md5.input_str(input);
let output = md5.result_str();
md5.reset();
md5.input_str(output);
md5.input(nonce);
let output = "md5" + md5.result_str();
self.write_message(&PasswordMessage(output.as_slice()));
}
resp => fail!("Bad response: %?", resp.to_str())
}
}
}
2013-08-22 06:41:26 +00:00
pub fn prepare<'a>(&'a self, query: &str) -> PostgresStatement<'a> {
let id = self.next_stmt_id.take();
2013-08-27 02:38:02 +00:00
let stmt_name = format!("statement_{}", id);
2013-08-22 06:41:26 +00:00
self.next_stmt_id.put_back(id + 1);
2013-08-05 00:48:48 +00:00
2013-08-22 06:41:26 +00:00
let types = [];
self.write_message(&Parse(stmt_name, query, types));
self.write_message(&Sync);
match self.read_message() {
2013-08-22 05:52:15 +00:00
ParseComplete => (),
2013-08-23 05:24:14 +00:00
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
2013-08-05 00:48:48 +00:00
}
2013-08-22 06:41:26 +00:00
self.wait_for_ready();
self.write_message(&Describe('S' as u8, stmt_name));
self.write_message(&Sync);
let num_params = match self.read_message() {
ParameterDescription(ref types) => types.len(),
2013-08-23 05:24:14 +00:00
resp => fail!("Bad response: %?", resp.to_str())
2013-08-22 06:41:26 +00:00
};
match self.read_message() {
RowDescription(*) | NoData => (),
2013-08-23 05:24:14 +00:00
resp => fail!("Bad response: %?", resp.to_str())
2013-08-05 00:48:48 +00:00
}
2013-08-17 22:09:26 +00:00
2013-08-22 06:41:26 +00:00
self.wait_for_ready();
PostgresStatement {
conn: self,
name: stmt_name,
2013-08-22 07:12:35 +00:00
num_params: num_params,
next_portal_id: Cell::new(0)
2013-08-22 06:41:26 +00:00
}
}
2013-08-23 05:24:14 +00:00
pub fn in_transaction<T, E: ToStr>(&self, blk: &fn(&PostgresConnection)
-> Result<T, E>)
-> Result<T, E> {
2013-08-23 07:13:42 +00:00
self.quick_query("BEGIN");
2013-08-23 05:24:14 +00:00
// If this fails, Postgres will rollback when the connection closes
let ret = blk(self);
if ret.is_ok() {
2013-08-23 07:13:42 +00:00
self.quick_query("COMMIT");
2013-08-23 05:24:14 +00:00
} else {
2013-08-23 07:13:42 +00:00
self.quick_query("ROLLBACK");
2013-08-23 05:24:14 +00:00
}
ret
}
2013-08-23 07:13:42 +00:00
fn quick_query(&self, query: &str) {
self.write_message(&Query(query));
loop {
match self.read_message() {
ReadyForQuery(*) => break,
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
_ => ()
}
}
}
2013-08-22 06:41:26 +00:00
fn wait_for_ready(&self) {
2013-08-23 05:24:14 +00:00
loop {
match self.read_message() {
ReadyForQuery(*) => break,
resp => fail!("Bad response: %?", resp.to_str())
}
2013-08-22 06:41:26 +00:00
}
2013-08-17 22:09:26 +00:00
}
}
2013-08-22 05:52:15 +00:00
pub struct PostgresStatement<'self> {
priv conn: &'self PostgresConnection,
2013-08-22 06:41:26 +00:00
priv name: ~str,
2013-08-22 07:12:35 +00:00
priv num_params: uint,
priv next_portal_id: Cell<uint>
}
#[unsafe_destructor]
impl<'self> Drop for PostgresStatement<'self> {
fn drop(&self) {
do io_error::cond.trap(|_| {}).inside {
self.conn.write_message(&Close('S' as u8, self.name.as_slice()));
self.conn.write_message(&Sync);
loop {
match self.conn.read_message() {
ReadyForQuery(*) => break,
_ => ()
}
2013-08-23 05:24:14 +00:00
}
}
2013-08-22 07:12:35 +00:00
}
}
2013-08-22 05:52:15 +00:00
impl<'self> PostgresStatement<'self> {
2013-08-22 07:12:35 +00:00
pub fn num_params(&self) -> uint {
self.num_params
}
2013-08-25 03:47:36 +00:00
fn execute(&self, portal_name: &str, params: &[&ToSql]) {
if self.num_params != params.len() {
fail!("Expected %u params but got %u", self.num_params,
params.len());
}
2013-08-22 07:12:35 +00:00
let formats = [];
2013-08-25 03:47:36 +00:00
let values: ~[Option<~[u8]>] = params.iter().map(|val| val.to_sql())
.collect();
2013-08-22 07:12:35 +00:00
let result_formats = [];
self.conn.write_message(&Bind(portal_name, self.name.as_slice(),
formats, values, result_formats));
2013-08-23 05:24:14 +00:00
self.conn.write_message(&Execute(portal_name.as_slice(), 0));
2013-08-22 07:12:35 +00:00
self.conn.write_message(&Sync);
match self.conn.read_message() {
BindComplete => (),
2013-08-23 05:24:14 +00:00
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
2013-08-22 07:12:35 +00:00
}
2013-08-23 05:24:14 +00:00
}
2013-08-25 03:47:36 +00:00
pub fn update(&self, params: &[&ToSql]) -> uint {
// The unnamed portal is automatically cleaned up at sync time
self.execute("", params);
2013-08-22 07:12:35 +00:00
2013-08-23 05:24:14 +00:00
let mut num = 0;
loop {
match self.conn.read_message() {
CommandComplete(ret) => {
let s = ret.split_iter(' ').last().unwrap();
match FromStr::from_str(s) {
None => (),
Some(n) => num = n
}
break;
}
DataRow(*) => (),
EmptyQueryResponse => break,
NoticeResponse(*) => (),
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
}
}
2013-08-22 07:12:35 +00:00
self.conn.wait_for_ready();
2013-08-23 05:24:14 +00:00
num
2013-08-22 07:12:35 +00:00
}
2013-08-23 07:13:42 +00:00
2013-08-25 03:47:36 +00:00
pub fn query<'a>(&'a self, params: &[&ToSql]) -> PostgresResult<'a> {
2013-08-23 07:13:42 +00:00
let id = self.next_portal_id.take();
2013-08-27 02:38:02 +00:00
let portal_name = format!("{:s}_portal_{}", self.name.as_slice(), id);
2013-08-23 07:13:42 +00:00
self.next_portal_id.put_back(id + 1);
2013-08-25 03:47:36 +00:00
self.execute(portal_name, params);
2013-08-23 07:13:42 +00:00
let mut data = ~[];
loop {
match self.conn.read_message() {
EmptyQueryResponse => break,
DataRow(row) => data.push(row),
CommandComplete(*) => break,
NoticeResponse(*) => (),
resp @ ErrorResponse(*) => fail!("Error: %?", resp.to_str()),
resp => fail!("Bad response: %?", resp.to_str())
}
}
PostgresResult {
stmt: self,
name: portal_name,
data: data
}
}
}
pub struct PostgresResult<'self> {
priv stmt: &'self PostgresStatement<'self>,
priv name: ~str,
priv data: ~[~[Option<~[u8]>]]
}
#[unsafe_destructor]
impl<'self> Drop for PostgresResult<'self> {
fn drop(&self) {
do io_error::cond.trap(|_| {}).inside {
self.stmt.conn.write_message(&Close('P' as u8,
self.name.as_slice()));
self.stmt.conn.write_message(&Sync);
loop {
match self.stmt.conn.read_message() {
ReadyForQuery(*) => break,
_ => ()
}
2013-08-23 07:13:42 +00:00
}
}
}
}
impl<'self> PostgresResult<'self> {
pub fn iter<'a>(&'a self) -> PostgresResultIterator<'a> {
PostgresResultIterator { result: self, next_row: 0 }
}
}
pub struct PostgresResultIterator<'self> {
priv result: &'self PostgresResult<'self>,
priv next_row: uint
}
impl<'self> Iterator<PostgresRow<'self>> for PostgresResultIterator<'self> {
fn next(&mut self) -> Option<PostgresRow<'self>> {
if self.next_row == self.result.data.len() {
return None;
}
let row = self.next_row;
self.next_row += 1;
Some(PostgresRow { result: self.result, row: row })
}
}
pub struct PostgresRow<'self> {
priv result: &'self PostgresResult<'self>,
priv row: uint
}
impl<'self> Container for PostgresRow<'self> {
fn len(&self) -> uint {
self.result.data[self.row].len()
}
}
impl<'self, T: FromSql> Index<uint, T> for PostgresRow<'self> {
fn index(&self, idx: &uint) -> T {
self.get(*idx)
}
}
impl<'self> PostgresRow<'self> {
pub fn get<T: FromSql>(&self, idx: uint) -> T {
FromSql::from_sql(&self.result.data[self.row][idx])
}
}
pub trait FromSql {
fn from_sql(raw: &Option<~[u8]>) -> Self;
}
2013-08-25 03:47:36 +00:00
macro_rules! from_str_impl(
($t:ty) => (
impl FromSql for Option<$t> {
fn from_sql(raw: &Option<~[u8]>) -> Option<$t> {
match *raw {
None => None,
Some(ref buf) => {
let s = str::from_bytes_slice(buf.as_slice());
Some(FromStr::from_str(s).unwrap())
}
}
}
}
)
)
macro_rules! from_option_impl(
($t:ty) => (
impl FromSql for $t {
fn from_sql(raw: &Option<~[u8]>) -> $t {
FromSql::from_sql::<Option<$t>>(raw).unwrap()
}
}
)
)
from_str_impl!(int)
from_option_impl!(int)
from_str_impl!(i8)
from_option_impl!(i8)
from_str_impl!(i16)
from_option_impl!(i16)
from_str_impl!(i32)
from_option_impl!(i32)
from_str_impl!(i64)
from_option_impl!(i64)
from_str_impl!(uint)
from_option_impl!(uint)
from_str_impl!(u8)
from_option_impl!(u8)
from_str_impl!(u16)
from_option_impl!(u16)
from_str_impl!(u32)
from_option_impl!(u32)
from_str_impl!(u64)
from_option_impl!(u64)
from_str_impl!(float)
from_option_impl!(float)
from_str_impl!(f32)
from_option_impl!(f32)
from_str_impl!(f64)
from_option_impl!(f64)
impl FromSql for Option<~str> {
fn from_sql(raw: &Option<~[u8]>) -> Option<~str> {
do raw.chain_ref |buf| {
Some(str::from_bytes(buf.as_slice()))
}
}
}
from_option_impl!(~str)
pub trait ToSql {
fn to_sql(&self) -> Option<~[u8]>;
}
macro_rules! to_str_impl(
($t:ty) => (
impl ToSql for $t {
fn to_sql(&self) -> Option<~[u8]> {
Some(self.to_str().into_bytes())
}
}
)
)
macro_rules! to_option_impl(
($t:ty) => (
impl ToSql for Option<$t> {
fn to_sql(&self) -> Option<~[u8]> {
do self.chain |val| {
val.to_sql()
}
}
}
)
)
to_str_impl!(int)
to_option_impl!(int)
to_str_impl!(i8)
to_option_impl!(i8)
to_str_impl!(i16)
to_option_impl!(i16)
to_str_impl!(i32)
to_option_impl!(i32)
to_str_impl!(i64)
to_option_impl!(i64)
to_str_impl!(uint)
to_option_impl!(uint)
to_str_impl!(u8)
to_option_impl!(u8)
to_str_impl!(u16)
to_option_impl!(u16)
to_str_impl!(u32)
to_option_impl!(u32)
to_str_impl!(u64)
to_option_impl!(u64)
to_str_impl!(float)
to_option_impl!(float)
to_str_impl!(f32)
to_option_impl!(f32)
to_str_impl!(f64)
to_option_impl!(f64)
impl<'self> ToSql for &'self str {
fn to_sql(&self) -> Option<~[u8]> {
Some(self.as_bytes().to_owned())
}
}
impl ToSql for Option<~str> {
fn to_sql(&self) -> Option<~[u8]> {
do self.chain_ref |val| {
val.to_sql()
}
}
}
impl<'self> ToSql for Option<&'self str> {
fn to_sql(&self) -> Option<~[u8]> {
do self.chain |val| {
val.to_sql()
}
2013-08-23 07:13:42 +00:00
}
}