Switch over to protocol backend parser

We still have to keep our owned Backend enum around since SEME doesn't
work :'(
This commit is contained in:
Steven Fackler 2016-09-11 21:27:13 -07:00
parent 457d700639
commit 96943d7e10
6 changed files with 167 additions and 304 deletions

View File

@ -34,6 +34,7 @@ with-uuid = ["uuid"]
[dependencies]
bufstream = "0.1"
byteorder = "0.5"
fallible-iterator = "0.1"
hex = "0.2"
log = "0.3"
phf = "=0.7.15"

View File

@ -44,6 +44,7 @@
extern crate bufstream;
extern crate byteorder;
extern crate fallible_iterator;
extern crate hex;
#[macro_use]
extern crate log;
@ -64,7 +65,7 @@ use postgres_protocol::message::frontend;
use error::{Error, ConnectError, SqlState, DbError};
use io::TlsHandshake;
use message::{Backend, RowDescriptionEntry, ReadMessage};
use message::{Backend, RowDescriptionEntry};
use notification::{Notifications, Notification};
use params::{ConnectParams, IntoConnectParams, UserInfo};
use priv_io::MessageStream;
@ -288,8 +289,8 @@ impl InnerConnection {
loop {
match try!(conn.read_message()) {
Backend::BackendKeyData { process_id, secret_key } => {
conn.cancel_data.process_id = process_id as i32;
conn.cancel_data.secret_key = secret_key as i32;
conn.cancel_data.process_id = process_id;
conn.cancel_data.secret_key = secret_key;
}
Backend::ReadyForQuery { .. } => break,
Backend::ErrorResponse { fields } => return DbError::new_connect(fields),
@ -303,7 +304,7 @@ impl InnerConnection {
fn read_message_with_notification(&mut self) -> std_io::Result<Backend> {
debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, ReadMessage::read_message(&mut self.stream)) {
match try_desync!(self, self.stream.read_message()) {
Backend::NoticeResponse { fields } => {
if let Ok(err) = DbError::new_raw(fields) {
self.notice_handler.handle_notice(err);
@ -357,9 +358,9 @@ impl InnerConnection {
fn read_message(&mut self) -> std_io::Result<Backend> {
loop {
match try!(self.read_message_with_notification()) {
Backend::NotificationResponse { pid, channel, payload } => {
Backend::NotificationResponse { process_id, channel, payload } => {
self.notifications.push_back(Notification {
pid: pid,
process_id: process_id,
channel: channel,
payload: payload,
})

View File

@ -1,11 +1,8 @@
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::Message;
use std::io;
use std::io::prelude::*;
use std::mem;
use std::time::Duration;
use byteorder::{BigEndian, ReadBytesExt};
use types::Oid;
use priv_io::StreamOptions;
pub enum Backend {
AuthenticationCleartextPassword,
@ -18,8 +15,8 @@ pub enum Backend {
AuthenticationSCMCredential,
AuthenticationSSPI,
BackendKeyData {
process_id: u32,
secret_key: u32,
process_id: i32,
secret_key: i32,
},
BindComplete,
CloseComplete,
@ -50,7 +47,7 @@ pub enum Backend {
fields: Vec<(u8, String)>,
},
NotificationResponse {
pid: u32,
process_id: i32,
channel: String,
payload: String,
},
@ -71,6 +68,107 @@ pub enum Backend {
},
}
impl Backend {
pub fn convert(message: Message) -> io::Result<Backend> {
let ret = match message {
Message::AuthenticationCleartextPassword => Backend::AuthenticationCleartextPassword,
Message::AuthenticationGss => Backend::AuthenticationGSS,
Message::AuthenticationKerberosV5 => Backend::AuthenticationKerberosV5,
Message::AuthenticationMd55Password(body) => {
Backend::AuthenticationMD5Password { salt: body.salt() }
}
Message::AuthenticationOk => Backend::AuthenticationOk,
Message::AuthenticationScmCredential => Backend::AuthenticationSCMCredential,
Message::AuthenticationSspi => Backend::AuthenticationSSPI,
Message::BackendKeyData(body) => {
Backend::BackendKeyData {
process_id: body.process_id(),
secret_key: body.secret_key(),
}
}
Message::BindComplete => Backend::BindComplete,
Message::CloseComplete => Backend::CloseComplete,
Message::CommandComplete(body) => {
Backend::CommandComplete {
tag: body.tag().to_owned()
}
}
Message::CopyData(body) => Backend::CopyData { data: body.data().to_owned() },
Message::CopyDone => Backend::CopyDone,
Message::CopyInResponse(body) => {
Backend::CopyInResponse {
format: body.format(),
column_formats: try!(body.column_formats().collect()),
}
}
Message::CopyOutResponse(body) => {
Backend::CopyOutResponse {
format: body.format(),
column_formats: try!(body.column_formats().collect()),
}
}
Message::DataRow(body) => {
Backend::DataRow {
row: try!(body.values().map(|r| r.map(|d| d.to_owned())).collect()),
}
}
Message::EmptyQueryResponse => Backend::EmptyQueryResponse,
Message::ErrorResponse(body) => {
Backend::ErrorResponse {
fields: try!(body.fields().map(|f| (f.type_(), f.value().to_owned())).collect()),
}
}
Message::NoData => Backend::NoData,
Message::NoticeResponse(body) => {
Backend::NoticeResponse {
fields: try!(body.fields().map(|f| (f.type_(), f.value().to_owned())).collect()),
}
}
Message::NotificationResponse(body) => {
Backend::NotificationResponse {
process_id: body.process_id(),
channel: body.channel().to_owned(),
payload: body.message().to_owned(),
}
}
Message::ParameterDescription(body) => {
Backend::ParameterDescription {
types: try!(body.parameters().collect()),
}
}
Message::ParameterStatus(body) => {
Backend::ParameterStatus {
parameter: body.name().to_owned(),
value: body.value().to_owned(),
}
}
Message::ParseComplete => Backend::ParseComplete,
Message::PortalSuspended => Backend::PortalSuspended,
Message::ReadyForQuery(body) => Backend::ReadyForQuery { _state: body.status() },
Message::RowDescription(body) => {
let fields = body.fields()
.map(|f| {
RowDescriptionEntry {
name: f.name().to_owned(),
table_oid: f.table_oid(),
column_id: f.column_id(),
type_oid: f.type_oid(),
type_size: f.type_size(),
type_modifier: f.type_modifier(),
format: f.format(),
}
});
Backend::RowDescription {
descriptions: try!(fields.collect()),
}
}
_ => return Err(io::Error::new(io::ErrorKind::InvalidInput, "unknown message type")),
};
Ok(ret)
}
}
pub struct RowDescriptionEntry {
pub name: String,
pub table_oid: Oid,
@ -80,254 +178,3 @@ pub struct RowDescriptionEntry {
pub type_modifier: i32,
pub format: i16,
}
#[doc(hidden)]
trait ReadCStr {
fn read_cstr(&mut self) -> io::Result<String>;
}
impl<R: BufRead> ReadCStr for R {
fn read_cstr(&mut self) -> io::Result<String> {
let mut buf = vec![];
try!(self.read_until(0, &mut buf));
buf.pop();
String::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::Other, err))
}
}
#[doc(hidden)]
pub trait ReadMessage {
fn read_message(&mut self) -> io::Result<Backend>;
fn read_message_timeout(&mut self, timeout: Duration) -> io::Result<Option<Backend>>;
fn read_message_nonblocking(&mut self) -> io::Result<Option<Backend>>;
fn finish_read_message(&mut self, ident: u8) -> io::Result<Backend>;
}
impl<R: BufRead + StreamOptions> ReadMessage for R {
fn read_message(&mut self) -> io::Result<Backend> {
let ident = try!(self.read_u8());
self.finish_read_message(ident)
}
fn read_message_timeout(&mut self, timeout: Duration) -> io::Result<Option<Backend>> {
try!(self.set_read_timeout(Some(timeout)));
let ident = self.read_u8();
try!(self.set_read_timeout(None));
match ident {
Ok(ident) => self.finish_read_message(ident).map(Some),
Err(e) => {
let e: io::Error = e.into();
if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut {
Ok(None)
} else {
Err(e)
}
}
}
}
fn read_message_nonblocking(&mut self) -> io::Result<Option<Backend>> {
try!(self.set_nonblocking(true));
let ident = self.read_u8();
try!(self.set_nonblocking(false));
match ident {
Ok(ident) => self.finish_read_message(ident).map(Some),
Err(e) => {
let e: io::Error = e.into();
if e.kind() == io::ErrorKind::WouldBlock {
Ok(None)
} else {
Err(e)
}
}
}
}
#[allow(cyclomatic_complexity)]
fn finish_read_message(&mut self, ident: u8) -> io::Result<Backend> {
// subtract size of length value
let len = try!(self.read_u32::<BigEndian>()) - mem::size_of::<u32>() as u32;
let mut rdr = self.by_ref().take(len as u64);
let ret = match ident {
b'1' => Backend::ParseComplete,
b'2' => Backend::BindComplete,
b'3' => Backend::CloseComplete,
b'A' => {
Backend::NotificationResponse {
pid: try!(rdr.read_u32::<BigEndian>()),
channel: try!(rdr.read_cstr()),
payload: try!(rdr.read_cstr()),
}
}
b'c' => Backend::CopyDone,
b'C' => Backend::CommandComplete { tag: try!(rdr.read_cstr()) },
b'd' => {
let mut data = vec![];
try!(rdr.read_to_end(&mut data));
Backend::CopyData { data: data }
}
b'D' => try!(read_data_row(&mut rdr)),
b'E' => Backend::ErrorResponse { fields: try!(read_fields(&mut rdr)) },
b'G' => {
let format = try!(rdr.read_u8());
let mut column_formats = vec![];
for _ in 0..try!(rdr.read_u16::<BigEndian>()) {
column_formats.push(try!(rdr.read_u16::<BigEndian>()));
}
Backend::CopyInResponse {
format: format,
column_formats: column_formats,
}
}
b'H' => {
let format = try!(rdr.read_u8());
let mut column_formats = vec![];
for _ in 0..try!(rdr.read_u16::<BigEndian>()) {
column_formats.push(try!(rdr.read_u16::<BigEndian>()));
}
Backend::CopyOutResponse {
format: format,
column_formats: column_formats,
}
}
b'I' => Backend::EmptyQueryResponse,
b'K' => {
Backend::BackendKeyData {
process_id: try!(rdr.read_u32::<BigEndian>()),
secret_key: try!(rdr.read_u32::<BigEndian>()),
}
}
b'n' => Backend::NoData,
b'N' => Backend::NoticeResponse { fields: try!(read_fields(&mut rdr)) },
b'R' => try!(read_auth_message(&mut rdr)),
b's' => Backend::PortalSuspended,
b'S' => {
Backend::ParameterStatus {
parameter: try!(rdr.read_cstr()),
value: try!(rdr.read_cstr()),
}
}
b't' => try!(read_parameter_description(&mut rdr)),
b'T' => try!(read_row_description(&mut rdr)),
b'Z' => Backend::ReadyForQuery { _state: try!(rdr.read_u8()) },
t => {
return Err(io::Error::new(io::ErrorKind::Other,
format!("unexpected message tag `{}`", t)))
}
};
if rdr.limit() != 0 {
return Err(io::Error::new(io::ErrorKind::Other, "didn't read entire message"));
}
Ok(ret)
}
}
fn read_fields<R: BufRead>(buf: &mut R) -> io::Result<Vec<(u8, String)>> {
let mut fields = vec![];
loop {
let ty = try!(buf.read_u8());
if ty == 0 {
break;
}
fields.push((ty, try!(buf.read_cstr())));
}
Ok(fields)
}
fn read_data_row<R: BufRead>(buf: &mut R) -> io::Result<Backend> {
let len = try!(buf.read_u16::<BigEndian>()) as usize;
let mut values = Vec::with_capacity(len);
for _ in 0..len {
let val = match try!(buf.read_i32::<BigEndian>()) {
-1 => None,
len => {
let mut data = vec![0; len as usize];
try!(buf.read_exact(&mut data));
Some(data)
}
};
values.push(val);
}
Ok(Backend::DataRow { row: values })
}
fn read_auth_message<R: Read>(buf: &mut R) -> io::Result<Backend> {
Ok(match try!(buf.read_i32::<BigEndian>()) {
0 => Backend::AuthenticationOk,
2 => Backend::AuthenticationKerberosV5,
3 => Backend::AuthenticationCleartextPassword,
5 => {
let mut salt = [0; 4];
try!(buf.read_exact(&mut salt));
Backend::AuthenticationMD5Password { salt: salt }
}
6 => Backend::AuthenticationSCMCredential,
7 => Backend::AuthenticationGSS,
9 => Backend::AuthenticationSSPI,
t => {
return Err(io::Error::new(io::ErrorKind::Other,
format!("unexpected authentication tag `{}`", t)))
}
})
}
fn read_parameter_description<R: Read>(buf: &mut R) -> io::Result<Backend> {
let len = try!(buf.read_u16::<BigEndian>()) as usize;
let mut types = Vec::with_capacity(len);
for _ in 0..len {
types.push(try!(buf.read_u32::<BigEndian>()));
}
Ok(Backend::ParameterDescription { types: types })
}
fn read_row_description<R: BufRead>(buf: &mut R) -> io::Result<Backend> {
let len = try!(buf.read_u16::<BigEndian>()) as usize;
let mut types = Vec::with_capacity(len);
for _ in 0..len {
types.push(RowDescriptionEntry {
name: try!(buf.read_cstr()),
table_oid: try!(buf.read_u32::<BigEndian>()),
column_id: try!(buf.read_i16::<BigEndian>()),
type_oid: try!(buf.read_u32::<BigEndian>()),
type_size: try!(buf.read_i16::<BigEndian>()),
type_modifier: try!(buf.read_i32::<BigEndian>()),
format: try!(buf.read_i16::<BigEndian>()),
})
}
Ok(Backend::RowDescription { descriptions: types })
}
trait FromUsize: Sized {
fn from_usize(x: usize) -> io::Result<Self>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::max_value() as usize {
Err(io::Error::new(io::ErrorKind::InvalidInput, "value too large to transmit"))
} else {
Ok(x as $t)
}
}
}
}
}
from_usize!(u16);
from_usize!(i32);

View File

@ -11,7 +11,7 @@ use error::Error;
#[derive(Clone, Debug)]
pub struct Notification {
/// The process ID of the notifying backend process.
pub pid: u32,
pub process_id: i32,
/// The name of the channel that the notify has been raised on.
pub channel: String,
/// The "payload" string passed from the notifying process.
@ -110,9 +110,9 @@ impl<'a> Iterator for Iter<'a> {
}
match conn.read_message_with_notification_nonblocking() {
Ok(Some(Backend::NotificationResponse { pid, channel, payload })) => {
Ok(Some(Backend::NotificationResponse { process_id, channel, payload })) => {
Some(Ok(Notification {
pid: pid,
process_id: process_id,
channel: channel,
payload: payload,
}))
@ -148,9 +148,9 @@ impl<'a> Iterator for BlockingIter<'a> {
}
match conn.read_message_with_notification() {
Ok(Backend::NotificationResponse { pid, channel, payload }) => {
Ok(Backend::NotificationResponse { process_id, channel, payload }) => {
Some(Ok(Notification {
pid: pid,
process_id: process_id,
channel: channel,
payload: payload,
}))
@ -187,9 +187,9 @@ impl<'a> Iterator for TimeoutIter<'a> {
}
match conn.read_message_with_notification_timeout(self.timeout) {
Ok(Some(Backend::NotificationResponse { pid, channel, payload })) => {
Ok(Some(Backend::NotificationResponse { process_id, channel, payload })) => {
Some(Ok(Notification {
pid: pid,
process_id: process_id,
channel: channel,
payload: payload,
}))

View File

@ -15,9 +15,10 @@ use postgres_protocol::message::frontend;
use postgres_protocol::message::backend::{self, ParseResult};
use TlsMode;
use params::{ConnectParams, ConnectTarget};
use error::ConnectError;
use io::TlsStream;
use message::Backend;
use params::{ConnectParams, ConnectTarget};
const DEFAULT_PORT: u16 = 5432;
const MESSAGE_HEADER_SIZE: usize = 5;
@ -45,9 +46,10 @@ impl MessageStream {
self.stream.write_all(&self.buf)
}
pub fn read_message<'a>(&'a mut self) -> io::Result<backend::Message<'a>> {
fn raw_read_message<'a>(&'a mut self, b: u8) -> io::Result<backend::Message<'a>> {
self.buf.resize(MESSAGE_HEADER_SIZE, 0);
try!(self.stream.read_exact(&mut self.buf));
self.buf[0] = b;
try!(self.stream.read_exact(&mut self.buf[1..]));
let len = match try!(backend::Message::parse(&self.buf)) {
// FIXME this is dumb but an explicit return runs into borrowck issues :(
@ -66,34 +68,46 @@ impl MessageStream {
}
}
fn inner_read_message(&mut self, b: u8) -> io::Result<Backend> {
let message = try!(self.raw_read_message(b));
Backend::convert(message)
}
pub fn read_message(&mut self) -> io::Result<Backend> {
let b = try!(self.stream.read_u8());
self.inner_read_message(b)
}
pub fn read_message_timeout(&mut self, timeout: Duration) -> io::Result<Option<Backend>> {
try!(self.set_read_timeout(Some(timeout)));
let b = self.stream.read_u8();
try!(self.set_read_timeout(None));
match b {
Ok(b) => self.inner_read_message(b).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock || e.kind() == io::ErrorKind::TimedOut => {
Ok(None)
}
Err(e) => Err(e),
}
}
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<Backend>> {
try!(self.set_nonblocking(true));
let b = self.stream.read_u8();
try!(self.set_nonblocking(false));
match b {
Ok(b) => self.inner_read_message(b).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(e),
}
}
pub fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
}
}
impl io::Read for MessageStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.stream.read(buf)
}
}
impl io::BufRead for MessageStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
#[doc(hidden)]
pub trait StreamOptions {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>;
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()>;
}
impl StreamOptions for MessageStream {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match self.stream.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_read_timeout(timeout),

View File

@ -576,12 +576,12 @@ fn test_notification_iterator_some() {
or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel2, 'world'", &[]));
check_notification(Notification {
pid: 0,
process_id: 0,
channel: "test_notification_iterator_one_channel".to_string(),
payload: "hello".to_string()
}, it.next().unwrap().unwrap());
check_notification(Notification {
pid: 0,
process_id: 0,
channel: "test_notification_iterator_one_channel2".to_string(),
payload: "world".to_string()
}, it.next().unwrap().unwrap());
@ -589,7 +589,7 @@ fn test_notification_iterator_some() {
or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel, '!'", &[]));
check_notification(Notification {
pid: 0,
process_id: 0,
channel: "test_notification_iterator_one_channel".to_string(),
payload: "!".to_string()
}, it.next().unwrap().unwrap());
@ -609,7 +609,7 @@ fn test_notifications_next_block() {
let notifications = conn.notifications();
check_notification(Notification {
pid: 0,
process_id: 0,
channel: "test_notifications_next_block".to_string(),
payload: "foo".to_string()
}, or_panic!(notifications.blocking_iter().next().unwrap()));
@ -631,7 +631,7 @@ fn test_notification_next_timeout() {
let notifications = conn.notifications();
let mut it = notifications.timeout_iter(Duration::from_secs(1));
check_notification(Notification {
pid: 0,
process_id: 0,
channel: "test_notifications_next_timeout".to_string(),
payload: "foo".to_string()
}, or_panic!(it.next().unwrap()));