269 lines
7.9 KiB
269 lines
7.9 KiB
use byteorder::ReadBytesExt;
use std::io;
use std::io::prelude::*;
use std::fmt;
use std::net::TcpStream;
use std::time::Duration;
use bufstream::BufStream;
use std::os::unix::net::UnixStream;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::windows::io::{AsRawSocket, RawSocket};
use postgres_protocol::message::frontend;
use postgres_protocol::message::backend::{self, ParseResult};
use TlsMode;
use error::ConnectError;
use io::TlsStream;
use params::{ConnectParams, ConnectTarget};
const DEFAULT_PORT: u16 = 5432;
const MESSAGE_HEADER_SIZE: usize = 5;
pub struct MessageStream {
stream: BufStream<Box<TlsStream>>,
buf: Vec<u8>,
impl MessageStream {
pub fn new(stream: Box<TlsStream>) -> MessageStream {
MessageStream {
stream: BufStream::new(stream),
buf: vec![],
pub fn get_ref(&self) -> &Box<TlsStream> {
pub fn write_message(&mut self, message: &frontend::Message) -> io::Result<()> {
try!(frontend::Message::write(message, &mut self.buf));
fn inner_read_message(&mut self, b: u8) -> io::Result<backend::Message> {
self.buf.resize(MESSAGE_HEADER_SIZE, 0);
self.buf[0] = b;
try!(self.stream.read_exact(&mut self.buf[1..]));
let len = match try!(backend::Message::parse(&self.buf)) {
ParseResult::Complete { message, .. } => return Ok(message),
ParseResult::Incomplete { required_size } => Some(required_size.unwrap()),
if let Some(len) = len {
self.buf.resize(len, 0);
try!(self.stream.read_exact(&mut self.buf[MESSAGE_HEADER_SIZE..]));
match try!(backend::Message::parse(&self.buf)) {
ParseResult::Complete { message, .. } => Ok(message),
ParseResult::Incomplete { .. } => unreachable!(),
pub fn read_message(&mut self) -> io::Result<backend::Message> {
let b = try!(self.stream.read_u8());
pub fn read_message_timeout(&mut self,
timeout: Duration)
-> io::Result<Option<backend::Message>> {
let b = self.stream.read_u8();
match b {
Ok(b) => self.inner_read_message(b).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock ||
e.kind() == io::ErrorKind::TimedOut => Ok(None),
Err(e) => Err(e),
pub fn read_message_nonblocking(&mut self) -> io::Result<Option<backend::Message>> {
let b = self.stream.read_u8();
match b {
Ok(b) => self.inner_read_message(b).map(Some),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(e),
pub fn flush(&mut self) -> io::Result<()> {
fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match self.stream.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_read_timeout(timeout),
InternalStream::Unix(ref s) => s.set_read_timeout(timeout),
fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> {
match self.stream.get_ref().get_ref().0 {
InternalStream::Tcp(ref s) => s.set_nonblocking(nonblock),
InternalStream::Unix(ref s) => s.set_nonblocking(nonblock),
/// A connection to the Postgres server.
/// It implements `Read`, `Write` and `TlsStream`, as well as `AsRawFd` on
/// Unix platforms and `AsRawSocket` on Windows platforms.
pub struct Stream(InternalStream);
impl fmt::Debug for Stream {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self.0 {
InternalStream::Tcp(ref s) => fmt::Debug::fmt(s, fmt),
InternalStream::Unix(ref s) => fmt::Debug::fmt(s, fmt),
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
fn flush(&mut self) -> io::Result<()> {
impl TlsStream for Stream {
fn get_ref(&self) -> &Stream {
fn get_mut(&mut self) -> &mut Stream {
impl AsRawFd for Stream {
fn as_raw_fd(&self) -> RawFd {
match self.0 {
InternalStream::Tcp(ref s) => s.as_raw_fd(),
InternalStream::Unix(ref s) => s.as_raw_fd(),
impl AsRawSocket for Stream {
fn as_raw_socket(&self) -> RawSocket {
// Unix sockets aren't supported on windows, so no need to match
match self.0 {
InternalStream::Tcp(ref s) => s.as_raw_socket(),
enum InternalStream {
impl Read for InternalStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match *self {
InternalStream::Tcp(ref mut s) => s.read(buf),
InternalStream::Unix(ref mut s) => s.read(buf),
impl Write for InternalStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match *self {
InternalStream::Tcp(ref mut s) => s.write(buf),
InternalStream::Unix(ref mut s) => s.write(buf),
fn flush(&mut self) -> io::Result<()> {
match *self {
InternalStream::Tcp(ref mut s) => s.flush(),
InternalStream::Unix(ref mut s) => s.flush(),
fn open_socket(params: &ConnectParams) -> Result<InternalStream, ConnectError> {
let port = params.port.unwrap_or(DEFAULT_PORT);
match params.target {
ConnectTarget::Tcp(ref host) => {
Ok(try!(TcpStream::connect(&(&**host, port)).map(InternalStream::Tcp)))
ConnectTarget::Unix(ref path) => {
let path = path.join(&format!(".s.PGSQL.{}", port));
ConnectTarget::Unix(..) => {
"unix sockets are not supported on this system")))
pub fn initialize_stream(params: &ConnectParams,
tls: TlsMode)
-> Result<Box<TlsStream>, ConnectError> {
let mut socket = Stream(try!(open_socket(params)));
let (tls_required, handshaker) = match tls {
TlsMode::None => return Ok(Box::new(socket)),
TlsMode::Prefer(handshaker) => (false, handshaker),
TlsMode::Require(handshaker) => (true, handshaker),
let mut buf = vec![];
try!(frontend::Message::write(&frontend::SslRequest, &mut buf));
if try!(socket.read_u8()) == b'N' {
if tls_required {
return Err(ConnectError::Tls("the server does not support TLS".into()));
} else {
return Ok(Box::new(socket));
let host = match params.target {
ConnectTarget::Tcp(ref host) => host,
// Postgres doesn't support TLS over unix sockets
ConnectTarget::Unix(_) => return Err(ConnectError::Io(::bad_response())),
handshaker.tls_handshake(host, socket).map_err(ConnectError::Tls)