Don't suppress notices during startup flow
NoticeResponses received during the startup flow were previously being dropped on the floor. Instead stash them away so they can be delivered to the user after the startup flow is complete.
This commit is contained in:
parent
5429a79997
commit
7ea1b2d785
@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
|
|||||||
use postgres_protocol::authentication::sasl::ScramSha256;
|
use postgres_protocol::authentication::sasl::ScramSha256;
|
||||||
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
|
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
|
||||||
use postgres_protocol::message::frontend;
|
use postgres_protocol::message::frontend;
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, VecDeque};
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
|
|||||||
pub struct StartupStream<S, T> {
|
pub struct StartupStream<S, T> {
|
||||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||||
buf: BackendMessages,
|
buf: BackendMessages,
|
||||||
|
delayed: VecDeque<BackendMessage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
|
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
|
||||||
@ -91,6 +92,7 @@ where
|
|||||||
let mut stream = StartupStream {
|
let mut stream = StartupStream {
|
||||||
inner: Framed::new(stream, PostgresCodec),
|
inner: Framed::new(stream, PostgresCodec),
|
||||||
buf: BackendMessages::empty(),
|
buf: BackendMessages::empty(),
|
||||||
|
delayed: VecDeque::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
startup(&mut stream, config).await?;
|
startup(&mut stream, config).await?;
|
||||||
@ -99,7 +101,7 @@ where
|
|||||||
|
|
||||||
let (sender, receiver) = mpsc::unbounded();
|
let (sender, receiver) = mpsc::unbounded();
|
||||||
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
|
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
|
||||||
let connection = Connection::new(stream.inner, parameters, receiver);
|
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
|
||||||
|
|
||||||
Ok((client, connection))
|
Ok((client, connection))
|
||||||
}
|
}
|
||||||
@ -332,7 +334,9 @@ where
|
|||||||
body.value().map_err(Error::parse)?.to_string(),
|
body.value().map_err(Error::parse)?.to_string(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Some(Message::NoticeResponse(_)) => {}
|
Some(msg @ Message::NoticeResponse(_)) => {
|
||||||
|
stream.delayed.push_back(BackendMessage::Async(msg))
|
||||||
|
}
|
||||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||||
Some(_) => return Err(Error::unexpected_message()),
|
Some(_) => return Err(Error::unexpected_message()),
|
||||||
|
@ -52,7 +52,7 @@ pub struct Connection<S, T> {
|
|||||||
parameters: HashMap<String, String>,
|
parameters: HashMap<String, String>,
|
||||||
receiver: mpsc::UnboundedReceiver<Request>,
|
receiver: mpsc::UnboundedReceiver<Request>,
|
||||||
pending_request: Option<RequestMessages>,
|
pending_request: Option<RequestMessages>,
|
||||||
pending_response: Option<BackendMessage>,
|
pending_responses: VecDeque<BackendMessage>,
|
||||||
responses: VecDeque<Response>,
|
responses: VecDeque<Response>,
|
||||||
state: State,
|
state: State,
|
||||||
}
|
}
|
||||||
@ -64,6 +64,7 @@ where
|
|||||||
{
|
{
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||||
|
pending_responses: VecDeque<BackendMessage>,
|
||||||
parameters: HashMap<String, String>,
|
parameters: HashMap<String, String>,
|
||||||
receiver: mpsc::UnboundedReceiver<Request>,
|
receiver: mpsc::UnboundedReceiver<Request>,
|
||||||
) -> Connection<S, T> {
|
) -> Connection<S, T> {
|
||||||
@ -72,7 +73,7 @@ where
|
|||||||
parameters,
|
parameters,
|
||||||
receiver,
|
receiver,
|
||||||
pending_request: None,
|
pending_request: None,
|
||||||
pending_response: None,
|
pending_responses,
|
||||||
responses: VecDeque::new(),
|
responses: VecDeque::new(),
|
||||||
state: State::Active,
|
state: State::Active,
|
||||||
}
|
}
|
||||||
@ -82,7 +83,7 @@ where
|
|||||||
&mut self,
|
&mut self,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||||
if let Some(message) = self.pending_response.take() {
|
if let Some(message) = self.pending_responses.pop_front() {
|
||||||
trace!("retrying pending response");
|
trace!("retrying pending response");
|
||||||
return Poll::Ready(Some(Ok(message)));
|
return Poll::Ready(Some(Ok(message)));
|
||||||
}
|
}
|
||||||
@ -158,7 +159,7 @@ where
|
|||||||
}
|
}
|
||||||
Poll::Pending => {
|
Poll::Pending => {
|
||||||
self.responses.push_front(response);
|
self.responses.push_front(response);
|
||||||
self.pending_response = Some(BackendMessage::Normal {
|
self.pending_responses.push_back(BackendMessage::Normal {
|
||||||
messages,
|
messages,
|
||||||
request_complete,
|
request_complete,
|
||||||
});
|
});
|
||||||
|
@ -570,6 +570,45 @@ async fn copy_out() {
|
|||||||
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
|
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn notices() {
|
||||||
|
let long_name = "x".repeat(65);
|
||||||
|
let (client, mut connection) =
|
||||||
|
connect_raw(&format!("user=postgres application_name={}", long_name,))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
|
||||||
|
let connection = stream.forward(tx).map(|r| r.unwrap());
|
||||||
|
tokio::spawn(connection);
|
||||||
|
|
||||||
|
client
|
||||||
|
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
drop(client);
|
||||||
|
|
||||||
|
let notices = rx
|
||||||
|
.filter_map(|m| match m {
|
||||||
|
AsyncMessage::Notice(n) => future::ready(Some(n)),
|
||||||
|
_ => future::ready(None),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.await;
|
||||||
|
assert_eq!(notices.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
notices[0].message(),
|
||||||
|
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
|
||||||
|
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
notices[1].message(),
|
||||||
|
"database \"noexistdb\" does not exist, skipping"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn notifications() {
|
async fn notifications() {
|
||||||
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();
|
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();
|
||||||
|
Loading…
Reference in New Issue
Block a user