Use the bytes crate for backend message parsing (#253)

This commit is contained in:
Steven Fackler 2017-05-06 08:28:07 -07:00 committed by GitHub
parent 413d1db5cd
commit 6b008766bf
9 changed files with 275 additions and 324 deletions

View File

@ -9,6 +9,7 @@ documentation = "https://docs.rs/postgres-protocol/0.2.2/postgres_protocol"
readme = "../README.md"
[dependencies]
bytes = "0.4"
byteorder = "1.0"
fallible-iterator = "0.1"
md5 = "0.3"

View File

@ -11,6 +11,7 @@
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![doc(html_root_url="https://docs.rs/postgres-protocol/0.2.2")]
#![warn(missing_docs)]
extern crate bytes;
extern crate byteorder;
extern crate fallible_iterator;
extern crate md5;

View File

@ -1,138 +1,118 @@
#![allow(missing_docs)]
use byteorder::{ReadBytesExt, BigEndian};
use memchr::memchr;
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use memchr::memchr;
use std::cmp;
use std::io::{self, Read};
use std::marker::PhantomData;
use std::ops::Deref;
use std::ops::Range;
use std::str;
use Oid;
/// An enum representing Postgres backend messages.
pub enum Message<T> {
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
AuthenticationKerberosV5,
AuthenticationMd5Password(AuthenticationMd5PasswordBody<T>),
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
AuthenticationOk,
AuthenticationScmCredential,
AuthenticationSspi,
BackendKeyData(BackendKeyDataBody<T>),
BackendKeyData(BackendKeyDataBody),
BindComplete,
CloseComplete,
CommandComplete(CommandCompleteBody<T>),
CopyData(CopyDataBody<T>),
CommandComplete(CommandCompleteBody),
CopyData(CopyDataBody),
CopyDone,
CopyInResponse(CopyInResponseBody<T>),
CopyOutResponse(CopyOutResponseBody<T>),
DataRow(DataRowBody<T>),
CopyInResponse(CopyInResponseBody),
CopyOutResponse(CopyOutResponseBody),
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody<T>),
ErrorResponse(ErrorResponseBody),
NoData,
NoticeResponse(NoticeResponseBody<T>),
NotificationResponse(NotificationResponseBody<T>),
ParameterDescription(ParameterDescriptionBody<T>),
ParameterStatus(ParameterStatusBody<T>),
NoticeResponse(NoticeResponseBody),
NotificationResponse(NotificationResponseBody),
ParameterDescription(ParameterDescriptionBody),
ParameterStatus(ParameterStatusBody),
ParseComplete,
PortalSuspended,
ReadyForQuery(ReadyForQueryBody<T>),
RowDescription(RowDescriptionBody<T>),
ReadyForQuery(ReadyForQueryBody),
RowDescription(RowDescriptionBody),
#[doc(hidden)]
__ForExtensibility,
}
impl<'a> Message<&'a [u8]> {
/// Attempts to parse a backend message from the buffer.
///
/// This method is unfortunately difficult to use due to deficiencies in the compiler's borrow
/// checker.
impl Message {
#[inline]
pub fn parse(buf: &'a [u8]) -> io::Result<ParseResult<&'a [u8]>> {
Message::parse_inner(buf)
}
}
impl Message<Vec<u8>> {
/// Attempts to parse a backend message from the buffer.
///
/// In contrast to `parse`, this method produces messages that do not reference the input,
/// buffer by copying any necessary portions internally.
#[inline]
pub fn parse_owned(buf: &[u8]) -> io::Result<ParseResult<Vec<u8>>> {
Message::parse_inner(buf)
}
}
impl<'a, T> Message<T>
where T: From<&'a [u8]>
{
#[inline]
fn parse_inner(buf: &'a [u8]) -> io::Result<ParseResult<T>> {
pub fn parse(buf: &mut BytesMut) -> io::Result<Option<Message>> {
if buf.len() < 5 {
return Ok(ParseResult::Incomplete { required_size: None });
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let mut r = buf;
let tag = r.read_u8().unwrap();
// add a byte for the tag
let len = r.read_u32::<BigEndian>().unwrap() as usize + 1;
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if buf.len() < len {
return Ok(ParseResult::Incomplete { required_size: Some(len) });
if len < 4 {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length"));
}
let mut buf = &buf[5..len];
let total_len = len as usize + 1;
if buf.len() < total_len {
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let mut buf = Buffer {
bytes: buf.split_to(total_len).freeze(),
idx: 5,
};
let message = match tag {
b'1' => Message::ParseComplete,
b'2' => Message::BindComplete,
b'3' => Message::CloseComplete,
b'A' => {
let process_id = try!(buf.read_i32::<BigEndian>());
let channel_end = try!(find_null(buf, 0));
let message_end = try!(find_null(buf, channel_end + 1));
let storage = buf[..message_end].into();
buf = &buf[message_end + 1..];
let channel = try!(buf.read_cstr());
let message = try!(buf.read_cstr());
Message::NotificationResponse(NotificationResponseBody {
storage: storage,
process_id: process_id,
channel_end: channel_end,
channel: channel,
message: message,
})
}
b'c' => Message::CopyDone,
b'C' => {
let tag_end = try!(find_null(buf, 0));
let storage = buf[..tag_end].into();
buf = &buf[tag_end + 1..];
let tag = try!(buf.read_cstr());
Message::CommandComplete(CommandCompleteBody {
storage: storage,
tag: tag,
})
}
b'd' => {
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::CopyData(CopyDataBody { storage: storage })
}
b'D' => {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::DataRow(DataRowBody {
storage: storage,
len: len,
})
}
b'E' => {
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::ErrorResponse(ErrorResponseBody { storage: storage })
}
b'G' => {
let format = try!(buf.read_u8());
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::CopyInResponse(CopyInResponseBody {
format: format,
len: len,
@ -142,8 +122,7 @@ impl<'a, T> Message<T>
b'H' => {
let format = try!(buf.read_u8());
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::CopyOutResponse(CopyOutResponseBody {
format: format,
len: len,
@ -157,13 +136,11 @@ impl<'a, T> Message<T>
Message::BackendKeyData(BackendKeyDataBody {
process_id: process_id,
secret_key: secret_key,
_p: PhantomData,
})
}
b'n' => Message::NoData,
b'N' => {
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::NoticeResponse(NoticeResponseBody {
storage: storage,
})
@ -178,7 +155,6 @@ impl<'a, T> Message<T>
try!(buf.read_exact(&mut salt));
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody {
salt: salt,
_p: PhantomData,
})
}
6 => Message::AuthenticationScmCredential,
@ -192,19 +168,16 @@ impl<'a, T> Message<T>
}
b's' => Message::PortalSuspended,
b'S' => {
let name_end = try!(find_null(buf, 0));
let value_end = try!(find_null(buf, name_end + 1));
let storage = buf[0..value_end].into();
buf = &buf[value_end + 1..];
let name = try!(buf.read_cstr());
let value = try!(buf.read_cstr());
Message::ParameterStatus(ParameterStatusBody {
storage: storage,
name_end: name_end,
name: name,
value: value,
})
}
b't' => {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::ParameterDescription(ParameterDescriptionBody {
storage: storage,
len: len,
@ -212,8 +185,7 @@ impl<'a, T> Message<T>
}
b'T' => {
let len = try!(buf.read_u16::<BigEndian>());
let storage = buf.into();
buf = &[];
let storage = buf.read_all();
Message::RowDescription(RowDescriptionBody {
storage: storage,
len: len,
@ -223,7 +195,6 @@ impl<'a, T> Message<T>
let status = try!(buf.read_u8());
Message::ReadyForQuery(ReadyForQueryBody {
status: status,
_p: PhantomData,
})
}
tag => {
@ -236,54 +207,74 @@ impl<'a, T> Message<T>
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length"));
}
Ok(ParseResult::Complete {
message: message,
consumed: len,
})
Ok(Some(message))
}
}
/// The result of an attempted parse.
pub enum ParseResult<T> {
/// The message was successfully parsed.
Complete {
/// The message.
message: Message<T>,
/// The number of bytes of the input buffer consumed to parse this message.
consumed: usize,
},
/// The buffer did not contain a full message.
Incomplete {
/// The number of total bytes required to parse a message, if known.
///
/// This value is present if the input buffer contains at least 5 bytes.
required_size: Option<usize>,
struct Buffer {
bytes: Bytes,
idx: usize,
}
impl Buffer {
fn slice(&self) -> &[u8] {
&self.bytes[self.idx..]
}
fn is_empty(&self) -> bool {
self.slice().is_empty()
}
fn read_cstr(&mut self) -> io::Result<Bytes> {
match memchr(0, self.slice()) {
Some(pos) => {
let start = self.idx;
let end = start + pos;
let cstr = self.bytes.slice(start, end);
self.idx = end + 1;
Ok(cstr)
}
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
}
}
fn read_all(&mut self) -> Bytes {
let buf = self.bytes.slice_from(self.idx);
self.idx = self.bytes.len();
buf
}
}
pub struct AuthenticationMd5PasswordBody<T> {
impl Read for Buffer {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = {
let slice = self.slice();
let len = cmp::min(slice.len(), buf.len());
buf[..len].copy_from_slice(&slice[..len]);
len
};
self.idx += len;
Ok(len)
}
}
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
_p: PhantomData<T>,
}
impl<T> AuthenticationMd5PasswordBody<T>
where T: Deref<Target = [u8]>
{
impl AuthenticationMd5PasswordBody {
#[inline]
pub fn salt(&self) -> [u8; 4] {
self.salt
}
}
pub struct BackendKeyDataBody<T> {
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
_p: PhantomData<T>,
}
impl<T> BackendKeyDataBody<T>
where T: Deref<Target = [u8]>
{
impl BackendKeyDataBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
@ -295,41 +286,35 @@ impl<T> BackendKeyDataBody<T>
}
}
pub struct CommandCompleteBody<T> {
storage: T,
pub struct CommandCompleteBody {
tag: Bytes,
}
impl<T> CommandCompleteBody<T>
where T: Deref<Target = [u8]>
{
impl CommandCompleteBody {
#[inline]
pub fn tag(&self) -> io::Result<&str> {
get_str(&self.storage)
get_str(&self.tag)
}
}
pub struct CopyDataBody<T> {
storage: T,
pub struct CopyDataBody {
storage: Bytes,
}
impl<T> CopyDataBody<T>
where T: Deref<Target = [u8]>
{
impl CopyDataBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.storage
}
}
pub struct CopyInResponseBody<T> {
storage: T,
pub struct CopyInResponseBody {
storage: Bytes,
len: u16,
format: u8,
}
impl<T> CopyInResponseBody<T>
where T: Deref<Target = [u8]>
{
impl CopyInResponseBody {
#[inline]
pub fn format(&self) -> u8 {
self.format
@ -374,15 +359,13 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
}
}
pub struct CopyOutResponseBody<T> {
storage: T,
pub struct CopyOutResponseBody {
storage: Bytes,
len: u16,
format: u8,
}
impl<T> CopyOutResponseBody<T>
where T: Deref<Target = [u8]>
{
impl CopyOutResponseBody {
#[inline]
pub fn format(&self) -> u8 {
self.format
@ -397,34 +380,39 @@ impl<T> CopyOutResponseBody<T>
}
}
pub struct DataRowBody<T> {
storage: T,
pub struct DataRowBody {
storage: Bytes,
len: u16,
}
impl<T> DataRowBody<T>
where T: Deref<Target = [u8]>
{
impl DataRowBody {
#[inline]
pub fn values<'a>(&'a self) -> DataRowValues<'a> {
DataRowValues {
pub fn ranges<'a>(&'a self) -> DataRowRanges<'a> {
DataRowRanges {
buf: &self.storage,
len: self.storage.len(),
remaining: self.len,
}
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.storage
}
}
pub struct DataRowValues<'a> {
pub struct DataRowRanges<'a> {
buf: &'a [u8],
len: usize,
remaining: u16,
}
impl<'a> FallibleIterator for DataRowValues<'a> {
type Item = Option<&'a [u8]>;
impl<'a> FallibleIterator for DataRowRanges<'a> {
type Item = Option<Range<usize>>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Option<&'a [u8]>>> {
fn next(&mut self) -> io::Result<Option<Option<Range<usize>>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
@ -442,9 +430,9 @@ impl<'a> FallibleIterator for DataRowValues<'a> {
if self.buf.len() < len {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"));
}
let (head, tail) = self.buf.split_at(len);
self.buf = tail;
Ok(Some(Some(head)))
let base = self.len - self.buf.len();
self.buf = &self.buf[len as usize..];
Ok(Some(Some(base..base + len)))
}
}
@ -455,18 +443,14 @@ impl<'a> FallibleIterator for DataRowValues<'a> {
}
}
pub struct ErrorResponseBody<T> {
storage: T,
pub struct ErrorResponseBody {
storage: Bytes,
}
impl<T> ErrorResponseBody<T>
where T: Deref<Target = [u8]>
{
impl ErrorResponseBody {
#[inline]
pub fn fields<'a>(&'a self) -> ErrorFields<'a> {
ErrorFields {
buf: &self.storage
}
ErrorFields { buf: &self.storage }
}
}
@ -494,9 +478,9 @@ impl<'a> FallibleIterator for ErrorFields<'a> {
self.buf = &self.buf[value_end + 1..];
Ok(Some(ErrorField {
type_: type_,
value: value,
}))
type_: type_,
value: value,
}))
}
}
@ -517,30 +501,24 @@ impl<'a> ErrorField<'a> {
}
}
pub struct NoticeResponseBody<T> {
storage: T,
pub struct NoticeResponseBody {
storage: Bytes,
}
impl<T> NoticeResponseBody<T>
where T: Deref<Target = [u8]>
{
impl NoticeResponseBody {
#[inline]
pub fn fields<'a>(&'a self) -> ErrorFields<'a> {
ErrorFields {
buf: &self.storage
}
ErrorFields { buf: &self.storage }
}
}
pub struct NotificationResponseBody<T> {
storage: T,
pub struct NotificationResponseBody {
process_id: i32,
channel_end: usize,
channel: Bytes,
message: Bytes,
}
impl<T> NotificationResponseBody<T>
where T: Deref<Target = [u8]>
{
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
@ -548,23 +526,21 @@ impl<T> NotificationResponseBody<T>
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.storage[..self.channel_end])
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.storage[self.channel_end + 1..])
get_str(&self.message)
}
}
pub struct ParameterDescriptionBody<T> {
storage: T,
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
}
impl<T> ParameterDescriptionBody<T>
where T: Deref<Target = [u8]>
{
impl ParameterDescriptionBody {
#[inline]
pub fn parameters<'a>(&'a self) -> Parameters<'a> {
Parameters {
@ -604,47 +580,40 @@ impl<'a> FallibleIterator for Parameters<'a> {
}
}
pub struct ParameterStatusBody<T> {
storage: T,
name_end: usize,
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
}
impl<T> ParameterStatusBody<T>
where T: Deref<Target = [u8]>
{
impl ParameterStatusBody {
#[inline]
pub fn name(&self) -> io::Result<&str> {
get_str(&self.storage[..self.name_end])
get_str(&self.name)
}
#[inline]
pub fn value(&self) -> io::Result<&str> {
get_str(&self.storage[self.name_end + 1..])
get_str(&self.value)
}
}
pub struct ReadyForQueryBody<T> {
pub struct ReadyForQueryBody {
status: u8,
_p: PhantomData<T>,
}
impl<T> ReadyForQueryBody<T>
where T: Deref<Target = [u8]>
{
impl ReadyForQueryBody {
#[inline]
pub fn status(&self) -> u8 {
self.status
}
}
pub struct RowDescriptionBody<T> {
storage: T,
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
}
impl<T> RowDescriptionBody<T>
where T: Deref<Target = [u8]>
{
impl RowDescriptionBody {
#[inline]
pub fn fields<'a>(&'a self) -> Fields<'a> {
Fields {
@ -685,14 +654,14 @@ impl<'a> FallibleIterator for Fields<'a> {
let format = try!(self.buf.read_i16::<BigEndian>());
Ok(Some(Field {
name: name,
table_oid: table_oid,
column_id: column_id,
type_oid: type_oid,
type_size: type_size,
type_modifier: type_modifier,
format: format,
}))
name: name,
table_oid: table_oid,
column_id: column_id,
type_oid: type_oid,
type_size: type_size,
type_modifier: type_modifier,
format: format,
}))
}
}
@ -747,11 +716,11 @@ impl<'a> Field<'a> {
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
match memchr(0, &buf[start..]) {
Some(pos) => Ok(pos + start),
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
}
}
#[inline]
fn get_str(buf: &[u8]) -> io::Result<&str> {
str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}
}

View File

@ -1,5 +1,7 @@
use fallible_iterator::{FallibleIterator, FromFallibleIterator};
use fallible_iterator::{FallibleIterator};
use postgres_protocol::message::backend::DataRowBody;
use std::ascii::AsciiExt;
use std::io;
use std::ops::Range;
use stmt::Column;
@ -32,12 +34,13 @@ impl<'a> RowIndex for str {
// FIXME ASCII-only case insensitivity isn't really the right thing to
// do. Postgres itself uses a dubious wrapper around tolower and JDBC
// uses the US locale.
stmt.iter().position(|d| d.name().eq_ignore_ascii_case(self))
stmt.iter()
.position(|d| d.name().eq_ignore_ascii_case(self))
}
}
impl<'a, T: ?Sized> RowIndex for &'a T
where T: RowIndex
where T: RowIndex
{
#[inline]
fn idx(&self, columns: &[Column]) -> Option<usize> {
@ -47,43 +50,26 @@ where T: RowIndex
#[doc(hidden)]
pub struct RowData {
buf: Vec<u8>,
indices: Vec<Option<Range<usize>>>,
}
impl<'a> FromFallibleIterator<Option<&'a [u8]>> for RowData {
fn from_fallible_iterator<I>(mut it: I) -> Result<RowData, I::Error>
where I: FallibleIterator<Item = Option<&'a [u8]>>
{
let mut row = RowData {
buf: vec![],
indices: Vec::with_capacity(it.size_hint().0),
};
while let Some(cell) = it.next()? {
let index = match cell {
Some(cell) => {
let base = row.buf.len();
row.buf.extend_from_slice(cell);
Some(base..row.buf.len())
}
None => None,
};
row.indices.push(index);
}
Ok(row)
}
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl RowData {
pub fn new(body: DataRowBody) -> io::Result<RowData> {
let ranges = body.ranges().collect()?;
Ok(RowData {
body: body,
ranges: ranges,
})
}
pub fn len(&self) -> usize {
self.indices.len()
self.ranges.len()
}
pub fn get(&self, index: usize) -> Option<&[u8]> {
match &self.indices[index] {
&Some(ref range) => Some(&self.buf[range.clone()]),
match &self.ranges[index] {
&Some(ref range) => Some(&self.body.buffer()[range.clone()]),
&None => None,
}
}

View File

@ -42,7 +42,7 @@ with-security-framework = ["security-framework"]
no-logging = []
[dependencies]
bufstream = "0.1"
bytes = "0.4"
fallible-iterator = "0.1.3"
log = "0.3"

View File

@ -70,7 +70,7 @@
#![warn(missing_docs)]
#![allow(unknown_lints, needless_lifetimes, doc_markdown)] // for clippy
extern crate bufstream;
extern crate bytes;
extern crate fallible_iterator;
#[cfg(not(feature = "no-logging"))]
#[macro_use]
@ -315,7 +315,7 @@ impl InnerConnection {
Ok(conn)
}
fn read_message_with_notification(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
fn read_message_with_notification(&mut self) -> io::Result<backend::Message> {
debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, self.stream.read_message()) {
@ -335,7 +335,7 @@ impl InnerConnection {
fn read_message_with_notification_timeout(&mut self,
timeout: Duration)
-> io::Result<Option<backend::Message<Vec<u8>>>> {
-> io::Result<Option<backend::Message>> {
debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, self.stream.read_message_timeout(timeout)) {
@ -354,7 +354,7 @@ impl InnerConnection {
}
fn read_message_with_notification_nonblocking(&mut self)
-> io::Result<Option<backend::Message<Vec<u8>>>> {
-> io::Result<Option<backend::Message>> {
debug_assert!(!self.desynchronized);
loop {
match try_desync!(self, self.stream.read_message_nonblocking()) {
@ -372,7 +372,7 @@ impl InnerConnection {
}
}
fn read_message(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
fn read_message(&mut self) -> io::Result<backend::Message> {
loop {
match self.read_message_with_notification()? {
backend::Message::NotificationResponse(body) => {
@ -495,7 +495,7 @@ impl InnerConnection {
more_rows = true;
break;
}
backend::Message::DataRow(body) => consumer(body.values().collect()?),
backend::Message::DataRow(body) => consumer(RowData::new(body)?),
backend::Message::ErrorResponse(body) => {
self.wait_for_ready()?;
return Err(err(&mut body.fields()));
@ -832,8 +832,8 @@ impl InnerConnection {
match self.read_message()? {
backend::Message::ReadyForQuery(_) => break,
backend::Message::DataRow(body) => {
let row = body.values()
.map(|v| v.map(|v| String::from_utf8_lossy(v).into_owned()))
let row = body.ranges()
.map(|r| r.map(|r| String::from_utf8_lossy(&body.buffer()[r]).into_owned()))
.collect()?;
result.push(row);
}

View File

@ -1,9 +1,8 @@
use std::io;
use std::io::prelude::*;
use std::io::{self, BufWriter, Read, Write};
use std::fmt;
use std::net::TcpStream;
use std::time::Duration;
use bufstream::BufStream;
use bytes::{BufMut, BytesMut};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
#[cfg(unix)]
@ -11,25 +10,27 @@ use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};
use postgres_protocol::message::frontend;
use postgres_protocol::message::backend::{self, ParseResult};
use postgres_protocol::message::backend;
use TlsMode;
use error::ConnectError;
use tls::TlsStream;
use params::{ConnectParams, Host};
const MESSAGE_HEADER_SIZE: usize = 5;
const INITIAL_CAPACITY: usize = 8 * 1024;
pub struct MessageStream {
stream: BufStream<Box<TlsStream>>,
buf: Vec<u8>,
stream: BufWriter<Box<TlsStream>>,
in_buf: BytesMut,
out_buf: Vec<u8>,
}
impl MessageStream {
pub fn new(stream: Box<TlsStream>) -> MessageStream {
MessageStream {
stream: BufStream::new(stream),
buf: vec![],
stream: BufWriter::new(stream),
in_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
out_buf: vec![],
}
}
@ -41,65 +42,66 @@ impl MessageStream {
where F: FnOnce(&mut Vec<u8>) -> Result<(), E>,
E: From<io::Error>
{
self.buf.clear();
f(&mut self.buf)?;
self.stream.write_all(&self.buf).map_err(From::from)
self.out_buf.clear();
f(&mut self.out_buf)?;
self.stream.write_all(&self.out_buf).map_err(From::from)
}
fn inner_read_message(&mut self, b: u8) -> io::Result<backend::Message<Vec<u8>>> {
self.buf.resize(MESSAGE_HEADER_SIZE, 0);
self.buf[0] = b;
self.stream.read_exact(&mut self.buf[1..])?;
let len = match backend::Message::parse_owned(&self.buf)? {
ParseResult::Complete { message, .. } => return Ok(message),
ParseResult::Incomplete { required_size } => Some(required_size.unwrap()),
};
if let Some(len) = len {
self.buf.resize(len, 0);
self.stream.read_exact(&mut self.buf[MESSAGE_HEADER_SIZE..])?;
};
match backend::Message::parse_owned(&self.buf)? {
ParseResult::Complete { message, .. } => Ok(message),
ParseResult::Incomplete { .. } => unreachable!(),
pub fn read_message(&mut self) -> io::Result<backend::Message> {
loop {
match backend::Message::parse(&mut self.in_buf) {
Ok(Some(message)) => return Ok(message),
Ok(None) => self.read_in()?,
Err(e) => return Err(e),
}
}
}
pub fn read_message(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
let mut b = [0; 1];
self.stream.read_exact(&mut b)?;
self.inner_read_message(b[0])
fn read_in(&mut self) -> io::Result<()> {
self.in_buf.reserve(1);
match self.stream.get_mut().read(unsafe { self.in_buf.bytes_mut() }) {
Ok(0) => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
Ok(n) => {
unsafe { self.in_buf.advance_mut(n) };
Ok(())
}
Err(e) => Err(e),
}
}
pub fn read_message_timeout(&mut self,
timeout: Duration)
-> io::Result<Option<backend::Message<Vec<u8>>>> {
self.set_read_timeout(Some(timeout))?;
let mut b = [0; 1];
let r = self.stream.read_exact(&mut b);
self.set_read_timeout(None)?;
-> io::Result<Option<backend::Message>> {
if self.in_buf.is_empty() {
self.set_read_timeout(Some(timeout))?;
let r = self.read_in();
self.set_read_timeout(None)?;
match r {
Ok(()) => self.inner_read_message(b[0]).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock ||
e.kind() == io::ErrorKind::TimedOut => Ok(None),
Err(e) => Err(e),
match r {
Ok(()) => {},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock ||
e.kind() == io::ErrorKind::TimedOut => return Ok(None),
Err(e) => return Err(e),
}
}
self.read_message().map(Some)
}
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message<Vec<u8>>>> {
self.set_nonblocking(true)?;
let mut b = [0; 1];
let r = self.stream.read_exact(&mut b);
self.set_nonblocking(false)?;
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message>> {
if self.in_buf.is_empty() {
self.set_nonblocking(true)?;
let r = self.read_in();
self.set_nonblocking(false)?;
match r {
Ok(()) => self.inner_read_message(b[0]).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(e),
match r {
Ok(()) => {},
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
Err(e) => return Err(e),
}
}
self.read_message().map(Some)
}
pub fn flush(&mut self) -> io::Result<()> {

View File

@ -174,7 +174,7 @@ struct InnerConnection {
}
impl InnerConnection {
fn read(self) -> IoFuture<(backend::Message<Vec<u8>>, InnerConnection)> {
fn read(self) -> IoFuture<(backend::Message, InnerConnection)> {
self.into_future()
.map_err(|e| e.0)
.and_then(|(m, mut s)| {
@ -209,10 +209,10 @@ impl InnerConnection {
}
impl Stream for InnerConnection {
type Item = backend::Message<Vec<u8>>;
type Item = backend::Message;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<backend::Message<Vec<u8>>>, io::Error> {
fn poll(&mut self) -> Poll<Option<backend::Message>, io::Error> {
loop {
match try_ready!(self.stream.poll()) {
Some(backend::Message::ParameterStatus(body)) => {
@ -439,7 +439,7 @@ impl Connection {
Ok((rows, Connection(s))).into_future().boxed()
}
backend::Message::DataRow(body) => {
match body.values().collect() {
match RowData::new(body) {
Ok(row) => {
rows.push(row);
Connection(s).simple_read_rows(rows)
@ -504,7 +504,7 @@ impl Connection {
.boxed()
}
fn ready_err<T>(self, body: ErrorResponseBody<Vec<u8>>) -> BoxFuture<T, Error>
fn ready_err<T>(self, body: ErrorResponseBody) -> BoxFuture<T, Error>
where T: 'static + Send
{
DbError::new(&mut body.fields())
@ -948,8 +948,7 @@ impl Connection {
let c = Connection(s);
match m {
backend::Message::DataRow(body) => {
Either::A(body.values()
.collect()
Either::A(RowData::new(body)
.map(|r| (Some(r), c))
.map_err(Error::Io)
.into_future())

View File

@ -2,7 +2,7 @@ use bytes::{BytesMut, BufMut};
use futures::{BoxFuture, Future, IntoFuture, Sink, Stream as FuturesStream, Poll};
use futures::future::Either;
use postgres_shared::params::Host;
use postgres_protocol::message::backend::{self, ParseResult};
use postgres_protocol::message::backend;
use postgres_protocol::message::frontend;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
@ -167,18 +167,11 @@ impl AsyncWrite for Stream {
pub struct PostgresCodec;
impl Decoder for PostgresCodec {
type Item = backend::Message<Vec<u8>>;
type Item = backend::Message;
type Error = io::Error;
// FIXME ideally we'd avoid re-copying the data
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<Self::Item>> {
match backend::Message::parse_owned(buf.as_ref())? {
ParseResult::Complete { message, consumed } => {
buf.split_to(consumed);
Ok(Some(message))
}
ParseResult::Incomplete { .. } => Ok(None),
}
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<backend::Message>> {
backend::Message::parse(buf)
}
}