diff --git a/src/error.rs b/src/error.rs index 1c601a21..11db2461 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,7 @@ //! Postgres errors use std::hashmap::HashMap; +use std::io::IoError; use openssl::ssl::error::SslError; use phf::PhfMap; @@ -362,7 +363,7 @@ pub enum PostgresConnectError { /// There was an error opening a socket to the server SocketError, /// An error from the Postgres server itself - DbError(PostgresDbError), + PgConnectDbError(PostgresDbError), /// A password was required but not provided in the URL MissingPassword, /// The Postgres server requested an authentication method not supported @@ -371,7 +372,9 @@ pub enum PostgresConnectError { /// The Postgres server does not support SSL encryption NoSslSupport, /// There was an error initializing the SSL session - SslError(SslError) + SslError(SslError), + /// There was an error communicating with the server + PgConnectStreamError(IoError), } /// Represents the position of an error in a query @@ -491,3 +494,21 @@ impl PostgresDbError { } } } + +#[deriving(ToStr)] +pub enum PostgresError { + /// An error reported by the Postgres server + PgDbError(PostgresDbError), + /// An error communicating with the Postgres server + PgStreamError(IoError), +} + +impl PostgresError { + #[doc(hidden)] + pub fn pretty_error(&self, query: &str) -> ~str { + match *self { + PgDbError(ref err) => err.pretty_error(query), + PgStreamError(ref err) => format!("{}", *err), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 7e1c72e0..2820562c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,6 +69,7 @@ extern mod openssl; #[phase(syntax)] extern mod phf_mac; extern mod phf; +extern mod uuid; use extra::container::Deque; use extra::hex::ToHex; @@ -78,8 +79,7 @@ use openssl::crypto::hash::{MD5, Hasher}; use openssl::ssl::{SslStream, SslContext}; use std::cell::RefCell; use std::hashmap::HashMap; -use std::io::BufferedStream; -use std::io::io_error; +use std::io::{BufferedStream, IoResult}; use std::io::net::ip::{Port, SocketAddr}; use std::io::net::tcp::TcpStream; use std::io::net; @@ -88,6 +88,8 @@ use std::task; use std::util; use error::{PostgresDbError, + PgConnectDbError, + PgConnectStreamError, PostgresConnectError, InvalidUrl, DnsError, @@ -95,9 +97,11 @@ use error::{PostgresDbError, NoSslSupport, SslError, MissingUser, - DbError, UnsupportedAuthentication, - MissingPassword}; + MissingPassword, + PostgresError, + PgStreamError, + PgDbError}; use message::{BackendMessage, AuthenticationOk, AuthenticationKerberosV5, @@ -235,27 +239,39 @@ pub fn cancel_query(url: &str, ssl: &SslMode, data: PostgresCancelData) fn open_socket(host: &str, port: Port) -> Result { - let addrs = io_error::cond.trap(|_| {}).inside(|| { - net::get_host_addresses(host) - }); - let addrs = match addrs { - Some(addrs) => addrs, - None => return Err(DnsError) + let addrs = match net::get_host_addresses(host) { + Ok(addrs) => addrs, + Err(_) => return Err(DnsError) }; - for addr in addrs.iter() { - let socket = io_error::cond.trap(|_| {}).inside(|| { - TcpStream::connect(SocketAddr { ip: *addr, port: port }) - }); - match socket { - Some(socket) => return Ok(socket), - None => {} + for &addr in addrs.iter() { + match TcpStream::connect(SocketAddr { ip: addr, port: port }) { + Ok(socket) => return Ok(socket), + Err(_) => {} } } Err(SocketError) } +macro_rules! if_ok_pg_conn( + ($e:expr) => ( + match $e { + Ok(ok) => ok, + Err(err) => return Err(PgConnectStreamError(err)) + } + ) +) + +macro_rules! if_ok_pg( + ($e:expr) => ( + match $e { + Ok(ok) => ok, + Err(err) => return Err(PgStreamError(err)) + } + ) +) + fn initialize_stream(host: &str, port: Port, ssl: &SslMode) -> Result { let mut socket = match open_socket(host, port) { @@ -272,7 +288,7 @@ fn initialize_stream(host: &str, port: Port, ssl: &SslMode) socket.write_message(&SslRequest { code: message::SSL_CODE }); socket.flush(); - if socket.read_u8() == 'N' as u8 { + if if_ok_pg_conn!(socket.read_u8()) == 'N' as u8 { if ssl_required { return Err(NoSslSupport); } else { @@ -292,7 +308,7 @@ enum InternalStream { } impl Reader for InternalStream { - fn read(&mut self, buf: &mut [u8]) -> Option { + fn read(&mut self, buf: &mut [u8]) -> IoResult { match *self { Normal(ref mut s) => s.read(buf), Ssl(ref mut s) => s.read(buf) @@ -301,14 +317,14 @@ impl Reader for InternalStream { } impl Writer for InternalStream { - fn write(&mut self, buf: &[u8]) { + fn write(&mut self, buf: &[u8]) -> IoResult<()> { match *self { Normal(ref mut s) => s.write(buf), Ssl(ref mut s) => s.write(buf) } } - fn flush(&mut self) { + fn flush(&mut self) -> IoResult<()> { match *self { Normal(ref mut s) => s.flush(), Ssl(ref mut s) => s.flush() @@ -327,9 +343,7 @@ struct InnerPostgresConnection { impl Drop for InnerPostgresConnection { fn drop(&mut self) { - io_error::cond.trap(|_| {}).inside(|| { - self.write_messages([Terminate]); - }) + let _ = self.write_messages([Terminate]); } } @@ -383,25 +397,25 @@ impl InnerPostgresConnection { path.shift_char(); args.push((~"database", path)); } - conn.write_messages([StartupMessage { + if_ok_pg_conn!(conn.write_messages([StartupMessage { version: message::PROTOCOL_VERSION, parameters: args.as_slice() - }]); + }])); match conn.handle_auth(user) { - Some(err) => return Err(err), - None => {} + Err(err) => return Err(err), + Ok(()) => {} } loop { - match conn.read_message() { + match if_ok_pg_conn!(conn.read_message()) { BackendKeyData { process_id, secret_key } => { conn.cancel_data.process_id = process_id; conn.cancel_data.secret_key = secret_key; } ReadyForQuery { .. } => break, ErrorResponse { fields } => - return Err(DbError(PostgresDbError::new(fields))), + return Err(PgConnectDbError(PostgresDbError::new(fields))), _ => unreachable!() } } @@ -409,46 +423,47 @@ impl InnerPostgresConnection { Ok(conn) } - fn write_messages(&mut self, messages: &[FrontendMessage]) { + fn write_messages(&mut self, messages: &[FrontendMessage]) -> IoResult<()> { for message in messages.iter() { - self.stream.write_message(message); + if_ok!(self.stream.write_message(message)); } - self.stream.flush(); + self.stream.flush() } - fn read_message(&mut self) -> BackendMessage { + fn read_message(&mut self) -> IoResult { loop { match self.stream.read_message() { - NoticeResponse { fields } => + Ok(NoticeResponse { fields }) => self.notice_handler.handle(PostgresDbError::new(fields)), - NotificationResponse { pid, channel, payload } => + Ok(NotificationResponse { pid, channel, payload }) => self.notifications.push_back(PostgresNotification { pid: pid, channel: channel, payload: payload }), - ParameterStatus { parameter, value } => + Ok(ParameterStatus { parameter, value }) => info!("Parameter {} = {}", parameter, value), - msg => return msg + val => return val } } } - fn handle_auth(&mut self, user: UserInfo) -> Option { - match self.read_message() { - AuthenticationOk => return None, + fn handle_auth(&mut self, user: UserInfo) -> + Result<(), PostgresConnectError> { + match if_ok_pg_conn!(self.read_message()) { + AuthenticationOk => return Ok(()), AuthenticationCleartextPassword => { let pass = match user.pass { Some(pass) => pass, - None => return Some(MissingPassword) + None => return Err(MissingPassword) }; - self.write_messages([PasswordMessage { password: pass }]); + if_ok_pg_conn!(self.write_messages([PasswordMessage { password: pass }])); } AuthenticationMD5Password { salt } => { let UserInfo { user, pass } = user; let pass = match pass { Some(pass) => pass, - None => return Some(MissingPassword) + None => return Err(MissingPassword) }; let input = pass + user; let hasher = Hasher::new(MD5); @@ -458,23 +473,23 @@ impl InnerPostgresConnection { hasher.update(output.as_bytes()); hasher.update(salt); let output = "md5" + hasher.final().to_hex(); - self.write_messages([PasswordMessage { + if_ok_pg_conn!(self.write_messages([PasswordMessage { password: output.as_slice() - }]); + }])); } AuthenticationKerberosV5 | AuthenticationSCMCredential | AuthenticationGSS - | AuthenticationSSPI => return Some(UnsupportedAuthentication), + | AuthenticationSSPI => return Err(UnsupportedAuthentication), ErrorResponse { fields } => - return Some(DbError(PostgresDbError::new(fields))), + return Err(PgConnectDbError(PostgresDbError::new(fields))), _ => unreachable!() } - match self.read_message() { - AuthenticationOk => None, + match if_ok_pg_conn!(self.read_message()) { + AuthenticationOk => Ok(()), ErrorResponse { fields } => - Some(DbError(PostgresDbError::new(fields))), + Err(PgConnectDbError(PostgresDbError::new(fields))), _ => unreachable!() } } @@ -485,12 +500,12 @@ impl InnerPostgresConnection { } fn try_prepare<'a>(&mut self, query: &str, conn: &'a PostgresConnection) - -> Result, PostgresDbError> { + -> Result, PostgresError> { let stmt_name = format!("statement_{}", self.next_stmt_id); self.next_stmt_id += 1; let types = []; - self.write_messages([ + if_ok_pg!(self.write_messages([ Parse { name: stmt_name, query: query, @@ -500,24 +515,24 @@ impl InnerPostgresConnection { variant: 'S' as u8, name: stmt_name }, - Sync]); + Sync])); - match self.read_message() { + match if_ok_pg!(self.read_message()) { ParseComplete => {} ErrorResponse { fields } => { self.wait_for_ready(); - return Err(PostgresDbError::new(fields)); + return Err(PgDbError(PostgresDbError::new(fields))); } _ => unreachable!() } - let mut param_types: ~[PostgresType] = match self.read_message() { + let mut param_types: ~[PostgresType] = match if_ok_pg!(self.read_message()) { ParameterDescription { types } => types.iter().map(|ty| PostgresType::from_oid(*ty)).collect(), _ => unreachable!() }; - let mut result_desc: ~[ResultDescription] = match self.read_message() { + let mut result_desc: ~[ResultDescription] = match if_ok_pg!(self.read_message()) { RowDescription { descriptions } => descriptions.move_iter().map(|desc| { ResultDescription::from_row_description_entry(desc) @@ -526,14 +541,14 @@ impl InnerPostgresConnection { _ => unreachable!() }; - self.wait_for_ready(); + if_ok!(self.wait_for_ready()); // now that the connection is ready again, get unknown type names for param in param_types.mut_iter() { match *param { PgUnknownType { oid, .. } => *param = PgUnknownType { - name: self.get_type_name(oid), + name: if_ok!(self.get_type_name(oid)), oid: oid }, _ => {} @@ -544,7 +559,7 @@ impl InnerPostgresConnection { match desc.ty { PgUnknownType { oid, .. } => desc.ty = PgUnknownType { - name: self.get_type_name(oid), + name: if_ok!(self.get_type_name(oid)), oid: oid }, _ => {} @@ -560,43 +575,43 @@ impl InnerPostgresConnection { }) } - fn get_type_name(&mut self, oid: Oid) -> ~str { + fn get_type_name(&mut self, oid: Oid) -> Result<~str, PostgresError> { match self.unknown_types.find(&oid) { - Some(name) => return name.clone(), + Some(name) => return Ok(name.clone()), None => {} } - let name = self.quick_query( - format!("SELECT typname FROM pg_type WHERE oid={}", oid))[0][0] + let name = if_ok!(self.quick_query( + format!("SELECT typname FROM pg_type WHERE oid={}", oid)))[0][0] .unwrap(); self.unknown_types.insert(oid, name.clone()); - name + Ok(name) } - fn wait_for_ready(&mut self) { - match self.read_message() { - ReadyForQuery { .. } => {} + fn wait_for_ready(&mut self) -> Result<(), PostgresError> { + match if_ok_pg!(self.read_message()) { + ReadyForQuery { .. } => Ok(()), _ => unreachable!() } } - fn quick_query(&mut self, query: &str) -> ~[~[Option<~str>]] { - self.write_messages([Query { query: query }]); + fn quick_query(&mut self, query: &str) + -> Result<~[~[Option<~str>]], PostgresError> { + if_ok_pg!(self.write_messages([Query { query: query }])); let mut result = ~[]; loop { - match self.read_message() { + match if_ok_pg!(self.read_message()) { ReadyForQuery { .. } => break, DataRow { row } => result.push(row.move_iter().map(|opt| opt.map(|b| str::from_utf8_owned(b).unwrap())) .collect()), ErrorResponse { fields } => - fail!("Error: {}", - PostgresDbError::new(fields).to_str()), + return Err(PgDbError(PostgresDbError::new(fields))), _ => {} } } - result + Ok(result) } } @@ -663,7 +678,7 @@ impl PostgresConnection { /// The statement is associated with the connection that created it and may /// not outlive that connection. pub fn try_prepare<'a>(&'a self, query: &str) - -> Result, PostgresDbError> { + -> Result, PostgresError> { self.conn.with_mut(|conn| conn.try_prepare(query, self)) } @@ -703,7 +718,7 @@ impl PostgresConnection { /// /// On success, returns the number of rows modified or 0 if not applicable. pub fn try_execute(&self, query: &str, params: &[&ToSql]) - -> Result { + -> Result { self.try_prepare(query).and_then(|stmt| stmt.try_execute(params)) } @@ -728,19 +743,19 @@ impl PostgresConnection { self.conn.with(|conn| conn.cancel_data) } - fn quick_query(&self, query: &str) -> ~[~[Option<~str>]] { + fn quick_query(&self, query: &str) -> Result<~[~[Option<~str>]], PostgresError> { self.conn.with_mut(|conn| conn.quick_query(query)) } - fn wait_for_ready(&self) { + fn wait_for_ready(&self) -> Result<(), PostgresError> { self.conn.with_mut(|conn| conn.wait_for_ready()) } - fn read_message(&self) -> BackendMessage { + fn read_message(&self) -> IoResult { self.conn.with_mut(|conn| conn.read_message()) } - fn write_messages(&self, messages: &[FrontendMessage]) { + fn write_messages(&self, messages: &[FrontendMessage]) -> IoResult<()> { self.conn.with_mut(|conn| conn.write_messages(messages)) } } @@ -765,28 +780,26 @@ pub struct PostgresTransaction<'conn> { #[unsafe_destructor] impl<'conn> Drop for PostgresTransaction<'conn> { fn drop(&mut self) { - io_error::cond.trap(|_| {}).inside(|| { - if task::failing() || !self.commit.with(|x| *x) { - if self.nested { - self.conn.quick_query("ROLLBACK TO sp"); - } else { - self.conn.quick_query("ROLLBACK"); - } + if task::failing() || !self.commit.with(|x| *x) { + if self.nested { + self.conn.quick_query("ROLLBACK TO sp"); } else { - if self.nested { - self.conn.quick_query("RELEASE sp"); - } else { - self.conn.quick_query("COMMIT"); - } + self.conn.quick_query("ROLLBACK"); } - }) + } else { + if self.nested { + self.conn.quick_query("RELEASE sp"); + } else { + self.conn.quick_query("COMMIT"); + } + } } } impl<'conn> PostgresTransaction<'conn> { /// Like `PostgresConnection::try_prepare`. pub fn try_prepare<'a>(&'a self, query: &str) - -> Result, PostgresDbError> { + -> Result, PostgresError> { self.conn.try_prepare(query).map(|stmt| { TransactionalPostgresStatement { stmt: stmt @@ -804,7 +817,7 @@ impl<'conn> PostgresTransaction<'conn> { /// Like `PostgresConnection::try_execute`. pub fn try_execute(&self, query: &str, params: &[&ToSql]) - -> Result { + -> Result { self.conn.try_execute(query, params) } @@ -861,7 +874,7 @@ pub trait PostgresStatement { /// /// Fails if the number or types of the provided parameters do not match /// the parameters of the statement. - fn try_execute(&self, params: &[&ToSql]) -> Result; + fn try_execute(&self, params: &[&ToSql]) -> Result; /// A convenience function wrapping `try_execute`. /// @@ -883,7 +896,7 @@ pub trait PostgresStatement { /// Fails if the number or types of the provided parameters do not match /// the parameters of the statement. fn try_query<'a>(&'a self, params: &[&ToSql]) - -> Result, PostgresDbError>; + -> Result, PostgresError>; /// A convenience function wrapping `try_query`. /// @@ -910,26 +923,25 @@ pub struct NormalPostgresStatement<'conn> { #[unsafe_destructor] impl<'conn> Drop for NormalPostgresStatement<'conn> { fn drop(&mut self) { - io_error::cond.trap(|_| {}).inside(|| { - self.conn.write_messages([ - Close { - variant: 'S' as u8, - name: self.name.as_slice() - }, - Sync]); - loop { - match self.conn.read_message() { - ReadyForQuery { .. } => break, - _ => {} - } + let _ = self.conn.write_messages([ + Close { + variant: 'S' as u8, + name: self.name.as_slice() + }, + Sync]); + loop { + match self.conn.read_message() { + Ok(ReadyForQuery { .. }) => break, + Err(_) => break, + _ => {} } - }) + } } } impl<'conn> NormalPostgresStatement<'conn> { fn execute(&self, portal_name: &str, row_limit: uint, params: &[&ToSql]) - -> Option { + -> Result<(), PostgresError> { let mut formats = ~[]; let mut values = ~[]; assert!(self.param_types.len() == params.len(), @@ -959,25 +971,22 @@ impl<'conn> NormalPostgresStatement<'conn> { }, Sync]); - match self.conn.read_message() { - BindComplete => None, + match if_ok_pg!(self.conn.read_message()) { + BindComplete => Ok(()), ErrorResponse { fields } => { - self.conn.wait_for_ready(); - Some(PostgresDbError::new(fields)) + if_ok!(self.conn.wait_for_ready()); + Err(PgDbError(PostgresDbError::new(fields))) } _ => unreachable!() } } fn try_lazy_query<'a>(&'a self, row_limit: uint, params: &[&ToSql]) - -> Result, PostgresDbError> { + -> Result, PostgresError> { let id = self.next_portal_id.with_mut(|x| { *x += 1; *x - 1 }); let portal_name = format!("{}_portal_{}", self.name, id); - match self.execute(portal_name, row_limit, params) { - Some(err) => return Err(err), - None => {} - } + if_ok!(self.execute(portal_name, row_limit, params)); let mut result = PostgresResult { stmt: self, @@ -1002,19 +1011,16 @@ impl<'conn> PostgresStatement for NormalPostgresStatement<'conn> { } fn try_execute(&self, params: &[&ToSql]) - -> Result { - match self.execute("", 0, params) { - Some(err) => return Err(err), - None => {} - } + -> Result { + if_ok!(self.execute("", 0, params)); let num; loop { - match self.conn.read_message() { + match if_ok_pg!(self.conn.read_message()) { DataRow { .. } => {} ErrorResponse { fields } => { self.conn.wait_for_ready(); - return Err(PostgresDbError::new(fields)); + return Err(PgDbError(PostgresDbError::new(fields))); } CommandComplete { tag } => { let s = tag.split(' ').last().unwrap(); @@ -1031,13 +1037,13 @@ impl<'conn> PostgresStatement for NormalPostgresStatement<'conn> { _ => unreachable!() } } - self.conn.wait_for_ready(); + if_ok!(self.conn.wait_for_ready()); Ok(num) } fn try_query<'a>(&'a self, params: &[&ToSql]) - -> Result, PostgresDbError> { + -> Result, PostgresError> { self.try_lazy_query(0, params) } } @@ -1079,12 +1085,12 @@ impl<'conn> PostgresStatement for TransactionalPostgresStatement<'conn> { self.stmt.result_descriptions() } - fn try_execute(&self, params: &[&ToSql]) -> Result { + fn try_execute(&self, params: &[&ToSql]) -> Result { self.stmt.try_execute(params) } fn try_query<'a>(&'a self, params: &[&ToSql]) - -> Result, PostgresDbError> { + -> Result, PostgresError> { self.stmt.try_query(params) } } @@ -1102,7 +1108,7 @@ impl<'conn> TransactionalPostgresStatement<'conn> { /// Fails if the number or types of the provided parameters do not match /// the parameters of the statement. pub fn try_lazy_query<'a>(&'a self, row_limit: uint, params: &[&ToSql]) - -> Result, PostgresDbError> { + -> Result, PostgresError> { self.stmt.try_lazy_query(row_limit, params) } @@ -1132,27 +1138,25 @@ pub struct PostgresResult<'stmt> { #[unsafe_destructor] impl<'stmt> Drop for PostgresResult<'stmt> { fn drop(&mut self) { - io_error::cond.trap(|_| {}).inside(|| { - self.stmt.conn.write_messages([ - Close { - variant: 'P' as u8, - name: self.name.as_slice() - }, - Sync]); - loop { - match self.stmt.conn.read_message() { - ReadyForQuery { .. } => break, - _ => {} - } + let _ = self.stmt.conn.write_messages([ + Close { + variant: 'P' as u8, + name: self.name.as_slice() + }, + Sync]); + loop { + match self.stmt.conn.read_message() { + Ok(ReadyForQuery { .. }) => break, + _ => {} } - }) + } } } impl<'stmt> PostgresResult<'stmt> { - fn read_rows(&mut self) { + fn read_rows(&mut self) -> Result<(), PostgresError> { loop { - match self.stmt.conn.read_message() { + match if_ok_pg!(self.stmt.conn.read_message()) { EmptyQueryResponse | CommandComplete { .. } => { self.more_rows = false; @@ -1166,7 +1170,7 @@ impl<'stmt> PostgresResult<'stmt> { _ => unreachable!() } } - self.stmt.conn.wait_for_ready(); + self.stmt.conn.wait_for_ready() } fn execute(&mut self) { diff --git a/src/message.rs b/src/message.rs index be923963..f60ddeda 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,5 +1,5 @@ use std::str; -use std::io::{MemWriter, MemReader}; +use std::io::{IoResult, MemWriter, MemReader}; use std::mem; use std::vec; @@ -120,23 +120,23 @@ pub enum FrontendMessage<'a> { #[doc(hidden)] trait WriteCStr { - fn write_cstr(&mut self, s: &str); + fn write_cstr(&mut self, s: &str) -> IoResult<()>; } impl WriteCStr for W { - fn write_cstr(&mut self, s: &str) { - self.write(s.as_bytes()); - self.write_u8(0); + fn write_cstr(&mut self, s: &str) -> IoResult<()> { + if_ok!(self.write(s.as_bytes())); + self.write_u8(0) } } #[doc(hidden)] pub trait WriteMessage { - fn write_message(&mut self, &FrontendMessage); + fn write_message(&mut self, &FrontendMessage) -> IoResult<()> ; } impl WriteMessage for W { - fn write_message(&mut self, message: &FrontendMessage) { + fn write_message(&mut self, message: &FrontendMessage) -> IoResult<()> { debug!("Writing message {:?}", message); let mut buf = MemWriter::new(); let mut ident = None; @@ -144,78 +144,78 @@ impl WriteMessage for W { match *message { Bind { portal, statement, formats, values, result_formats } => { ident = Some('B'); - buf.write_cstr(portal); - buf.write_cstr(statement); + if_ok!(buf.write_cstr(portal)); + if_ok!(buf.write_cstr(statement)); - buf.write_be_i16(formats.len() as i16); + if_ok!(buf.write_be_i16(formats.len() as i16)); for format in formats.iter() { - buf.write_be_i16(*format); + if_ok!(buf.write_be_i16(*format)); } - buf.write_be_i16(values.len() as i16); + if_ok!(buf.write_be_i16(values.len() as i16)); for value in values.iter() { match *value { None => { - buf.write_be_i32(-1); + if_ok!(buf.write_be_i32(-1)); } Some(ref value) => { - buf.write_be_i32(value.len() as i32); - buf.write(*value); + if_ok!(buf.write_be_i32(value.len() as i32)); + if_ok!(buf.write(*value)); } } } - buf.write_be_i16(result_formats.len() as i16); + if_ok!(buf.write_be_i16(result_formats.len() as i16)); for format in result_formats.iter() { - buf.write_be_i16(*format); + if_ok!(buf.write_be_i16(*format)); } } CancelRequest { code, process_id, secret_key } => { - buf.write_be_i32(code); - buf.write_be_i32(process_id); - buf.write_be_i32(secret_key); + if_ok!(buf.write_be_i32(code)); + if_ok!(buf.write_be_i32(process_id)); + if_ok!(buf.write_be_i32(secret_key)); } Close { variant, name } => { ident = Some('C'); - buf.write_u8(variant); - buf.write_cstr(name); + if_ok!(buf.write_u8(variant)); + if_ok!(buf.write_cstr(name)); } Describe { variant, name } => { ident = Some('D'); - buf.write_u8(variant); - buf.write_cstr(name); + if_ok!(buf.write_u8(variant)); + if_ok!(buf.write_cstr(name)); } Execute { portal, max_rows } => { ident = Some('E'); - buf.write_cstr(portal); - buf.write_be_i32(max_rows); + if_ok!(buf.write_cstr(portal)); + if_ok!(buf.write_be_i32(max_rows)); } Parse { name, query, param_types } => { ident = Some('P'); - buf.write_cstr(name); - buf.write_cstr(query); - buf.write_be_i16(param_types.len() as i16); + if_ok!(buf.write_cstr(name)); + if_ok!(buf.write_cstr(query)); + if_ok!(buf.write_be_i16(param_types.len() as i16)); for ty in param_types.iter() { - buf.write_be_i32(*ty); + if_ok!(buf.write_be_i32(*ty)); } } PasswordMessage { password } => { ident = Some('p'); - buf.write_cstr(password); + if_ok!(buf.write_cstr(password)); } Query { query } => { ident = Some('Q'); - buf.write_cstr(query); + if_ok!(buf.write_cstr(query)); } StartupMessage { version, parameters } => { - buf.write_be_i32(version); + if_ok!(buf.write_be_i32(version)); for &(ref k, ref v) in parameters.iter() { - buf.write_cstr(k.as_slice()); - buf.write_cstr(v.as_slice()); + if_ok!(buf.write_cstr(k.as_slice())); + if_ok!(buf.write_cstr(v.as_slice())); } - buf.write_u8(0); + if_ok!(buf.write_u8(0)); } - SslRequest { code } => buf.write_be_i32(code), + SslRequest { code } => if_ok!(buf.write_be_i32(code)), Sync => { ident = Some('S'); } @@ -225,148 +225,150 @@ impl WriteMessage for W { } match ident { - Some(ident) => self.write_u8(ident as u8), + Some(ident) => if_ok!(self.write_u8(ident as u8)), None => () } let buf = buf.unwrap(); // add size of length value - self.write_be_i32((buf.len() + mem::size_of::()) as i32); - self.write(buf); + if_ok!(self.write_be_i32((buf.len() + mem::size_of::()) as i32)); + if_ok!(self.write(buf)); + + Ok(()) } } #[doc(hidden)] trait ReadCStr { - fn read_cstr(&mut self) -> ~str; + fn read_cstr(&mut self) -> IoResult<~str>; } impl ReadCStr for R { - fn read_cstr(&mut self) -> ~str { - let mut buf = self.read_until(0).unwrap(); + fn read_cstr(&mut self) -> IoResult<~str> { + let mut buf = if_ok!(self.read_until(0)); buf.pop(); - str::from_utf8_owned(buf).unwrap() + Ok(str::from_utf8_owned(buf).unwrap()) } } #[doc(hidden)] pub trait ReadMessage { - fn read_message(&mut self) -> BackendMessage; + fn read_message(&mut self) -> IoResult; } impl ReadMessage for R { - fn read_message(&mut self) -> BackendMessage { + fn read_message(&mut self) -> IoResult { debug!("Reading message"); - let ident = self.read_u8(); + let ident = if_ok!(self.read_u8()); // subtract size of length value - let len = self.read_be_i32() as uint - mem::size_of::(); - let mut buf = MemReader::new(self.read_bytes(len)); + let len = if_ok!(self.read_be_i32()) as uint - mem::size_of::(); + let mut buf = MemReader::new(if_ok!(self.read_bytes(len))); let ret = match ident as char { '1' => ParseComplete, '2' => BindComplete, '3' => CloseComplete, 'A' => NotificationResponse { - pid: buf.read_be_i32(), - channel: buf.read_cstr(), - payload: buf.read_cstr() + pid: if_ok!(buf.read_be_i32()), + channel: if_ok!(buf.read_cstr()), + payload: if_ok!(buf.read_cstr()) }, - 'C' => CommandComplete { tag: buf.read_cstr() }, - 'D' => read_data_row(&mut buf), - 'E' => ErrorResponse { fields: read_fields(&mut buf) }, + 'C' => CommandComplete { tag: if_ok!(buf.read_cstr()) }, + 'D' => if_ok!(read_data_row(&mut buf)), + 'E' => ErrorResponse { fields: if_ok!(read_fields(&mut buf)) }, 'I' => EmptyQueryResponse, 'K' => BackendKeyData { - process_id: buf.read_be_i32(), - secret_key: buf.read_be_i32() + process_id: if_ok!(buf.read_be_i32()), + secret_key: if_ok!(buf.read_be_i32()) }, 'n' => NoData, - 'N' => NoticeResponse { fields: read_fields(&mut buf) }, - 'R' => read_auth_message(&mut buf), + 'N' => NoticeResponse { fields: if_ok!(read_fields(&mut buf)) }, + 'R' => if_ok!(read_auth_message(&mut buf)), 's' => PortalSuspended, 'S' => ParameterStatus { - parameter: buf.read_cstr(), - value: buf.read_cstr() + parameter: if_ok!(buf.read_cstr()), + value: if_ok!(buf.read_cstr()) }, - 't' => read_parameter_description(&mut buf), - 'T' => read_row_description(&mut buf), - 'Z' => ReadyForQuery { state: buf.read_u8() }, + 't' => if_ok!(read_parameter_description(&mut buf)), + 'T' => if_ok!(read_row_description(&mut buf)), + 'Z' => ReadyForQuery { state: if_ok!(buf.read_u8()) }, ident => fail!("Unknown message identifier `{}`", ident) }; debug!("Read message {:?}", ret); - ret + Ok(ret) } } -fn read_fields(buf: &mut MemReader) -> ~[(u8, ~str)] { +fn read_fields(buf: &mut MemReader) -> IoResult<~[(u8, ~str)]> { let mut fields = ~[]; loop { - let ty = buf.read_u8(); + let ty = if_ok!(buf.read_u8()); if ty == 0 { break; } - fields.push((ty, buf.read_cstr())); + fields.push((ty, if_ok!(buf.read_cstr()))); } - fields + Ok(fields) } -fn read_data_row(buf: &mut MemReader) -> BackendMessage { - let len = buf.read_be_i16() as uint; +fn read_data_row(buf: &mut MemReader) -> IoResult { + let len = if_ok!(buf.read_be_i16()) as uint; let mut values = vec::with_capacity(len); for _ in range(0, len) { - let val = match buf.read_be_i32() { + let val = match if_ok!(buf.read_be_i32()) { -1 => None, - len => Some(buf.read_bytes(len as uint)) + len => Some(if_ok!(buf.read_bytes(len as uint))) }; values.push(val); } - DataRow { row: values } + Ok(DataRow { row: values }) } -fn read_auth_message(buf: &mut MemReader) -> BackendMessage { - match buf.read_be_i32() { +fn read_auth_message(buf: &mut MemReader) -> IoResult { + Ok(match if_ok!(buf.read_be_i32()) { 0 => AuthenticationOk, 2 => AuthenticationKerberosV5, 3 => AuthenticationCleartextPassword, - 5 => AuthenticationMD5Password { salt: buf.read_bytes(4) }, + 5 => AuthenticationMD5Password { salt: if_ok!(buf.read_bytes(4)) }, 6 => AuthenticationSCMCredential, 7 => AuthenticationGSS, 9 => AuthenticationSSPI, val => fail!("Invalid authentication identifier `{}`", val) - } + }) } -fn read_parameter_description(buf: &mut MemReader) -> BackendMessage { - let len = buf.read_be_i16() as uint; +fn read_parameter_description(buf: &mut MemReader) -> IoResult { + let len = if_ok!(buf.read_be_i16()) as uint; let mut types = vec::with_capacity(len); for _ in range(0, len) { - types.push(buf.read_be_u32()); + types.push(if_ok!(buf.read_be_u32())); } - ParameterDescription { types: types } + Ok(ParameterDescription { types: types }) } -fn read_row_description(buf: &mut MemReader) -> BackendMessage { - let len = buf.read_be_i16() as uint; +fn read_row_description(buf: &mut MemReader) -> IoResult { + let len = if_ok!(buf.read_be_i16()) as uint; let mut types = vec::with_capacity(len); for _ in range(0, len) { types.push(RowDescriptionEntry { - name: buf.read_cstr(), - table_oid: buf.read_be_u32(), - column_id: buf.read_be_i16(), - type_oid: buf.read_be_u32(), - type_size: buf.read_be_i16(), - type_modifier: buf.read_be_i32(), - format: buf.read_be_i16() + name: if_ok!(buf.read_cstr()), + table_oid: if_ok!(buf.read_be_u32()), + column_id: if_ok!(buf.read_be_i16()), + type_oid: if_ok!(buf.read_be_u32()), + type_size: if_ok!(buf.read_be_i16()), + type_modifier: if_ok!(buf.read_be_i32()), + format: if_ok!(buf.read_be_i16()) }); } - RowDescription { descriptions: types } + Ok(RowDescription { descriptions: types }) } diff --git a/src/pool.rs b/src/pool.rs index ed1b2f1e..7f1e7881 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -10,7 +10,7 @@ use super::{PostgresNotifications, NormalPostgresStatement, PostgresTransaction, SslMode}; -use super::error::{PostgresConnectError, PostgresDbError}; +use super::error::{PostgresConnectError, PostgresError}; use super::types::ToSql; struct InnerConnectionPool { @@ -121,7 +121,7 @@ impl Drop for PooledPostgresConnection { impl PooledPostgresConnection { /// Like `PostgresConnection::try_prepare`. pub fn try_prepare<'a>(&'a self, query: &str) - -> Result, PostgresDbError> { + -> Result, PostgresError> { self.conn.get_ref().try_prepare(query) } @@ -132,7 +132,7 @@ impl PooledPostgresConnection { /// Like `PostgresConnection::try_execute`. pub fn try_execute(&self, query: &str, params: &[&ToSql]) - -> Result { + -> Result { self.conn.get_ref().try_execute(query, params) } diff --git a/src/tests.rs b/src/tests.rs index f4ff33be..1af3d7b9 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -3,7 +3,7 @@ use extra::future::Future; use extra::time; use extra::time::Timespec; use extra::json; -use extra::uuid::Uuid; +use uuid::Uuid; use openssl::ssl::{SslContext, Sslv3}; use std::f32; use std::f64; @@ -18,7 +18,8 @@ use {PostgresNoticeHandler, RequireSsl, PreferSsl, NoSsl}; -use error::{DbError, +use error::{PgConnectDbError, + PgDbError, DnsError, MissingPassword, Position, @@ -74,7 +75,7 @@ fn test_url_terminating_slash() { fn test_prepare_err() { let conn = PostgresConnection::connect("postgres://postgres@localhost", &NoSsl); match conn.try_prepare("invalid sql statment") { - Err(PostgresDbError { code: SyntaxError, position: Some(Position(1)), .. }) => (), + Err(PgDbError(PostgresDbError { code: SyntaxError, position: Some(Position(1)), .. })) => (), resp => fail!("Unexpected result {:?}", resp) } } @@ -82,7 +83,7 @@ fn test_prepare_err() { #[test] fn test_unknown_database() { match PostgresConnection::try_connect("postgres://postgres@localhost/asdf", &NoSsl) { - Err(DbError(PostgresDbError { code: InvalidCatalogName, .. })) => {} + Err(PgConnectDbError(PostgresDbError { code: InvalidCatalogName, .. })) => {} resp => fail!("Unexpected result {:?}", resp) } } @@ -702,7 +703,7 @@ fn test_cancel_query() { }); match conn.try_execute("SELECT pg_sleep(10)", []) { - Err(PostgresDbError { code: QueryCanceled, .. }) => {} + Err(PgDbError(PostgresDbError { code: QueryCanceled, .. })) => {} res => fail!("Unexpected result {:?}", res) } } @@ -742,7 +743,7 @@ fn test_plaintext_pass_no_pass() { fn test_plaintext_pass_wrong_pass() { let ret = PostgresConnection::try_connect("postgres://pass_user:asdf@localhost/postgres", &NoSsl); match ret { - Err(DbError(PostgresDbError { code: InvalidPassword, .. })) => (), + Err(PgConnectDbError(PostgresDbError { code: InvalidPassword, .. })) => (), Err(err) => fail!("Unexpected error {}", err.to_str()), _ => fail!("Expected error") } @@ -767,7 +768,7 @@ fn test_md5_pass_no_pass() { fn test_md5_pass_wrong_pass() { let ret = PostgresConnection::try_connect("postgres://md5_user:asdf@localhost/postgres", &NoSsl); match ret { - Err(DbError(PostgresDbError { code: InvalidPassword, .. })) => (), + Err(PgConnectDbError(PostgresDbError { code: InvalidPassword, .. })) => (), Err(err) => fail!("Unexpected error {}", err.to_str()), _ => fail!("Expected error") } diff --git a/src/types/mod.rs b/src/types/mod.rs index e1c84bc4..2e6e8933 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -6,7 +6,7 @@ extern mod extra; use extra::time::Timespec; use extra::json; use extra::json::Json; -use extra::uuid::Uuid; +use uuid::Uuid; use std::hashmap::HashMap; use std::io::{MemWriter, BufReader}; use std::io::util::LimitReader; @@ -227,6 +227,15 @@ macro_rules! check_types( ) ) +macro_rules! or_fail( + ($e:expr) => ( + match $e { + Ok(ok) => ok, + Err(err) => fail!("{}", err) + } + ) +) + /// A trait for types that can be created from a Postgres value pub trait FromSql { /// Creates a new value of this type from a buffer of Postgres data. @@ -248,7 +257,7 @@ macro_rules! raw_from_impl( ($t:ty, $f:ident) => ( impl RawFromSql for $t { fn raw_from_sql(_len: uint, raw: &mut R) -> $t { - raw.$f() + or_fail!(raw.$f()) } } ) @@ -256,19 +265,19 @@ macro_rules! raw_from_impl( impl RawFromSql for bool { fn raw_from_sql(_len: uint, raw: &mut R) -> bool { - raw.read_u8() != 0 + (or_fail!(raw.read_u8())) != 0 } } impl RawFromSql for ~[u8] { fn raw_from_sql(len: uint, raw: &mut R) -> ~[u8] { - raw.read_bytes(len) + or_fail!(raw.read_bytes(len)) } } impl RawFromSql for ~str { fn raw_from_sql(len: uint, raw: &mut R) -> ~str { - str::from_utf8_owned(raw.read_bytes(len)).unwrap() + str::from_utf8_owned(or_fail!(raw.read_bytes(len))).unwrap() } } @@ -281,7 +290,7 @@ raw_from_impl!(f64, read_be_f64) impl RawFromSql for Timespec { fn raw_from_sql(_len: uint, raw: &mut R) -> Timespec { - let t = raw.read_be_i64(); + let t = or_fail!(raw.read_be_i64()); let mut sec = t / USEC_PER_SEC + TIME_SEC_CONVERSION; let mut usec = t % USEC_PER_SEC; @@ -296,7 +305,7 @@ impl RawFromSql for Timespec { impl RawFromSql for Uuid { fn raw_from_sql(len: uint, raw: &mut R) -> Uuid { - Uuid::from_bytes(raw.read_bytes(len)).unwrap() + Uuid::from_bytes(or_fail!(raw.read_bytes(len))).unwrap() } } @@ -304,7 +313,7 @@ macro_rules! from_range_impl( ($($oid:ident)|+, $t:ty) => ( impl RawFromSql for Range<$t> { fn raw_from_sql(_len: uint, rdr: &mut R) -> Range<$t> { - let t = rdr.read_i8(); + let t = or_fail!(rdr.read_i8()); if t & RANGE_EMPTY != 0 { Range::empty() @@ -315,7 +324,7 @@ macro_rules! from_range_impl( 0 => Exclusive, _ => Inclusive }; - let len = rdr.read_be_i32() as uint; + let len = or_fail!(rdr.read_be_i32()) as uint; Some(RangeBound::new( RawFromSql::raw_from_sql(len, rdr), type_)) } @@ -327,7 +336,7 @@ macro_rules! from_range_impl( 0 => Exclusive, _ => Inclusive }; - let len = rdr.read_be_i32() as uint; + let len = or_fail!(rdr.read_be_i32()) as uint; Some(RangeBound::new( RawFromSql::raw_from_sql(len, rdr), type_)) } @@ -402,22 +411,22 @@ macro_rules! from_array_impl( from_map_impl!($($oid)|+, ArrayBase>, |buf| { let mut rdr = BufReader::new(buf.as_slice()); - let ndim = rdr.read_be_i32() as uint; - let _has_null = rdr.read_be_i32() == 1; - let _element_type: Oid = rdr.read_be_u32(); + let ndim = or_fail!(rdr.read_be_i32()) as uint; + let _has_null = or_fail!(rdr.read_be_i32()) == 1; + let _element_type: Oid = or_fail!(rdr.read_be_u32()); let mut dim_info = vec::with_capacity(ndim); for _ in range(0, ndim) { dim_info.push(DimensionInfo { - len: rdr.read_be_i32() as uint, - lower_bound: rdr.read_be_i32() as int + len: or_fail!(rdr.read_be_i32()) as uint, + lower_bound: or_fail!(rdr.read_be_i32()) as int }); } let nele = dim_info.iter().fold(1, |acc, info| acc * info.len); let mut elements = vec::with_capacity(nele); for _ in range(0, nele) { - let len = rdr.read_be_i32(); + let len = or_fail!(rdr.read_be_i32()); if len < 0 { elements.push(None); } else { @@ -452,17 +461,17 @@ from_map_impl!(PgUnknownType { name: ~"hstore", .. }, let mut rdr = BufReader::new(buf.as_slice()); let mut map = HashMap::new(); - let count = rdr.read_be_i32(); + let count = or_fail!(rdr.read_be_i32()); for _ in range(0, count) { - let key_len = rdr.read_be_i32(); - let key = str::from_utf8_owned(rdr.read_bytes(key_len as uint)).unwrap(); + let key_len = or_fail!(rdr.read_be_i32()); + let key = str::from_utf8_owned(or_fail!(rdr.read_bytes(key_len as uint))).unwrap(); - let val_len = rdr.read_be_i32(); + let val_len = or_fail!(rdr.read_be_i32()); let val = if val_len < 0 { None } else { - Some(str::from_utf8_owned(rdr.read_bytes(val_len as uint)).unwrap()) + Some(str::from_utf8_owned(or_fail!(rdr.read_bytes(val_len as uint))).unwrap()) }; map.insert(key, val); @@ -492,7 +501,7 @@ macro_rules! raw_to_impl( ($t:ty, $f:ident) => ( impl RawToSql for $t { fn raw_to_sql(&self, w: &mut W) { - w.$f(*self) + or_fail!(w.$f(*self)) } } ) @@ -500,19 +509,19 @@ macro_rules! raw_to_impl( impl RawToSql for bool { fn raw_to_sql(&self, w: &mut W) { - w.write_u8(*self as u8) + or_fail!(w.write_u8(*self as u8)) } } impl RawToSql for ~[u8] { fn raw_to_sql(&self, w: &mut W) { - w.write(self.as_slice()) + or_fail!(w.write(self.as_slice())) } } impl RawToSql for ~str { fn raw_to_sql(&self, w: &mut W) { - w.write(self.as_bytes()) + or_fail!(w.write(self.as_bytes())) } } @@ -527,13 +536,13 @@ impl RawToSql for Timespec { fn raw_to_sql(&self, w: &mut W) { let t = (self.sec - TIME_SEC_CONVERSION) * USEC_PER_SEC + self.nsec as i64 / NSEC_PER_USEC; - w.write_be_i64(t); + or_fail!(w.write_be_i64(t)) } } impl RawToSql for Uuid { fn raw_to_sql(&self, w: &mut W) { - w.write(self.to_bytes()) + or_fail!(w.write(self.to_bytes())) } } @@ -559,15 +568,15 @@ macro_rules! to_range_impl( } } - buf.write_i8(tag); + or_fail!(buf.write_i8(tag)); match self.lower() { Some(bound) => { let mut inner_buf = MemWriter::new(); bound.value.raw_to_sql(&mut inner_buf); let inner_buf = inner_buf.unwrap(); - buf.write_be_i32(inner_buf.len() as i32); - buf.write(inner_buf); + or_fail!(buf.write_be_i32(inner_buf.len() as i32)); + or_fail!(buf.write(inner_buf)); } None => {} } @@ -576,8 +585,8 @@ macro_rules! to_range_impl( let mut inner_buf = MemWriter::new(); bound.value.raw_to_sql(&mut inner_buf); let inner_buf = inner_buf.unwrap(); - buf.write_be_i32(inner_buf.len() as i32); - buf.write(inner_buf); + or_fail!(buf.write_be_i32(inner_buf.len() as i32)); + or_fail!(buf.write(inner_buf)); } None => {} } @@ -592,7 +601,7 @@ to_range_impl!(PgTsRange | PgTstzRange, Timespec) impl RawToSql for Json { fn raw_to_sql(&self, raw: &mut W) { - self.to_writer(raw as &mut Writer) + or_fail!(self.to_writer(raw as &mut Writer)) } } @@ -684,13 +693,13 @@ macro_rules! to_array_impl( check_types!($($oid)|+, ty) let mut buf = MemWriter::new(); - buf.write_be_i32(self.dimension_info().len() as i32); - buf.write_be_i32(1); - buf.write_be_u32(ty.member_type().to_oid()); + or_fail!(buf.write_be_i32(self.dimension_info().len() as i32)); + or_fail!(buf.write_be_i32(1)); + or_fail!(buf.write_be_u32(ty.member_type().to_oid())); for info in self.dimension_info().iter() { - buf.write_be_i32(info.len as i32); - buf.write_be_i32(info.lower_bound as i32); + or_fail!(buf.write_be_i32(info.len as i32)); + or_fail!(buf.write_be_i32(info.lower_bound as i32)); } for v in self.values() { @@ -699,10 +708,10 @@ macro_rules! to_array_impl( let mut inner_buf = MemWriter::new(); val.raw_to_sql(&mut inner_buf); let inner_buf = inner_buf.unwrap(); - buf.write_be_i32(inner_buf.len() as i32); - buf.write(inner_buf); + or_fail!(buf.write_be_i32(inner_buf.len() as i32)); + or_fail!(buf.write(inner_buf)); } - None => buf.write_be_i32(-1) + None => or_fail!(buf.write_be_i32(-1)) } } @@ -735,18 +744,18 @@ impl<'a> ToSql for HashMap<~str, Option<~str>> { check_types!(PgUnknownType { name: ~"hstore", .. }, ty) let mut buf = MemWriter::new(); - buf.write_be_i32(self.len() as i32); + or_fail!(buf.write_be_i32(self.len() as i32)); for (key, val) in self.iter() { - buf.write_be_i32(key.len() as i32); - buf.write(key.as_bytes()); + or_fail!(buf.write_be_i32(key.len() as i32)); + or_fail!(buf.write(key.as_bytes())); match *val { Some(ref val) => { - buf.write_be_i32(val.len() as i32); - buf.write(val.as_bytes()); + or_fail!(buf.write_be_i32(val.len() as i32)); + or_fail!(buf.write(val.as_bytes())); } - None => buf.write_be_i32(-1) + None => or_fail!(buf.write_be_i32(-1)) } } diff --git a/submodules/rust-openssl b/submodules/rust-openssl index 89e79afa..1a5e625b 160000 --- a/submodules/rust-openssl +++ b/submodules/rust-openssl @@ -1 +1 @@ -Subproject commit 89e79afaf95f5e05ad6f3cd916e2412d184d354d +Subproject commit 1a5e625b4f21c9b4870ef30ab1da3c1fed919672