diff --git a/Makefile b/Makefile index ba734a8f..a85e57d4 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ RUSTC ?= rustc -RUSTFLAGS += -L. --cfg debug +RUSTFLAGS += -L. --cfg debug -Z debug-info .PHONY: all all: postgres.dummy diff --git a/src/lib.rs b/src/lib.rs index 1de2fdd8..436b70b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,11 @@ extern mod extra; +use extra::digest::Digest; +use extra::md5::Md5; +use extra::url::Url; use std::cell::Cell; -use std::hashmap::HashMap; use std::rt::io::net::ip::SocketAddr; use std::rt::io::net::tcp::TcpStream; -use extra::url::Url; use std::str; use message::*; @@ -24,29 +25,29 @@ impl Drop for PostgresConnection { impl PostgresConnection { pub fn connect(url: &str) -> PostgresConnection { - let parsed_url: Url = FromStr::from_str(url).unwrap(); + let url: Url = FromStr::from_str(url).unwrap(); - let socket_url = fmt!("%s:%s", parsed_url.host, - parsed_url.port.get_ref().as_slice()); + let socket_url = fmt!("%s:%s", url.host, + url.port.get_ref().as_slice()); let addr: SocketAddr = FromStr::from_str(socket_url).unwrap(); let conn = PostgresConnection { stream: Cell::new(TcpStream::connect(addr).unwrap()), next_stmt_id: Cell::new(0) }; - let mut args = HashMap::new(); - args.insert(&"user", parsed_url.user.get_ref().user.as_slice()); - conn.write_message(&StartupMessage(args)); - - match conn.read_message() { - AuthenticationOk => (), - resp => fail!("Bad response: %?", resp.to_str()) + let mut args = url.query.clone(); + args.push((~"user", url.user.get_ref().user.clone())); + if !url.path.is_empty() { + args.push((~"database", url.path.clone())); } + conn.write_message(&StartupMessage(args.as_slice())); + + conn.handle_auth(&url); loop { match conn.read_message() { ParameterStatus(param, value) => - info!("Param %s = %s", param, value), + info!("Parameter %s = %s", param, value), BackendKeyData(*) => (), ReadyForQuery(*) => break, resp => fail!("Bad response: %?", resp.to_str()) @@ -68,6 +69,31 @@ impl PostgresConnection { } } + fn handle_auth(&self, url: &Url) { + loop { + match self.read_message() { + AuthenticationOk => break, + AuthenticationCleartextPassword => { + let pass = url.user.get_ref().pass.get_ref().as_slice(); + self.write_message(&PasswordMessage(pass)); + } + AuthenticationMD5Password(nonce) => { + let input = url.user.get_ref().pass.get_ref().as_slice() + + url.user.get_ref().user.as_slice(); + let mut md5 = Md5::new(); + md5.input_str(input); + let output = md5.result_str(); + md5.reset(); + md5.input_str(output); + md5.input(nonce); + let output = "md5" + md5.result_str(); + self.write_message(&PasswordMessage(output.as_slice())); + } + resp => fail!("Bad response: %?", resp.to_str()) + } + } + } + pub fn prepare<'a>(&'a self, query: &str) -> PostgresStatement<'a> { let id = self.next_stmt_id.take(); let stmt_name = ifmt!("statement_{}", id); diff --git a/src/message.rs b/src/message.rs index 4617fd0f..3613d717 100644 --- a/src/message.rs +++ b/src/message.rs @@ -3,7 +3,6 @@ use std::rt::io::{Decorator, Reader, Writer}; use std::rt::io::extensions::{ReaderUtil, ReaderByteConversions, WriterByteConversions}; use std::rt::io::mem::{MemWriter, MemReader}; -use std::hashmap::HashMap; use std::sys; use std::vec; @@ -11,6 +10,8 @@ pub static PROTOCOL_VERSION: i32 = 0x0003_0000; #[deriving(ToStr)] pub enum BackendMessage { + AuthenticationCleartextPassword, + AuthenticationMD5Password(~[u8]), AuthenticationOk, BackendKeyData(i32, i32), BindComplete, @@ -18,9 +19,9 @@ pub enum BackendMessage { CommandComplete(~str), DataRow(~[Option<~[u8]>]), EmptyQueryResponse, - ErrorResponse(HashMap), + ErrorResponse(~[(u8, ~str)]), NoData, - NoticeResponse(HashMap), + NoticeResponse(~[(u8, ~str)]), ParameterDescription(~[i32]), ParameterStatus(~str, ~str), ParseComplete, @@ -48,8 +49,9 @@ pub enum FrontendMessage<'self> { Execute(&'self str, i32), /// name, query, parameter types Parse(&'self str, &'self str, &'self [i32]), + PasswordMessage(&'self str), Query(&'self str), - StartupMessage(HashMap<&'self str, &'self str>), + StartupMessage(&'self [(~str, ~str)]), Sync, Terminate } @@ -128,15 +130,19 @@ impl WriteMessage for W { buf.write_be_i32_(*ty); } } + PasswordMessage(password) => { + ident = Some('p'); + buf.write_string(password); + } Query(query) => { ident = Some('Q'); buf.write_string(query); } StartupMessage(ref params) => { buf.write_be_i32_(PROTOCOL_VERSION); - for (k, v) in params.iter() { - buf.write_string(*k); - buf.write_string(*v); + for &(ref k, ref v) in params.iter() { + buf.write_string(k.as_slice()); + buf.write_string(v.as_slice()); } buf.write_u8_(0); } @@ -199,11 +205,11 @@ impl ReadMessage for R { '3' => CloseComplete, 'C' => CommandComplete(buf.read_string()), 'D' => read_data_row(&mut buf), - 'E' => ErrorResponse(read_hash(&mut buf)), + 'E' => ErrorResponse(read_fields(&mut buf)), 'I' => EmptyQueryResponse, 'K' => BackendKeyData(buf.read_be_i32_(), buf.read_be_i32_()), 'n' => NoData, - 'N' => NoticeResponse(read_hash(&mut buf)), + 'N' => NoticeResponse(read_fields(&mut buf)), 'R' => read_auth_message(&mut buf), 'S' => ParameterStatus(buf.read_string(), buf.read_string()), 't' => read_parameter_description(&mut buf), @@ -217,15 +223,15 @@ impl ReadMessage for R { } } -fn read_hash(buf: &mut MemReader) -> HashMap { - let mut fields = HashMap::new(); +fn read_fields(buf: &mut MemReader) -> ~[(u8, ~str)] { + let mut fields = ~[]; loop { let ty = buf.read_u8_(); if ty == 0 { break; } - fields.insert(ty, buf.read_string()); + fields.push((ty, buf.read_string())); } fields @@ -249,6 +255,8 @@ fn read_data_row(buf: &mut MemReader) -> BackendMessage { fn read_auth_message(buf: &mut MemReader) -> BackendMessage { match buf.read_be_i32_() { 0 => AuthenticationOk, + 3 => AuthenticationCleartextPassword, + 5 => AuthenticationMD5Password(buf.read_bytes(4)), val => fail!("Unknown Authentication identifier `%?`", val) } } diff --git a/src/test.rs b/src/test.rs index 62a73630..2f882a7b 100644 --- a/src/test.rs +++ b/src/test.rs @@ -38,7 +38,7 @@ fn test_nulls() { conn.prepare("CREATE TABLE foo ( id BIGINT PRIMARY KEY, val VARCHAR - )").update([]); + )").update([]); conn.prepare("INSERT INTO foo (id, val) VALUES ($1, $2), ($3, $4)") .update([&1 as &ToSql, & &"foobar" as &ToSql, &2 as &ToSql, &None::<~str> as &ToSql]); @@ -51,3 +51,19 @@ fn test_nulls() { Err::<(), ()>(()) }; } + +#[test] +fn test_plaintext_pass() { + PostgresConnection::connect("postgres://pass_user:password@127.0.0.1:5432"); +} + +#[test] +#[should_fail] +fn test_plaintext_pass_no_pass() { + PostgresConnection::connect("postgres://pass_user@127.0.0.1:5432"); +} + +#[test] +fn test_md5_pass() { + PostgresConnection::connect("postgres://md5_user:password@127.0.0.1:5432"); +}