diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 7ccfe9b5..d07d5a2d 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -13,7 +13,7 @@ use postgres_protocol::authentication::sasl; use postgres_protocol::authentication::sasl::ScramSha256; use postgres_protocol::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol::message::frontend; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; @@ -23,6 +23,7 @@ use tokio_util::codec::Framed; pub struct StartupStream { inner: Framed, PostgresCodec>, buf: BackendMessages, + delayed: VecDeque, } impl Sink for StartupStream @@ -91,6 +92,7 @@ where let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), buf: BackendMessages::empty(), + delayed: VecDeque::new(), }; startup(&mut stream, config).await?; @@ -99,7 +101,7 @@ where let (sender, receiver) = mpsc::unbounded(); let client = Client::new(sender, config.ssl_mode, process_id, secret_key); - let connection = Connection::new(stream.inner, parameters, receiver); + let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); Ok((client, connection)) } @@ -332,7 +334,9 @@ where body.value().map_err(Error::parse)?.to_string(), ); } - Some(Message::NoticeResponse(_)) => {} + Some(msg @ Message::NoticeResponse(_)) => { + stream.delayed.push_back(BackendMessage::Async(msg)) + } Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), Some(_) => return Err(Error::unexpected_message()), diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 5b014428..ac186743 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -52,7 +52,7 @@ pub struct Connection { parameters: HashMap, receiver: mpsc::UnboundedReceiver, pending_request: Option, - pending_response: Option, + pending_responses: VecDeque, responses: VecDeque, state: State, } @@ -64,6 +64,7 @@ where { pub(crate) fn new( stream: Framed, PostgresCodec>, + pending_responses: VecDeque, parameters: HashMap, receiver: mpsc::UnboundedReceiver, ) -> Connection { @@ -72,7 +73,7 @@ where parameters, receiver, pending_request: None, - pending_response: None, + pending_responses, responses: VecDeque::new(), state: State::Active, } @@ -82,7 +83,7 @@ where &mut self, cx: &mut Context<'_>, ) -> Poll>> { - if let Some(message) = self.pending_response.take() { + if let Some(message) = self.pending_responses.pop_front() { trace!("retrying pending response"); return Poll::Ready(Some(Ok(message))); } @@ -158,7 +159,7 @@ where } Poll::Pending => { self.responses.push_front(response); - self.pending_response = Some(BackendMessage::Normal { + self.pending_responses.push_back(BackendMessage::Normal { messages, request_complete, }); diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 92f1edce..73860115 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -570,6 +570,45 @@ async fn copy_out() { assert_eq!(&data[..], b"1\tjim\n2\tjoe\n"); } +#[tokio::test] +async fn notices() { + let long_name = "x".repeat(65); + let (client, mut connection) = + connect_raw(&format!("user=postgres application_name={}", long_name,)) + .await + .unwrap(); + + let (tx, rx) = mpsc::unbounded(); + let stream = stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!(e)); + let connection = stream.forward(tx).map(|r| r.unwrap()); + tokio::spawn(connection); + + client + .batch_execute("DROP DATABASE IF EXISTS noexistdb") + .await + .unwrap(); + + drop(client); + + let notices = rx + .filter_map(|m| match m { + AsyncMessage::Notice(n) => future::ready(Some(n)), + _ => future::ready(None), + }) + .collect::>() + .await; + assert_eq!(notices.len(), 2); + assert_eq!( + notices[0].message(), + "identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \ + will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"" + ); + assert_eq!( + notices[1].message(), + "database \"noexistdb\" does not exist, skipping" + ); +} + #[tokio::test] async fn notifications() { let (client, mut connection) = connect_raw("user=postgres").await.unwrap();