Add a MessageStream

This commit is contained in:
Steven Fackler 2016-09-11 17:19:06 -07:00
parent f135d22394
commit 82f708c5ff
4 changed files with 87 additions and 54 deletions

View File

@ -68,10 +68,11 @@ use io::{TlsStream, TlsHandshake};
use message::{Backend, RowDescriptionEntry, ReadMessage};
use notification::{Notifications, Notification};
use params::{ConnectParams, IntoConnectParams, UserInfo};
use priv_io::MessageStream;
use rows::{Rows, LazyRows};
use stmt::{Statement, Column};
use types::{IsNull, Kind, Type, SessionInfo, Oid, Other, WrongType, ToSql, FromSql, Field};
use transaction::{Transaction, IsolationLevel};
use types::{IsNull, Kind, Type, SessionInfo, Oid, Other, WrongType, ToSql, FromSql, Field};
#[macro_use]
mod macros;
@ -127,9 +128,9 @@ impl HandleNotice for LoggingNoticeHandler {
#[derive(Copy, Clone, Debug)]
pub struct CancelData {
/// The process ID of the session.
pub process_id: u32,
pub process_id: i32,
/// The secret key for the session.
pub secret_key: u32,
pub secret_key: i32,
}
/// Attempts to cancel an in-progress query.
@ -167,8 +168,8 @@ pub fn cancel_query<T>(params: T,
let mut socket = try!(priv_io::initialize_stream(&params, tls));
let message = frontend::CancelRequest {
process_id: data.process_id as i32,
secret_key: data.secret_key as i32,
process_id: data.process_id,
secret_key: data.secret_key,
};
let mut buf = vec![];
try!(frontend::Message::write(&message, &mut buf));
@ -208,8 +209,7 @@ struct StatementInfo {
}
struct InnerConnection {
stream: BufStream<Box<TlsStream>>,
io_buf: Vec<u8>,
stream: MessageStream,
notice_handler: Box<HandleNotice>,
notifications: VecDeque<Notification>,
cancel_data: CancelData,
@ -250,8 +250,7 @@ impl InnerConnection {
};
let mut conn = InnerConnection {
stream: BufStream::new(stream),
io_buf: vec![],
stream: MessageStream::new(stream),
next_stmt_id: 0,
notice_handler: Box::new(LoggingNoticeHandler),
notifications: VecDeque::new(),
@ -280,7 +279,7 @@ impl InnerConnection {
options.push(("database".to_owned(), database));
}
try!(conn.write_message(&frontend::StartupMessage {
try!(conn.stream.write_message(&frontend::StartupMessage {
parameters: &options,
}));
try!(conn.stream.flush());
@ -290,8 +289,8 @@ impl InnerConnection {
loop {
match try!(conn.read_message()) {
Backend::BackendKeyData { process_id, secret_key } => {
conn.cancel_data.process_id = process_id;
conn.cancel_data.secret_key = secret_key;
conn.cancel_data.process_id = process_id as i32;
conn.cancel_data.secret_key = secret_key as i32;
}
Backend::ReadyForQuery { .. } => break,
Backend::ErrorResponse { fields } => return DbError::new_connect(fields),
@ -302,16 +301,6 @@ impl InnerConnection {
Ok(conn)
}
fn write_message<M>(&mut self, message: &M) -> std_io::Result<()>
where M: frontend::Message
{
debug_assert!(!self.desynchronized);
self.io_buf.clear();
try!(message.write(&mut self.io_buf));
try_desync!(self, self.stream.write_all(&self.io_buf));
Ok(())
}
fn read_message_with_notification(&mut self) -> std_io::Result<Backend> {
debug_assert!(!self.desynchronized);
loop {
@ -388,7 +377,7 @@ impl InnerConnection {
let pass = try!(user.password.ok_or_else(|| {
ConnectError::ConnectParams("a password was requested but not provided".into())
}));
try!(self.write_message(&frontend::PasswordMessage { password: &pass }));
try!(self.stream.write_message(&frontend::PasswordMessage { password: &pass }));
try!(self.stream.flush());
}
Backend::AuthenticationMD5Password { salt } => {
@ -403,7 +392,7 @@ impl InnerConnection {
hasher.input(output.as_bytes());
hasher.input(&salt);
let output = format!("md5{}", hasher.result_str());
try!(self.write_message(&frontend::PasswordMessage { password: &output }));
try!(self.stream.write_message(&frontend::PasswordMessage { password: &output }));
try!(self.stream.flush());
}
Backend::AuthenticationKerberosV5 |
@ -431,16 +420,16 @@ impl InnerConnection {
fn raw_prepare(&mut self, stmt_name: &str, query: &str) -> Result<(Vec<Type>, Vec<Column>)> {
debug!("preparing query with name `{}`: {}", stmt_name, query);
try!(self.write_message(&frontend::Parse {
try!(self.stream.write_message(&frontend::Parse {
name: stmt_name,
query: query,
param_types: &[],
}));
try!(self.write_message(&frontend::Describe {
try!(self.stream.write_message(&frontend::Describe {
variant: b'S',
name: stmt_name,
}));
try!(self.write_message(&frontend::Sync));
try!(self.stream.write_message(&frontend::Sync));
try!(self.stream.flush());
match try!(self.read_message()) {
@ -496,10 +485,10 @@ impl InnerConnection {
return DbError::new(fields);
}
Backend::CopyInResponse { .. } => {
try!(self.write_message(&frontend::CopyFail {
try!(self.stream.write_message(&frontend::CopyFail {
message: "COPY queries cannot be directly executed",
}));
try!(self.write_message(&frontend::Sync));
try!(self.stream.write_message(&frontend::Sync));
try!(self.stream.flush());
}
Backend::CopyOutResponse { .. } => {
@ -545,18 +534,18 @@ impl InnerConnection {
}
}
try!(self.write_message(&frontend::Bind {
try!(self.stream.write_message(&frontend::Bind {
portal: portal_name,
statement: &stmt_name,
formats: &[1],
values: &values,
result_formats: &[1],
}));
try!(self.write_message(&frontend::Execute {
try!(self.stream.write_message(&frontend::Execute {
portal: portal_name,
max_rows: row_limit,
}));
try!(self.write_message(&frontend::Sync));
try!(self.stream.write_message(&frontend::Sync));
try!(self.stream.flush());
match try!(self.read_message()) {
@ -611,11 +600,11 @@ impl InnerConnection {
}
fn close_statement(&mut self, name: &str, type_: u8) -> Result<()> {
try!(self.write_message(&frontend::Close {
try!(self.stream.write_message(&frontend::Close {
variant: type_,
name: name,
}));
try!(self.write_message(&frontend::Sync));
try!(self.stream.write_message(&frontend::Sync));
try!(self.stream.flush());
let resp = match try!(self.read_message()) {
Backend::CloseComplete => Ok(()),
@ -815,7 +804,7 @@ impl InnerConnection {
fn quick_query(&mut self, query: &str) -> Result<Vec<Vec<Option<String>>>> {
check_desync!(self);
debug!("executing query: {}", query);
try!(self.write_message(&frontend::Query { query: query }));
try!(self.stream.write_message(&frontend::Query { query: query }));
try!(self.stream.flush());
let mut result = vec![];
@ -830,10 +819,10 @@ impl InnerConnection {
.collect());
}
Backend::CopyInResponse { .. } => {
try!(self.write_message(&frontend::CopyFail {
try!(self.stream.write_message(&frontend::CopyFail {
message: "COPY queries cannot be directly executed",
}));
try!(self.write_message(&frontend::Sync));
try!(self.stream.write_message(&frontend::Sync));
try!(self.stream.flush());
}
Backend::ErrorResponse { fields } => {
@ -848,7 +837,7 @@ impl InnerConnection {
fn finish_inner(&mut self) -> Result<()> {
check_desync!(self);
try!(self.write_message(&frontend::Terminate));
try!(self.stream.write_message(&frontend::Terminate));
try!(self.stream.flush());
Ok(())
}

View File

@ -20,15 +20,59 @@ use io::TlsStream;
const DEFAULT_PORT: u16 = 5432;
pub struct MessageStream {
stream: BufStream<Box<TlsStream>>,
buf: Vec<u8>,
}
impl MessageStream {
pub fn new(stream: Box<TlsStream>) -> MessageStream {
MessageStream {
stream: BufStream::new(stream),
buf: vec![],
}
}
pub fn get_ref(&self) -> &Box<TlsStream> {
self.stream.get_ref()
}
pub fn write_message(&mut self, message: &frontend::Message) -> io::Result<()> {
self.buf.clear();
try!(frontend::Message::write(message, &mut self.buf));
self.stream.write_all(&self.buf)
}
pub fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
}
}
impl io::Read for MessageStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.stream.read(buf)
}
}
impl io::BufRead for MessageStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
#[doc(hidden)]
pub trait StreamOptions {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>;
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()>;
}
impl StreamOptions for BufStream<Box<TlsStream>> {
impl StreamOptions for MessageStream {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match self.get_ref().get_ref().0 {
match self.stream.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_read_timeout(timeout),
#[cfg(unix)]
InternalStream::Unix(ref s) => s.set_read_timeout(timeout),
@ -36,7 +80,7 @@ impl StreamOptions for BufStream<Box<TlsStream>> {
}
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
match self.get_ref().get_ref().0 {
match self.stream.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_nonblocking(nonblock),
#[cfg(unix)]
InternalStream::Unix(ref s) => s.set_nonblocking(nonblock),

View File

@ -351,11 +351,11 @@ impl<'trans, 'stmt> LazyRows<'trans, 'stmt> {
fn execute(&mut self) -> Result<()> {
let mut conn = self.stmt.conn().conn.borrow_mut();
try!(conn.write_message(&frontend::Execute {
try!(conn.stream.write_message(&frontend::Execute {
portal: &self.name,
max_rows: self.row_limit,
}));
try!(conn.write_message(&frontend::Sync));
try!(conn.stream.write_message(&frontend::Sync));
try!(conn.stream.flush());
conn.read_rows(&mut self.data).map(|more_rows| self.more_rows = more_rows)
}

View File

@ -147,10 +147,10 @@ impl<'conn> Statement<'conn> {
break;
}
Backend::CopyInResponse { .. } => {
try!(conn.write_message(&frontend::CopyFail {
try!(conn.stream.write_message(&frontend::CopyFail {
message: "COPY queries cannot be directly executed",
}));
try!(conn.write_message(&frontend::Sync));
try!(conn.stream.write_message(&frontend::Sync));
try!(conn.stream.flush());
}
Backend::CopyOutResponse { .. } => {
@ -297,12 +297,12 @@ impl<'conn> Statement<'conn> {
match fill_copy_buf(&mut buf, r, &info) {
Ok(0) => break,
Ok(len) => {
try!(info.conn.write_message(&frontend::CopyData { data: &buf[..len] }));
try!(info.conn.stream.write_message(&frontend::CopyData { data: &buf[..len] }));
}
Err(err) => {
try!(info.conn.write_message(&frontend::CopyFail { message: "" }));
try!(info.conn.write_message(&frontend::CopyDone));
try!(info.conn.write_message(&frontend::Sync));
try!(info.conn.stream.write_message(&frontend::CopyFail { message: "" }));
try!(info.conn.stream.write_message(&frontend::CopyDone));
try!(info.conn.stream.write_message(&frontend::Sync));
try!(info.conn.stream.flush());
match try!(info.conn.read_message()) {
Backend::ErrorResponse { .. } => {
@ -319,8 +319,8 @@ impl<'conn> Statement<'conn> {
}
}
try!(info.conn.write_message(&frontend::CopyDone));
try!(info.conn.write_message(&frontend::Sync));
try!(info.conn.stream.write_message(&frontend::CopyDone));
try!(info.conn.stream.write_message(&frontend::Sync));
try!(info.conn.stream.flush());
let num = match try!(info.conn.read_message()) {
@ -368,9 +368,9 @@ impl<'conn> Statement<'conn> {
let (format, column_formats) = match try!(conn.read_message()) {
Backend::CopyOutResponse { format, column_formats } => (format, column_formats),
Backend::CopyInResponse { .. } => {
try!(conn.write_message(&frontend::CopyFail { message: "" }));
try!(conn.write_message(&frontend::CopyDone));
try!(conn.write_message(&frontend::Sync));
try!(conn.stream.write_message(&frontend::CopyFail { message: "" }));
try!(conn.stream.write_message(&frontend::CopyDone));
try!(conn.stream.write_message(&frontend::Sync));
try!(conn.stream.flush());
match try!(conn.read_message()) {
Backend::ErrorResponse { .. } => {