Use the bytes crate for backend message parsing (#253)
This commit is contained in:
parent
413d1db5cd
commit
6b008766bf
@ -9,6 +9,7 @@ documentation = "https://docs.rs/postgres-protocol/0.2.2/postgres_protocol"
|
||||
readme = "../README.md"
|
||||
|
||||
[dependencies]
|
||||
bytes = "0.4"
|
||||
byteorder = "1.0"
|
||||
fallible-iterator = "0.1"
|
||||
md5 = "0.3"
|
||||
|
@ -11,6 +11,7 @@
|
||||
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
|
||||
#![doc(html_root_url="https://docs.rs/postgres-protocol/0.2.2")]
|
||||
#![warn(missing_docs)]
|
||||
extern crate bytes;
|
||||
extern crate byteorder;
|
||||
extern crate fallible_iterator;
|
||||
extern crate md5;
|
||||
|
@ -1,138 +1,118 @@
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use byteorder::{ReadBytesExt, BigEndian};
|
||||
use memchr::memchr;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use memchr::memchr;
|
||||
use std::cmp;
|
||||
use std::io::{self, Read};
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::Deref;
|
||||
use std::ops::Range;
|
||||
use std::str;
|
||||
|
||||
use Oid;
|
||||
|
||||
/// An enum representing Postgres backend messages.
|
||||
pub enum Message<T> {
|
||||
pub enum Message {
|
||||
AuthenticationCleartextPassword,
|
||||
AuthenticationGss,
|
||||
AuthenticationKerberosV5,
|
||||
AuthenticationMd5Password(AuthenticationMd5PasswordBody<T>),
|
||||
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
|
||||
AuthenticationOk,
|
||||
AuthenticationScmCredential,
|
||||
AuthenticationSspi,
|
||||
BackendKeyData(BackendKeyDataBody<T>),
|
||||
BackendKeyData(BackendKeyDataBody),
|
||||
BindComplete,
|
||||
CloseComplete,
|
||||
CommandComplete(CommandCompleteBody<T>),
|
||||
CopyData(CopyDataBody<T>),
|
||||
CommandComplete(CommandCompleteBody),
|
||||
CopyData(CopyDataBody),
|
||||
CopyDone,
|
||||
CopyInResponse(CopyInResponseBody<T>),
|
||||
CopyOutResponse(CopyOutResponseBody<T>),
|
||||
DataRow(DataRowBody<T>),
|
||||
CopyInResponse(CopyInResponseBody),
|
||||
CopyOutResponse(CopyOutResponseBody),
|
||||
DataRow(DataRowBody),
|
||||
EmptyQueryResponse,
|
||||
ErrorResponse(ErrorResponseBody<T>),
|
||||
ErrorResponse(ErrorResponseBody),
|
||||
NoData,
|
||||
NoticeResponse(NoticeResponseBody<T>),
|
||||
NotificationResponse(NotificationResponseBody<T>),
|
||||
ParameterDescription(ParameterDescriptionBody<T>),
|
||||
ParameterStatus(ParameterStatusBody<T>),
|
||||
NoticeResponse(NoticeResponseBody),
|
||||
NotificationResponse(NotificationResponseBody),
|
||||
ParameterDescription(ParameterDescriptionBody),
|
||||
ParameterStatus(ParameterStatusBody),
|
||||
ParseComplete,
|
||||
PortalSuspended,
|
||||
ReadyForQuery(ReadyForQueryBody<T>),
|
||||
RowDescription(RowDescriptionBody<T>),
|
||||
ReadyForQuery(ReadyForQueryBody),
|
||||
RowDescription(RowDescriptionBody),
|
||||
#[doc(hidden)]
|
||||
__ForExtensibility,
|
||||
}
|
||||
|
||||
impl<'a> Message<&'a [u8]> {
|
||||
/// Attempts to parse a backend message from the buffer.
|
||||
///
|
||||
/// This method is unfortunately difficult to use due to deficiencies in the compiler's borrow
|
||||
/// checker.
|
||||
impl Message {
|
||||
#[inline]
|
||||
pub fn parse(buf: &'a [u8]) -> io::Result<ParseResult<&'a [u8]>> {
|
||||
Message::parse_inner(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Message<Vec<u8>> {
|
||||
/// Attempts to parse a backend message from the buffer.
|
||||
///
|
||||
/// In contrast to `parse`, this method produces messages that do not reference the input,
|
||||
/// buffer by copying any necessary portions internally.
|
||||
#[inline]
|
||||
pub fn parse_owned(buf: &[u8]) -> io::Result<ParseResult<Vec<u8>>> {
|
||||
Message::parse_inner(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Message<T>
|
||||
where T: From<&'a [u8]>
|
||||
{
|
||||
#[inline]
|
||||
fn parse_inner(buf: &'a [u8]) -> io::Result<ParseResult<T>> {
|
||||
pub fn parse(buf: &mut BytesMut) -> io::Result<Option<Message>> {
|
||||
if buf.len() < 5 {
|
||||
return Ok(ParseResult::Incomplete { required_size: None });
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut r = buf;
|
||||
let tag = r.read_u8().unwrap();
|
||||
// add a byte for the tag
|
||||
let len = r.read_u32::<BigEndian>().unwrap() as usize + 1;
|
||||
let tag = buf[0];
|
||||
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||
|
||||
if buf.len() < len {
|
||||
return Ok(ParseResult::Incomplete { required_size: Some(len) });
|
||||
if len < 4 {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length"));
|
||||
}
|
||||
|
||||
let mut buf = &buf[5..len];
|
||||
let total_len = len as usize + 1;
|
||||
if buf.len() < total_len {
|
||||
let to_read = total_len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut buf = Buffer {
|
||||
bytes: buf.split_to(total_len).freeze(),
|
||||
idx: 5,
|
||||
};
|
||||
|
||||
let message = match tag {
|
||||
b'1' => Message::ParseComplete,
|
||||
b'2' => Message::BindComplete,
|
||||
b'3' => Message::CloseComplete,
|
||||
b'A' => {
|
||||
let process_id = try!(buf.read_i32::<BigEndian>());
|
||||
let channel_end = try!(find_null(buf, 0));
|
||||
let message_end = try!(find_null(buf, channel_end + 1));
|
||||
let storage = buf[..message_end].into();
|
||||
buf = &buf[message_end + 1..];
|
||||
let channel = try!(buf.read_cstr());
|
||||
let message = try!(buf.read_cstr());
|
||||
Message::NotificationResponse(NotificationResponseBody {
|
||||
storage: storage,
|
||||
process_id: process_id,
|
||||
channel_end: channel_end,
|
||||
channel: channel,
|
||||
message: message,
|
||||
})
|
||||
}
|
||||
b'c' => Message::CopyDone,
|
||||
b'C' => {
|
||||
let tag_end = try!(find_null(buf, 0));
|
||||
let storage = buf[..tag_end].into();
|
||||
buf = &buf[tag_end + 1..];
|
||||
let tag = try!(buf.read_cstr());
|
||||
Message::CommandComplete(CommandCompleteBody {
|
||||
storage: storage,
|
||||
tag: tag,
|
||||
})
|
||||
}
|
||||
b'd' => {
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::CopyData(CopyDataBody { storage: storage })
|
||||
}
|
||||
b'D' => {
|
||||
let len = try!(buf.read_u16::<BigEndian>());
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::DataRow(DataRowBody {
|
||||
storage: storage,
|
||||
len: len,
|
||||
})
|
||||
}
|
||||
b'E' => {
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::ErrorResponse(ErrorResponseBody { storage: storage })
|
||||
}
|
||||
b'G' => {
|
||||
let format = try!(buf.read_u8());
|
||||
let len = try!(buf.read_u16::<BigEndian>());
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::CopyInResponse(CopyInResponseBody {
|
||||
format: format,
|
||||
len: len,
|
||||
@ -142,8 +122,7 @@ impl<'a, T> Message<T>
|
||||
b'H' => {
|
||||
let format = try!(buf.read_u8());
|
||||
let len = try!(buf.read_u16::<BigEndian>());
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::CopyOutResponse(CopyOutResponseBody {
|
||||
format: format,
|
||||
len: len,
|
||||
@ -157,13 +136,11 @@ impl<'a, T> Message<T>
|
||||
Message::BackendKeyData(BackendKeyDataBody {
|
||||
process_id: process_id,
|
||||
secret_key: secret_key,
|
||||
_p: PhantomData,
|
||||
})
|
||||
}
|
||||
b'n' => Message::NoData,
|
||||
b'N' => {
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::NoticeResponse(NoticeResponseBody {
|
||||
storage: storage,
|
||||
})
|
||||
@ -178,7 +155,6 @@ impl<'a, T> Message<T>
|
||||
try!(buf.read_exact(&mut salt));
|
||||
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody {
|
||||
salt: salt,
|
||||
_p: PhantomData,
|
||||
})
|
||||
}
|
||||
6 => Message::AuthenticationScmCredential,
|
||||
@ -192,19 +168,16 @@ impl<'a, T> Message<T>
|
||||
}
|
||||
b's' => Message::PortalSuspended,
|
||||
b'S' => {
|
||||
let name_end = try!(find_null(buf, 0));
|
||||
let value_end = try!(find_null(buf, name_end + 1));
|
||||
let storage = buf[0..value_end].into();
|
||||
buf = &buf[value_end + 1..];
|
||||
let name = try!(buf.read_cstr());
|
||||
let value = try!(buf.read_cstr());
|
||||
Message::ParameterStatus(ParameterStatusBody {
|
||||
storage: storage,
|
||||
name_end: name_end,
|
||||
name: name,
|
||||
value: value,
|
||||
})
|
||||
}
|
||||
b't' => {
|
||||
let len = try!(buf.read_u16::<BigEndian>());
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::ParameterDescription(ParameterDescriptionBody {
|
||||
storage: storage,
|
||||
len: len,
|
||||
@ -212,8 +185,7 @@ impl<'a, T> Message<T>
|
||||
}
|
||||
b'T' => {
|
||||
let len = try!(buf.read_u16::<BigEndian>());
|
||||
let storage = buf.into();
|
||||
buf = &[];
|
||||
let storage = buf.read_all();
|
||||
Message::RowDescription(RowDescriptionBody {
|
||||
storage: storage,
|
||||
len: len,
|
||||
@ -223,7 +195,6 @@ impl<'a, T> Message<T>
|
||||
let status = try!(buf.read_u8());
|
||||
Message::ReadyForQuery(ReadyForQueryBody {
|
||||
status: status,
|
||||
_p: PhantomData,
|
||||
})
|
||||
}
|
||||
tag => {
|
||||
@ -236,54 +207,74 @@ impl<'a, T> Message<T>
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid message length"));
|
||||
}
|
||||
|
||||
Ok(ParseResult::Complete {
|
||||
message: message,
|
||||
consumed: len,
|
||||
})
|
||||
Ok(Some(message))
|
||||
}
|
||||
}
|
||||
|
||||
/// The result of an attempted parse.
|
||||
pub enum ParseResult<T> {
|
||||
/// The message was successfully parsed.
|
||||
Complete {
|
||||
/// The message.
|
||||
message: Message<T>,
|
||||
/// The number of bytes of the input buffer consumed to parse this message.
|
||||
consumed: usize,
|
||||
},
|
||||
/// The buffer did not contain a full message.
|
||||
Incomplete {
|
||||
/// The number of total bytes required to parse a message, if known.
|
||||
///
|
||||
/// This value is present if the input buffer contains at least 5 bytes.
|
||||
required_size: Option<usize>,
|
||||
struct Buffer {
|
||||
bytes: Bytes,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl Buffer {
|
||||
fn slice(&self) -> &[u8] {
|
||||
&self.bytes[self.idx..]
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.slice().is_empty()
|
||||
}
|
||||
|
||||
fn read_cstr(&mut self) -> io::Result<Bytes> {
|
||||
match memchr(0, self.slice()) {
|
||||
Some(pos) => {
|
||||
let start = self.idx;
|
||||
let end = start + pos;
|
||||
let cstr = self.bytes.slice(start, end);
|
||||
self.idx = end + 1;
|
||||
Ok(cstr)
|
||||
}
|
||||
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_all(&mut self) -> Bytes {
|
||||
let buf = self.bytes.slice_from(self.idx);
|
||||
self.idx = self.bytes.len();
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthenticationMd5PasswordBody<T> {
|
||||
impl Read for Buffer {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let len = {
|
||||
let slice = self.slice();
|
||||
let len = cmp::min(slice.len(), buf.len());
|
||||
buf[..len].copy_from_slice(&slice[..len]);
|
||||
len
|
||||
};
|
||||
self.idx += len;
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthenticationMd5PasswordBody {
|
||||
salt: [u8; 4],
|
||||
_p: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> AuthenticationMd5PasswordBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl AuthenticationMd5PasswordBody {
|
||||
#[inline]
|
||||
pub fn salt(&self) -> [u8; 4] {
|
||||
self.salt
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BackendKeyDataBody<T> {
|
||||
pub struct BackendKeyDataBody {
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
_p: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> BackendKeyDataBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl BackendKeyDataBody {
|
||||
#[inline]
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
@ -295,41 +286,35 @@ impl<T> BackendKeyDataBody<T>
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CommandCompleteBody<T> {
|
||||
storage: T,
|
||||
pub struct CommandCompleteBody {
|
||||
tag: Bytes,
|
||||
}
|
||||
|
||||
impl<T> CommandCompleteBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl CommandCompleteBody {
|
||||
#[inline]
|
||||
pub fn tag(&self) -> io::Result<&str> {
|
||||
get_str(&self.storage)
|
||||
get_str(&self.tag)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CopyDataBody<T> {
|
||||
storage: T,
|
||||
pub struct CopyDataBody {
|
||||
storage: Bytes,
|
||||
}
|
||||
|
||||
impl<T> CopyDataBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl CopyDataBody {
|
||||
#[inline]
|
||||
pub fn data(&self) -> &[u8] {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CopyInResponseBody<T> {
|
||||
storage: T,
|
||||
pub struct CopyInResponseBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
format: u8,
|
||||
}
|
||||
|
||||
impl<T> CopyInResponseBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl CopyInResponseBody {
|
||||
#[inline]
|
||||
pub fn format(&self) -> u8 {
|
||||
self.format
|
||||
@ -374,15 +359,13 @@ impl<'a> FallibleIterator for ColumnFormats<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CopyOutResponseBody<T> {
|
||||
storage: T,
|
||||
pub struct CopyOutResponseBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
format: u8,
|
||||
}
|
||||
|
||||
impl<T> CopyOutResponseBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl CopyOutResponseBody {
|
||||
#[inline]
|
||||
pub fn format(&self) -> u8 {
|
||||
self.format
|
||||
@ -397,34 +380,39 @@ impl<T> CopyOutResponseBody<T>
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DataRowBody<T> {
|
||||
storage: T,
|
||||
pub struct DataRowBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
}
|
||||
|
||||
impl<T> DataRowBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl DataRowBody {
|
||||
#[inline]
|
||||
pub fn values<'a>(&'a self) -> DataRowValues<'a> {
|
||||
DataRowValues {
|
||||
pub fn ranges<'a>(&'a self) -> DataRowRanges<'a> {
|
||||
DataRowRanges {
|
||||
buf: &self.storage,
|
||||
len: self.storage.len(),
|
||||
remaining: self.len,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn buffer(&self) -> &[u8] {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DataRowValues<'a> {
|
||||
pub struct DataRowRanges<'a> {
|
||||
buf: &'a [u8],
|
||||
len: usize,
|
||||
remaining: u16,
|
||||
}
|
||||
|
||||
impl<'a> FallibleIterator for DataRowValues<'a> {
|
||||
type Item = Option<&'a [u8]>;
|
||||
impl<'a> FallibleIterator for DataRowRanges<'a> {
|
||||
type Item = Option<Range<usize>>;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<Option<&'a [u8]>>> {
|
||||
fn next(&mut self) -> io::Result<Option<Option<Range<usize>>>> {
|
||||
if self.remaining == 0 {
|
||||
if self.buf.is_empty() {
|
||||
return Ok(None);
|
||||
@ -442,9 +430,9 @@ impl<'a> FallibleIterator for DataRowValues<'a> {
|
||||
if self.buf.len() < len {
|
||||
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"));
|
||||
}
|
||||
let (head, tail) = self.buf.split_at(len);
|
||||
self.buf = tail;
|
||||
Ok(Some(Some(head)))
|
||||
let base = self.len - self.buf.len();
|
||||
self.buf = &self.buf[len as usize..];
|
||||
Ok(Some(Some(base..base + len)))
|
||||
}
|
||||
}
|
||||
|
||||
@ -455,18 +443,14 @@ impl<'a> FallibleIterator for DataRowValues<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErrorResponseBody<T> {
|
||||
storage: T,
|
||||
pub struct ErrorResponseBody {
|
||||
storage: Bytes,
|
||||
}
|
||||
|
||||
impl<T> ErrorResponseBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl ErrorResponseBody {
|
||||
#[inline]
|
||||
pub fn fields<'a>(&'a self) -> ErrorFields<'a> {
|
||||
ErrorFields {
|
||||
buf: &self.storage
|
||||
}
|
||||
ErrorFields { buf: &self.storage }
|
||||
}
|
||||
}
|
||||
|
||||
@ -494,9 +478,9 @@ impl<'a> FallibleIterator for ErrorFields<'a> {
|
||||
self.buf = &self.buf[value_end + 1..];
|
||||
|
||||
Ok(Some(ErrorField {
|
||||
type_: type_,
|
||||
value: value,
|
||||
}))
|
||||
type_: type_,
|
||||
value: value,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@ -517,30 +501,24 @@ impl<'a> ErrorField<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoticeResponseBody<T> {
|
||||
storage: T,
|
||||
pub struct NoticeResponseBody {
|
||||
storage: Bytes,
|
||||
}
|
||||
|
||||
impl<T> NoticeResponseBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl NoticeResponseBody {
|
||||
#[inline]
|
||||
pub fn fields<'a>(&'a self) -> ErrorFields<'a> {
|
||||
ErrorFields {
|
||||
buf: &self.storage
|
||||
}
|
||||
ErrorFields { buf: &self.storage }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NotificationResponseBody<T> {
|
||||
storage: T,
|
||||
pub struct NotificationResponseBody {
|
||||
process_id: i32,
|
||||
channel_end: usize,
|
||||
channel: Bytes,
|
||||
message: Bytes,
|
||||
}
|
||||
|
||||
impl<T> NotificationResponseBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl NotificationResponseBody {
|
||||
#[inline]
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
@ -548,23 +526,21 @@ impl<T> NotificationResponseBody<T>
|
||||
|
||||
#[inline]
|
||||
pub fn channel(&self) -> io::Result<&str> {
|
||||
get_str(&self.storage[..self.channel_end])
|
||||
get_str(&self.channel)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn message(&self) -> io::Result<&str> {
|
||||
get_str(&self.storage[self.channel_end + 1..])
|
||||
get_str(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParameterDescriptionBody<T> {
|
||||
storage: T,
|
||||
pub struct ParameterDescriptionBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
}
|
||||
|
||||
impl<T> ParameterDescriptionBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl ParameterDescriptionBody {
|
||||
#[inline]
|
||||
pub fn parameters<'a>(&'a self) -> Parameters<'a> {
|
||||
Parameters {
|
||||
@ -604,47 +580,40 @@ impl<'a> FallibleIterator for Parameters<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParameterStatusBody<T> {
|
||||
storage: T,
|
||||
name_end: usize,
|
||||
pub struct ParameterStatusBody {
|
||||
name: Bytes,
|
||||
value: Bytes,
|
||||
}
|
||||
|
||||
impl<T> ParameterStatusBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl ParameterStatusBody {
|
||||
#[inline]
|
||||
pub fn name(&self) -> io::Result<&str> {
|
||||
get_str(&self.storage[..self.name_end])
|
||||
get_str(&self.name)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn value(&self) -> io::Result<&str> {
|
||||
get_str(&self.storage[self.name_end + 1..])
|
||||
get_str(&self.value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReadyForQueryBody<T> {
|
||||
pub struct ReadyForQueryBody {
|
||||
status: u8,
|
||||
_p: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> ReadyForQueryBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl ReadyForQueryBody {
|
||||
#[inline]
|
||||
pub fn status(&self) -> u8 {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RowDescriptionBody<T> {
|
||||
storage: T,
|
||||
pub struct RowDescriptionBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
}
|
||||
|
||||
impl<T> RowDescriptionBody<T>
|
||||
where T: Deref<Target = [u8]>
|
||||
{
|
||||
impl RowDescriptionBody {
|
||||
#[inline]
|
||||
pub fn fields<'a>(&'a self) -> Fields<'a> {
|
||||
Fields {
|
||||
@ -685,14 +654,14 @@ impl<'a> FallibleIterator for Fields<'a> {
|
||||
let format = try!(self.buf.read_i16::<BigEndian>());
|
||||
|
||||
Ok(Some(Field {
|
||||
name: name,
|
||||
table_oid: table_oid,
|
||||
column_id: column_id,
|
||||
type_oid: type_oid,
|
||||
type_size: type_size,
|
||||
type_modifier: type_modifier,
|
||||
format: format,
|
||||
}))
|
||||
name: name,
|
||||
table_oid: table_oid,
|
||||
column_id: column_id,
|
||||
type_oid: type_oid,
|
||||
type_size: type_size,
|
||||
type_modifier: type_modifier,
|
||||
format: format,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@ -747,11 +716,11 @@ impl<'a> Field<'a> {
|
||||
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
|
||||
match memchr(0, &buf[start..]) {
|
||||
Some(pos) => Ok(pos + start),
|
||||
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF"))
|
||||
None => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_str(buf: &[u8]) -> io::Result<&str> {
|
||||
str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
use fallible_iterator::{FallibleIterator, FromFallibleIterator};
|
||||
use fallible_iterator::{FallibleIterator};
|
||||
use postgres_protocol::message::backend::DataRowBody;
|
||||
use std::ascii::AsciiExt;
|
||||
use std::io;
|
||||
use std::ops::Range;
|
||||
|
||||
use stmt::Column;
|
||||
@ -32,12 +34,13 @@ impl<'a> RowIndex for str {
|
||||
// FIXME ASCII-only case insensitivity isn't really the right thing to
|
||||
// do. Postgres itself uses a dubious wrapper around tolower and JDBC
|
||||
// uses the US locale.
|
||||
stmt.iter().position(|d| d.name().eq_ignore_ascii_case(self))
|
||||
stmt.iter()
|
||||
.position(|d| d.name().eq_ignore_ascii_case(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: ?Sized> RowIndex for &'a T
|
||||
where T: RowIndex
|
||||
where T: RowIndex
|
||||
{
|
||||
#[inline]
|
||||
fn idx(&self, columns: &[Column]) -> Option<usize> {
|
||||
@ -47,43 +50,26 @@ where T: RowIndex
|
||||
|
||||
#[doc(hidden)]
|
||||
pub struct RowData {
|
||||
buf: Vec<u8>,
|
||||
indices: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl<'a> FromFallibleIterator<Option<&'a [u8]>> for RowData {
|
||||
fn from_fallible_iterator<I>(mut it: I) -> Result<RowData, I::Error>
|
||||
where I: FallibleIterator<Item = Option<&'a [u8]>>
|
||||
{
|
||||
let mut row = RowData {
|
||||
buf: vec![],
|
||||
indices: Vec::with_capacity(it.size_hint().0),
|
||||
};
|
||||
|
||||
while let Some(cell) = it.next()? {
|
||||
let index = match cell {
|
||||
Some(cell) => {
|
||||
let base = row.buf.len();
|
||||
row.buf.extend_from_slice(cell);
|
||||
Some(base..row.buf.len())
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
row.indices.push(index);
|
||||
}
|
||||
|
||||
Ok(row)
|
||||
}
|
||||
body: DataRowBody,
|
||||
ranges: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl RowData {
|
||||
pub fn new(body: DataRowBody) -> io::Result<RowData> {
|
||||
let ranges = body.ranges().collect()?;
|
||||
Ok(RowData {
|
||||
body: body,
|
||||
ranges: ranges,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.indices.len()
|
||||
self.ranges.len()
|
||||
}
|
||||
|
||||
pub fn get(&self, index: usize) -> Option<&[u8]> {
|
||||
match &self.indices[index] {
|
||||
&Some(ref range) => Some(&self.buf[range.clone()]),
|
||||
match &self.ranges[index] {
|
||||
&Some(ref range) => Some(&self.body.buffer()[range.clone()]),
|
||||
&None => None,
|
||||
}
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ with-security-framework = ["security-framework"]
|
||||
no-logging = []
|
||||
|
||||
[dependencies]
|
||||
bufstream = "0.1"
|
||||
bytes = "0.4"
|
||||
fallible-iterator = "0.1.3"
|
||||
log = "0.3"
|
||||
|
||||
|
@ -70,7 +70,7 @@
|
||||
#![warn(missing_docs)]
|
||||
#![allow(unknown_lints, needless_lifetimes, doc_markdown)] // for clippy
|
||||
|
||||
extern crate bufstream;
|
||||
extern crate bytes;
|
||||
extern crate fallible_iterator;
|
||||
#[cfg(not(feature = "no-logging"))]
|
||||
#[macro_use]
|
||||
@ -315,7 +315,7 @@ impl InnerConnection {
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
fn read_message_with_notification(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
|
||||
fn read_message_with_notification(&mut self) -> io::Result<backend::Message> {
|
||||
debug_assert!(!self.desynchronized);
|
||||
loop {
|
||||
match try_desync!(self, self.stream.read_message()) {
|
||||
@ -335,7 +335,7 @@ impl InnerConnection {
|
||||
|
||||
fn read_message_with_notification_timeout(&mut self,
|
||||
timeout: Duration)
|
||||
-> io::Result<Option<backend::Message<Vec<u8>>>> {
|
||||
-> io::Result<Option<backend::Message>> {
|
||||
debug_assert!(!self.desynchronized);
|
||||
loop {
|
||||
match try_desync!(self, self.stream.read_message_timeout(timeout)) {
|
||||
@ -354,7 +354,7 @@ impl InnerConnection {
|
||||
}
|
||||
|
||||
fn read_message_with_notification_nonblocking(&mut self)
|
||||
-> io::Result<Option<backend::Message<Vec<u8>>>> {
|
||||
-> io::Result<Option<backend::Message>> {
|
||||
debug_assert!(!self.desynchronized);
|
||||
loop {
|
||||
match try_desync!(self, self.stream.read_message_nonblocking()) {
|
||||
@ -372,7 +372,7 @@ impl InnerConnection {
|
||||
}
|
||||
}
|
||||
|
||||
fn read_message(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
|
||||
fn read_message(&mut self) -> io::Result<backend::Message> {
|
||||
loop {
|
||||
match self.read_message_with_notification()? {
|
||||
backend::Message::NotificationResponse(body) => {
|
||||
@ -495,7 +495,7 @@ impl InnerConnection {
|
||||
more_rows = true;
|
||||
break;
|
||||
}
|
||||
backend::Message::DataRow(body) => consumer(body.values().collect()?),
|
||||
backend::Message::DataRow(body) => consumer(RowData::new(body)?),
|
||||
backend::Message::ErrorResponse(body) => {
|
||||
self.wait_for_ready()?;
|
||||
return Err(err(&mut body.fields()));
|
||||
@ -832,8 +832,8 @@ impl InnerConnection {
|
||||
match self.read_message()? {
|
||||
backend::Message::ReadyForQuery(_) => break,
|
||||
backend::Message::DataRow(body) => {
|
||||
let row = body.values()
|
||||
.map(|v| v.map(|v| String::from_utf8_lossy(v).into_owned()))
|
||||
let row = body.ranges()
|
||||
.map(|r| r.map(|r| String::from_utf8_lossy(&body.buffer()[r]).into_owned()))
|
||||
.collect()?;
|
||||
result.push(row);
|
||||
}
|
||||
|
@ -1,9 +1,8 @@
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::io::{self, BufWriter, Read, Write};
|
||||
use std::fmt;
|
||||
use std::net::TcpStream;
|
||||
use std::time::Duration;
|
||||
use bufstream::BufStream;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::net::UnixStream;
|
||||
#[cfg(unix)]
|
||||
@ -11,25 +10,27 @@ use std::os::unix::io::{AsRawFd, RawFd};
|
||||
#[cfg(windows)]
|
||||
use std::os::windows::io::{AsRawSocket, RawSocket};
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::message::backend::{self, ParseResult};
|
||||
use postgres_protocol::message::backend;
|
||||
|
||||
use TlsMode;
|
||||
use error::ConnectError;
|
||||
use tls::TlsStream;
|
||||
use params::{ConnectParams, Host};
|
||||
|
||||
const MESSAGE_HEADER_SIZE: usize = 5;
|
||||
const INITIAL_CAPACITY: usize = 8 * 1024;
|
||||
|
||||
pub struct MessageStream {
|
||||
stream: BufStream<Box<TlsStream>>,
|
||||
buf: Vec<u8>,
|
||||
stream: BufWriter<Box<TlsStream>>,
|
||||
in_buf: BytesMut,
|
||||
out_buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl MessageStream {
|
||||
pub fn new(stream: Box<TlsStream>) -> MessageStream {
|
||||
MessageStream {
|
||||
stream: BufStream::new(stream),
|
||||
buf: vec![],
|
||||
stream: BufWriter::new(stream),
|
||||
in_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
out_buf: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,65 +42,66 @@ impl MessageStream {
|
||||
where F: FnOnce(&mut Vec<u8>) -> Result<(), E>,
|
||||
E: From<io::Error>
|
||||
{
|
||||
self.buf.clear();
|
||||
f(&mut self.buf)?;
|
||||
self.stream.write_all(&self.buf).map_err(From::from)
|
||||
self.out_buf.clear();
|
||||
f(&mut self.out_buf)?;
|
||||
self.stream.write_all(&self.out_buf).map_err(From::from)
|
||||
}
|
||||
|
||||
fn inner_read_message(&mut self, b: u8) -> io::Result<backend::Message<Vec<u8>>> {
|
||||
self.buf.resize(MESSAGE_HEADER_SIZE, 0);
|
||||
self.buf[0] = b;
|
||||
self.stream.read_exact(&mut self.buf[1..])?;
|
||||
|
||||
let len = match backend::Message::parse_owned(&self.buf)? {
|
||||
ParseResult::Complete { message, .. } => return Ok(message),
|
||||
ParseResult::Incomplete { required_size } => Some(required_size.unwrap()),
|
||||
};
|
||||
|
||||
if let Some(len) = len {
|
||||
self.buf.resize(len, 0);
|
||||
self.stream.read_exact(&mut self.buf[MESSAGE_HEADER_SIZE..])?;
|
||||
};
|
||||
|
||||
match backend::Message::parse_owned(&self.buf)? {
|
||||
ParseResult::Complete { message, .. } => Ok(message),
|
||||
ParseResult::Incomplete { .. } => unreachable!(),
|
||||
pub fn read_message(&mut self) -> io::Result<backend::Message> {
|
||||
loop {
|
||||
match backend::Message::parse(&mut self.in_buf) {
|
||||
Ok(Some(message)) => return Ok(message),
|
||||
Ok(None) => self.read_in()?,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_message(&mut self) -> io::Result<backend::Message<Vec<u8>>> {
|
||||
let mut b = [0; 1];
|
||||
self.stream.read_exact(&mut b)?;
|
||||
self.inner_read_message(b[0])
|
||||
fn read_in(&mut self) -> io::Result<()> {
|
||||
self.in_buf.reserve(1);
|
||||
match self.stream.get_mut().read(unsafe { self.in_buf.bytes_mut() }) {
|
||||
Ok(0) => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected EOF")),
|
||||
Ok(n) => {
|
||||
unsafe { self.in_buf.advance_mut(n) };
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_message_timeout(&mut self,
|
||||
timeout: Duration)
|
||||
-> io::Result<Option<backend::Message<Vec<u8>>>> {
|
||||
self.set_read_timeout(Some(timeout))?;
|
||||
let mut b = [0; 1];
|
||||
let r = self.stream.read_exact(&mut b);
|
||||
self.set_read_timeout(None)?;
|
||||
-> io::Result<Option<backend::Message>> {
|
||||
if self.in_buf.is_empty() {
|
||||
self.set_read_timeout(Some(timeout))?;
|
||||
let r = self.read_in();
|
||||
self.set_read_timeout(None)?;
|
||||
|
||||
match r {
|
||||
Ok(()) => self.inner_read_message(b[0]).map(Some),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock ||
|
||||
e.kind() == io::ErrorKind::TimedOut => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
match r {
|
||||
Ok(()) => {},
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock ||
|
||||
e.kind() == io::ErrorKind::TimedOut => return Ok(None),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
self.read_message().map(Some)
|
||||
}
|
||||
|
||||
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message<Vec<u8>>>> {
|
||||
self.set_nonblocking(true)?;
|
||||
let mut b = [0; 1];
|
||||
let r = self.stream.read_exact(&mut b);
|
||||
self.set_nonblocking(false)?;
|
||||
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message>> {
|
||||
if self.in_buf.is_empty() {
|
||||
self.set_nonblocking(true)?;
|
||||
let r = self.read_in();
|
||||
self.set_nonblocking(false)?;
|
||||
|
||||
match r {
|
||||
Ok(()) => self.inner_read_message(b[0]).map(Some),
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
match r {
|
||||
Ok(()) => {},
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
self.read_message().map(Some)
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) -> io::Result<()> {
|
||||
|
@ -174,7 +174,7 @@ struct InnerConnection {
|
||||
}
|
||||
|
||||
impl InnerConnection {
|
||||
fn read(self) -> IoFuture<(backend::Message<Vec<u8>>, InnerConnection)> {
|
||||
fn read(self) -> IoFuture<(backend::Message, InnerConnection)> {
|
||||
self.into_future()
|
||||
.map_err(|e| e.0)
|
||||
.and_then(|(m, mut s)| {
|
||||
@ -209,10 +209,10 @@ impl InnerConnection {
|
||||
}
|
||||
|
||||
impl Stream for InnerConnection {
|
||||
type Item = backend::Message<Vec<u8>>;
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<backend::Message<Vec<u8>>>, io::Error> {
|
||||
fn poll(&mut self) -> Poll<Option<backend::Message>, io::Error> {
|
||||
loop {
|
||||
match try_ready!(self.stream.poll()) {
|
||||
Some(backend::Message::ParameterStatus(body)) => {
|
||||
@ -439,7 +439,7 @@ impl Connection {
|
||||
Ok((rows, Connection(s))).into_future().boxed()
|
||||
}
|
||||
backend::Message::DataRow(body) => {
|
||||
match body.values().collect() {
|
||||
match RowData::new(body) {
|
||||
Ok(row) => {
|
||||
rows.push(row);
|
||||
Connection(s).simple_read_rows(rows)
|
||||
@ -504,7 +504,7 @@ impl Connection {
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn ready_err<T>(self, body: ErrorResponseBody<Vec<u8>>) -> BoxFuture<T, Error>
|
||||
fn ready_err<T>(self, body: ErrorResponseBody) -> BoxFuture<T, Error>
|
||||
where T: 'static + Send
|
||||
{
|
||||
DbError::new(&mut body.fields())
|
||||
@ -948,8 +948,7 @@ impl Connection {
|
||||
let c = Connection(s);
|
||||
match m {
|
||||
backend::Message::DataRow(body) => {
|
||||
Either::A(body.values()
|
||||
.collect()
|
||||
Either::A(RowData::new(body)
|
||||
.map(|r| (Some(r), c))
|
||||
.map_err(Error::Io)
|
||||
.into_future())
|
||||
|
@ -2,7 +2,7 @@ use bytes::{BytesMut, BufMut};
|
||||
use futures::{BoxFuture, Future, IntoFuture, Sink, Stream as FuturesStream, Poll};
|
||||
use futures::future::Either;
|
||||
use postgres_shared::params::Host;
|
||||
use postgres_protocol::message::backend::{self, ParseResult};
|
||||
use postgres_protocol::message::backend;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::io::{self, Read, Write};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
@ -167,18 +167,11 @@ impl AsyncWrite for Stream {
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = backend::Message<Vec<u8>>;
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
// FIXME ideally we'd avoid re-copying the data
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<Self::Item>> {
|
||||
match backend::Message::parse_owned(buf.as_ref())? {
|
||||
ParseResult::Complete { message, consumed } => {
|
||||
buf.split_to(consumed);
|
||||
Ok(Some(message))
|
||||
}
|
||||
ParseResult::Incomplete { .. } => Ok(None),
|
||||
}
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(buf)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user