Connection IO logic
This commit is contained in:
parent
32fe52490e
commit
2480fefd2c
@ -1,8 +1,20 @@
|
||||
use crate::codec::{BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::error::DbError;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::{AsyncMessage, Error, Notification};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use std::collections::HashMap;
|
||||
use futures::{ready, Sink, Stream, StreamExt};
|
||||
use log::trace;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::codec::Framed;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
@ -13,13 +25,40 @@ pub struct Request {
|
||||
pub sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
enum State {
|
||||
Active,
|
||||
Terminating,
|
||||
Closing,
|
||||
}
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
///
|
||||
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
|
||||
/// server, and should generally be spawned off onto an executor to run in the background.
|
||||
///
|
||||
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
|
||||
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Connection<S, T> {
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_response: Option<BackendMessage>,
|
||||
responses: VecDeque<Response>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl<S, T> Connection<S, T> {
|
||||
impl<S, T> Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
parameters: HashMap<String, String>,
|
||||
@ -29,6 +68,240 @@ impl<S, T> Connection<S, T> {
|
||||
stream,
|
||||
parameters,
|
||||
receiver,
|
||||
pending_request: None,
|
||||
pending_response: None,
|
||||
responses: VecDeque::new(),
|
||||
state: State::Active,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value of a runtime parameter for this connection.
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
self.parameters.get(name).map(|s| &**s)
|
||||
}
|
||||
|
||||
fn poll_response(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||
if let Some(message) = self.pending_response.take() {
|
||||
trace!("retrying pending response");
|
||||
return Poll::Ready(Some(Ok(message)));
|
||||
}
|
||||
|
||||
Pin::new(&mut self.stream)
|
||||
.poll_next(cx)
|
||||
.map(|o| o.map(|r| r.map_err(Error::io)))
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
|
||||
if self.state != State::Active {
|
||||
trace!("poll_read: done");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = match self.poll_response(cx)? {
|
||||
Poll::Ready(Some(message)) => message,
|
||||
Poll::Ready(None) => return Err(Error::closed()),
|
||||
Poll::Pending => {
|
||||
trace!("poll_read: waiting on response");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let (mut messages, request_complete) = match message {
|
||||
BackendMessage::Async(Message::NoticeResponse(body)) => {
|
||||
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
|
||||
return Ok(Some(AsyncMessage::Notice(error)));
|
||||
}
|
||||
BackendMessage::Async(Message::NotificationResponse(body)) => {
|
||||
let notification = Notification {
|
||||
process_id: body.process_id(),
|
||||
channel: body.channel().map_err(Error::parse)?.to_string(),
|
||||
payload: body.message().map_err(Error::parse)?.to_string(),
|
||||
};
|
||||
return Ok(Some(AsyncMessage::Notification(notification)));
|
||||
}
|
||||
BackendMessage::Async(Message::ParameterStatus(body)) => {
|
||||
self.parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(_) => unreachable!(),
|
||||
BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
} => (messages, request_complete),
|
||||
};
|
||||
|
||||
let mut response = match self.responses.pop_front() {
|
||||
Some(response) => response,
|
||||
None => match messages.next().map_err(Error::parse)? {
|
||||
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
},
|
||||
};
|
||||
|
||||
match response.sender.poll_ready(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let _ = response.sender.start_send(messages);
|
||||
if !request_complete {
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
// we need to keep paging through the rest of the messages even if the receiver's hung up
|
||||
if !request_complete {
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.responses.push_front(response);
|
||||
self.pending_response = Some(BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
});
|
||||
trace!("poll_read: waiting on sender");
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
|
||||
if let Some(messages) = self.pending_request.take() {
|
||||
trace!("retrying pending request");
|
||||
return Poll::Ready(Some(messages));
|
||||
}
|
||||
|
||||
match self.receiver.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(request)) => {
|
||||
trace!("polled new request");
|
||||
self.responses.push_back(Response {
|
||||
sender: request.sender,
|
||||
});
|
||||
Poll::Ready(Some(request.messages))
|
||||
}
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
|
||||
loop {
|
||||
if self.state == State::Closing {
|
||||
trace!("poll_write: done");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let request = match self.poll_request(cx) {
|
||||
Poll::Ready(Some(request)) => request,
|
||||
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
|
||||
trace!("poll_write: at eof, terminating");
|
||||
self.state = State::Terminating;
|
||||
let mut request = vec![];
|
||||
frontend::terminate(&mut request);
|
||||
RequestMessages::Single(FrontendMessage::Raw(request))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
trace!(
|
||||
"poll_write: at eof, pending responses {}",
|
||||
self.responses.len()
|
||||
);
|
||||
return Ok(true);
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_write: waiting on request");
|
||||
return Ok(true);
|
||||
}
|
||||
};
|
||||
|
||||
if let Poll::Pending = Pin::new(&mut self.stream)
|
||||
.poll_ready(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
trace!("poll_write: waiting on socket");
|
||||
self.pending_request = Some(request);
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
match request {
|
||||
RequestMessages::Single(request) => {
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request)
|
||||
.map_err(Error::io)?;
|
||||
if self.state == State::Terminating {
|
||||
trace!("poll_write: sent eof, closing");
|
||||
self.state = State::Closing;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_flush(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => trace!("poll_flush: flushed"),
|
||||
Poll::Pending => trace!("poll_flush: waiting on socket"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if self.state != State::Closing {
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_close(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
trace!("poll_shutdown: complete");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_shutdown: waiting on socket");
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_message(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<AsyncMessage, Error>>> {
|
||||
let message = self.poll_read(cx)?;
|
||||
let want_flush = self.poll_write(cx)?;
|
||||
if want_flush {
|
||||
self.poll_flush(cx)?;
|
||||
}
|
||||
match message {
|
||||
Some(message) => Poll::Ready(Some(Ok(message))),
|
||||
None => match self.poll_shutdown(cx) {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(None),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Future for Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = Result<(), Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
while let Some(_) = ready!(Pin::as_mut(&mut self).poll_message(cx)?) {}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -115,6 +115,7 @@
|
||||
pub use crate::client::Client;
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connection::Connection;
|
||||
use crate::error::DbError;
|
||||
pub use crate::error::Error;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::socket::Socket;
|
||||
@ -157,3 +158,43 @@ where
|
||||
let config = config.parse::<Config>()?;
|
||||
config.connect(tls).await
|
||||
}
|
||||
|
||||
/// An asynchronous notification.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Notification {
|
||||
process_id: i32,
|
||||
channel: String,
|
||||
payload: String,
|
||||
}
|
||||
|
||||
/// An asynchronous message from the server.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum AsyncMessage {
|
||||
/// A notice.
|
||||
///
|
||||
/// Notices use the same format as errors, but aren't "errors" per-se.
|
||||
Notice(DbError),
|
||||
/// A notification.
|
||||
///
|
||||
/// Connections can subscribe to notifications with the `LISTEN` command.
|
||||
Notification(Notification),
|
||||
#[doc(hidden)]
|
||||
__NonExhaustive,
|
||||
}
|
||||
|
||||
impl Notification {
|
||||
/// The process ID of the notifying backend process.
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
/// The name of the channel that the notify has been raised on.
|
||||
pub fn channel(&self) -> &str {
|
||||
&self.channel
|
||||
}
|
||||
|
||||
/// The "payload" string passed from the notifying process.
|
||||
pub fn payload(&self) -> &str {
|
||||
&self.payload
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#![warn(rust_2018_idioms)]
|
||||
#![feature(async_await)]
|
||||
|
||||
use futures::FutureExt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::tls::{NoTls, NoTlsStream};
|
||||
@ -13,7 +14,7 @@ mod runtime;
|
||||
mod types;
|
||||
*/
|
||||
|
||||
async fn connect(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
|
||||
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
|
||||
let socket = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
@ -21,9 +22,16 @@ async fn connect(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>)
|
||||
config.connect_raw(socket, NoTls).await
|
||||
}
|
||||
|
||||
async fn connect(s: &str) -> Client {
|
||||
let (client, connection) = connect_raw(s).await.unwrap();
|
||||
let connection = connection.map(|r| r.unwrap());
|
||||
tokio::spawn(connection);
|
||||
client
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn plain_password_missing() {
|
||||
connect("user=pass_user dbname=postgres")
|
||||
connect_raw("user=pass_user dbname=postgres")
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
@ -31,7 +39,7 @@ async fn plain_password_missing() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn plain_password_wrong() {
|
||||
match connect("user=pass_user password=foo dbname=postgres").await {
|
||||
match connect_raw("user=pass_user password=foo dbname=postgres").await {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
@ -40,14 +48,12 @@ async fn plain_password_wrong() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn plain_password_ok() {
|
||||
connect("user=pass_user password=password dbname=postgres")
|
||||
.await
|
||||
.unwrap();
|
||||
connect("user=pass_user password=password dbname=postgres").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn md5_password_missing() {
|
||||
connect("user=md5_user dbname=postgres")
|
||||
connect_raw("user=md5_user dbname=postgres")
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
@ -55,7 +61,7 @@ async fn md5_password_missing() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn md5_password_wrong() {
|
||||
match connect("user=md5_user password=foo dbname=postgres").await {
|
||||
match connect_raw("user=md5_user password=foo dbname=postgres").await {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
@ -64,14 +70,12 @@ async fn md5_password_wrong() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn md5_password_ok() {
|
||||
connect("user=md5_user password=password dbname=postgres")
|
||||
.await
|
||||
.unwrap();
|
||||
connect("user=md5_user password=password dbname=postgres").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_password_missing() {
|
||||
connect("user=scram_user dbname=postgres")
|
||||
connect_raw("user=scram_user dbname=postgres")
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
@ -79,7 +83,7 @@ async fn scram_password_missing() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_password_wrong() {
|
||||
match connect("user=scram_user password=foo dbname=postgres").await {
|
||||
match connect_raw("user=scram_user password=foo dbname=postgres").await {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
Err(e) => panic!("{}", e),
|
||||
@ -88,9 +92,7 @@ async fn scram_password_wrong() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn scram_password_ok() {
|
||||
connect("user=scram_user password=password dbname=postgres")
|
||||
.await
|
||||
.unwrap();
|
||||
connect("user=scram_user password=password dbname=postgres").await;
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -1,4 +1,4 @@
|
||||
use futures::{Future, Stream};
|
||||
use futures::{Future, FutureExt, Stream};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio::timer::Delay;
|
||||
@ -7,10 +7,10 @@ use tokio_postgres::NoTls;
|
||||
|
||||
async fn smoke_test(s: &str) {
|
||||
let (mut client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap();
|
||||
/*
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
let connection = connection.map(|e| e.unwrap());
|
||||
tokio::spawn(connection);
|
||||
|
||||
/*
|
||||
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
|
||||
runtime.block_on(execute).unwrap();
|
||||
*/
|
||||
|
Loading…
Reference in New Issue
Block a user