Merge pull request #564 from benesch/startup-notices

Don't suppress notices during startup flow
This commit is contained in:
Steven Fackler 2020-01-31 18:24:09 -05:00 committed by GitHub
commit 2ce4f08f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 7 deletions

View File

@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl;
use postgres_protocol::authentication::sasl::ScramSha256;
use postgres_protocol::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol::message::frontend;
use std::collections::HashMap;
use std::collections::{HashMap, VecDeque};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
@ -23,6 +23,7 @@ use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
@ -91,6 +92,7 @@ where
let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
};
startup(&mut stream, config).await?;
@ -99,7 +101,7 @@ where
let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, parameters, receiver);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
Ok((client, connection))
}
@ -332,7 +334,9 @@ where
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(Message::NoticeResponse(_)) => {}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),

View File

@ -52,7 +52,7 @@ pub struct Connection<S, T> {
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_response: Option<BackendMessage>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
@ -64,6 +64,7 @@ where
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
@ -72,7 +73,7 @@ where
parameters,
receiver,
pending_request: None,
pending_response: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
@ -82,7 +83,7 @@ where
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_response.take() {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
@ -158,7 +159,7 @@ where
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_response = Some(BackendMessage::Normal {
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});

View File

@ -570,6 +570,45 @@ async fn copy_out() {
assert_eq!(&data[..], b"1\tjim\n2\tjoe\n");
}
#[tokio::test]
async fn notices() {
let long_name = "x".repeat(65);
let (client, mut connection) =
connect_raw(&format!("user=postgres application_name={}", long_name,))
.await
.unwrap();
let (tx, rx) = mpsc::unbounded();
let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e));
let connection = stream.forward(tx).map(|r| r.unwrap());
tokio::spawn(connection);
client
.batch_execute("DROP DATABASE IF EXISTS noexistdb")
.await
.unwrap();
drop(client);
let notices = rx
.filter_map(|m| match m {
AsyncMessage::Notice(n) => future::ready(Some(n)),
_ => future::ready(None),
})
.collect::<Vec<_>>()
.await;
assert_eq!(notices.len(), 2);
assert_eq!(
notices[0].message(),
"identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \
will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\""
);
assert_eq!(
notices[1].message(),
"database \"noexistdb\" does not exist, skipping"
);
}
#[tokio::test]
async fn notifications() {
let (client, mut connection) = connect_raw("user=postgres").await.unwrap();