Don't suppress notices during startup flow

NoticeResponses received during the startup flow were previously being
dropped on the floor. Instead stash them away so they can be delivered
to the user after the startup flow is complete.
This commit is contained in:
Nikhil Benesch 2020-01-31 00:03:31 -05:00
parent 5429a79997
commit 7ea1b2d785
No known key found for this signature in database
GPG Key ID: FCF98542083C5A69
3 changed files with 51 additions and 7 deletions

View File

@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol::message::frontend;
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
@ -91,6 +92,7 @@ where
let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
};
startup(&mut stream, config).await?;
@ -99,7 +101,7 @@ where
let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, parameters, receiver);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
Ok((client, connection))
}
@ -332,7 +334,9 @@ where
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(Message::NoticeResponse(_)) => {}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),

View File

@ -52,7 +52,7 @@ pub struct Connection<S, T> {
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_response: Option<BackendMessage>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
@ -64,6 +64,7 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
@ -72,7 +73,7 @@ where
parameters,
receiver,
pending_request: None,
pending_response: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
@ -82,7 +83,7 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_response.take() {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
@ -158,7 +159,7 @@ where
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_response = Some(BackendMessage::Normal {
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});

View File

@ -570,6 +570,45 @@ async fn copy_out() {
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
}
#[tokio::test]
async fn notices() {
let long_name = "x".repeat(65);
let (client, mut connection) =
connect_raw(&format!("user=postgres application_name={}", long_name,))
.await
.unwrap();
let (tx, rx) = mpsc::unbounded();
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
let connection = stream.forward(tx).map(|r| r.unwrap());
tokio::spawn(connection);
client
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
.await
.unwrap();
drop(client);
let notices = rx
.filter_map(|m| match m {
AsyncMessage::Notice(n) => future::ready(Some(n)),
_ => future::ready(None),
})
.collect::<Vec<_>>()
.await;
assert_eq!(notices.len(), 2);
assert_eq!(
notices[0].message(),
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
);
assert_eq!(
notices[1].message(),
"database \"noexistdb\" does not exist, skipping"
);
}
#[tokio::test]
async fn notifications() {
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();