Don't suppress notices during startup flow
NoticeResponses received during the startup flow were previously being dropped on the floor. Instead stash them away so they can be delivered to the user after the startup flow is complete.
This commit is contained in:
parent
5429a79997
commit
7ea1b2d785
@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
|
||||
use postgres_protocol::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BackendMessages,
|
||||
delayed: VecDeque<BackendMessage>,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
|
||||
@ -91,6 +92,7 @@ where
|
||||
let mut stream = StartupStream {
|
||||
inner: Framed::new(stream, PostgresCodec),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed: VecDeque::new(),
|
||||
};
|
||||
|
||||
startup(&mut stream, config).await?;
|
||||
@ -99,7 +101,7 @@ where
|
||||
|
||||
let (sender, receiver) = mpsc::unbounded();
|
||||
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
|
||||
let connection = Connection::new(stream.inner, parameters, receiver);
|
||||
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
@ -332,7 +334,9 @@ where
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
}
|
||||
Some(Message::NoticeResponse(_)) => {}
|
||||
Some(msg @ Message::NoticeResponse(_)) => {
|
||||
stream.delayed.push_back(BackendMessage::Async(msg))
|
||||
}
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
|
@ -52,7 +52,7 @@ pub struct Connection<S, T> {
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_response: Option<BackendMessage>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
responses: VecDeque<Response>,
|
||||
state: State,
|
||||
}
|
||||
@ -64,6 +64,7 @@ where
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
) -> Connection<S, T> {
|
||||
@ -72,7 +73,7 @@ where
|
||||
parameters,
|
||||
receiver,
|
||||
pending_request: None,
|
||||
pending_response: None,
|
||||
pending_responses,
|
||||
responses: VecDeque::new(),
|
||||
state: State::Active,
|
||||
}
|
||||
@ -82,7 +83,7 @@ where
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||
if let Some(message) = self.pending_response.take() {
|
||||
if let Some(message) = self.pending_responses.pop_front() {
|
||||
trace!("retrying pending response");
|
||||
return Poll::Ready(Some(Ok(message)));
|
||||
}
|
||||
@ -158,7 +159,7 @@ where
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.responses.push_front(response);
|
||||
self.pending_response = Some(BackendMessage::Normal {
|
||||
self.pending_responses.push_back(BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
});
|
||||
|
@ -570,6 +570,45 @@ async fn copy_out() {
|
||||
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn notices() {
|
||||
let long_name = "x".repeat(65);
|
||||
let (client, mut connection) =
|
||||
connect_raw(&format!("user=postgres application_name={}", long_name,))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
|
||||
let connection = stream.forward(tx).map(|r| r.unwrap());
|
||||
tokio::spawn(connection);
|
||||
|
||||
client
|
||||
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
drop(client);
|
||||
|
||||
let notices = rx
|
||||
.filter_map(|m| match m {
|
||||
AsyncMessage::Notice(n) => future::ready(Some(n)),
|
||||
_ => future::ready(None),
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await;
|
||||
assert_eq!(notices.len(), 2);
|
||||
assert_eq!(
|
||||
notices[0].message(),
|
||||
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
|
||||
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
|
||||
);
|
||||
assert_eq!(
|
||||
notices[1].message(),
|
||||
"database \"noexistdb\" does not exist, skipping"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn notifications() {
|
||||
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user