Connection IO logic

This commit is contained in:
Steven Fackler 2019-07-22 21:27:21 -07:00
parent 32fe52490e
commit 2480fefd2c
4 changed files with 339 additions and 23 deletions

View File

@ -1,8 +1,20 @@
use crate::codec::{BackendMessages, FrontendMessage, PostgresCodec};
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
use fallible_iterator::FallibleIterator;
use futures::channel::mpsc;
use std::collections::HashMap;
use futures::{ready, Sink, Stream, StreamExt};
use log::trace;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::codec::Framed;
use tokio::io::{AsyncRead, AsyncWrite};
pub enum RequestMessages {
Single(FrontendMessage),
@ -13,13 +25,40 @@ pub struct Request {
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: mpsc::Sender<BackendMessages>,
}
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
Closing,
}
/// A connection to a PostgreSQL database.
///
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
/// server, and should generally be spawned off onto an executor to run in the background.
///
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_response: Option<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
impl<S, T> Connection<S, T> {
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
parameters: HashMap<String, String>,
@ -29,6 +68,240 @@ impl<S, T> Connection<S, T> {
stream,
parameters,
receiver,
pending_request: None,
pending_response: None,
responses: VecDeque::new(),
state: State::Active,
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_response.take() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Ok(None);
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Ok(Some(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Ok(Some(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal {
messages,
request_complete,
} => (messages, request_complete),
};
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
match response.sender.poll_ready(cx) {
Poll::Ready(Ok(())) => {
let _ = response.sender.start_send(messages);
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Ready(Err(_)) => {
// we need to keep paging through the rest of the messages even if the receiver's hung up
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_response = Some(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Ok(None);
}
}
}
}
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
match self.receiver.poll_next_unpin(cx) {
Poll::Ready(Some(request)) => {
trace!("polled new request");
self.responses.push_back(Response {
sender: request.sender,
});
Poll::Ready(Some(request.messages))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = vec![];
frontend::terminate(&mut request);
RequestMessages::Single(FrontendMessage::Raw(request))
}
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
return Ok(true);
}
Poll::Pending => {
trace!("poll_write: waiting on request");
return Ok(true);
}
};
if let Poll::Pending = Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
{
trace!("poll_write: waiting on socket");
self.pending_request = Some(request);
return Ok(false);
}
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_shutdown: complete");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_shutdown: waiting on socket");
Poll::Pending
}
}
}
pub fn poll_message(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
}
impl<S, T> Future for Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(_) = ready!(Pin::as_mut(&mut self).poll_message(cx)?) {}
Poll::Ready(Ok(()))
}
}

View File

@ -115,6 +115,7 @@
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
#[cfg(feature = "runtime")]
pub use crate::socket::Socket;
@ -157,3 +158,43 @@ where
let config = config.parse::<Config>()?;
config.connect(tls).await
}
/// An asynchronous notification.
#[derive(Clone, Debug)]
pub struct Notification {
process_id: i32,
channel: String,
payload: String,
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
#[doc(hidden)]
__NonExhaustive,
}
impl Notification {
/// The process ID of the notifying backend process.
pub fn process_id(&self) -> i32 {
self.process_id
}
/// The name of the channel that the notify has been raised on.
pub fn channel(&self) -> &str {
&self.channel
}
/// The "payload" string passed from the notifying process.
pub fn payload(&self) -> &str {
&self.payload
}
}

View File

@ -1,6 +1,7 @@
#![warn(rust_2018_idioms)]
#![feature(async_await)]
use futures::FutureExt;
use tokio::net::TcpStream;
use tokio_postgres::error::SqlState;
use tokio_postgres::tls::{NoTls, NoTlsStream};
@ -13,7 +14,7 @@ mod runtime;
mod types;
*/
async fn connect(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
let socket = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
.await
.unwrap();
@ -21,9 +22,16 @@ async fn connect(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>)
config.connect_raw(socket, NoTls).await
}
async fn connect(s: &str) -> Client {
let (client, connection) = connect_raw(s).await.unwrap();
let connection = connection.map(|r| r.unwrap());
tokio::spawn(connection);
client
}
#[tokio::test]
async fn plain_password_missing() {
connect("user=pass_user dbname=postgres")
connect_raw("user=pass_user dbname=postgres")
.await
.err()
.unwrap();
@ -31,7 +39,7 @@ async fn plain_password_missing() {
#[tokio::test]
async fn plain_password_wrong() {
match connect("user=pass_user password=foo dbname=postgres").await {
match connect_raw("user=pass_user password=foo dbname=postgres").await {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
@ -40,14 +48,12 @@ async fn plain_password_wrong() {
#[tokio::test]
async fn plain_password_ok() {
connect("user=pass_user password=password dbname=postgres")
.await
.unwrap();
connect("user=pass_user password=password dbname=postgres").await;
}
#[tokio::test]
async fn md5_password_missing() {
connect("user=md5_user dbname=postgres")
connect_raw("user=md5_user dbname=postgres")
.await
.err()
.unwrap();
@ -55,7 +61,7 @@ async fn md5_password_missing() {
#[tokio::test]
async fn md5_password_wrong() {
match connect("user=md5_user password=foo dbname=postgres").await {
match connect_raw("user=md5_user password=foo dbname=postgres").await {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
@ -64,14 +70,12 @@ async fn md5_password_wrong() {
#[tokio::test]
async fn md5_password_ok() {
connect("user=md5_user password=password dbname=postgres")
.await
.unwrap();
connect("user=md5_user password=password dbname=postgres").await;
}
#[tokio::test]
async fn scram_password_missing() {
connect("user=scram_user dbname=postgres")
connect_raw("user=scram_user dbname=postgres")
.await
.err()
.unwrap();
@ -79,7 +83,7 @@ async fn scram_password_missing() {
#[tokio::test]
async fn scram_password_wrong() {
match connect("user=scram_user password=foo dbname=postgres").await {
match connect_raw("user=scram_user password=foo dbname=postgres").await {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
Err(e) => panic!("{}", e),
@ -88,9 +92,7 @@ async fn scram_password_wrong() {
#[tokio::test]
async fn scram_password_ok() {
connect("user=scram_user password=password dbname=postgres")
.await
.unwrap();
connect("user=scram_user password=password dbname=postgres").await;
}
/*

View File

@ -1,4 +1,4 @@
use futures::{Future, Stream};
use futures::{Future, FutureExt, Stream};
use std::time::{Duration, Instant};
use tokio::runtime::current_thread::Runtime;
use tokio::timer::Delay;
@ -7,10 +7,10 @@ use tokio_postgres::NoTls;
async fn smoke_test(s: &str) {
let (mut client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap();
/*
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let connection = connection.map(|e| e.unwrap());
tokio::spawn(connection);
/*
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
*/