From b2df11579f8b49728d3096b6bd6da0b7ab27ccf0 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Tue, 19 Oct 2021 19:36:14 -0400 Subject: [PATCH] Fix commit-time error reporting Closes #832 --- tokio-postgres/src/query.rs | 25 ++++++++++++++----------- tokio-postgres/tests/test/main.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index f139ed91..cdb95219 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -99,11 +99,12 @@ where }; let mut responses = start(client, buf).await?; + let mut rows = 0; loop { match responses.next().await? { Message::DataRow(_) => {} Message::CommandComplete(body) => { - let rows = body + rows = body .tag() .map_err(Error::parse)? .rsplit(' ') @@ -111,9 +112,9 @@ where .unwrap() .parse() .unwrap_or(0); - return Ok(rows); } - Message::EmptyQueryResponse => return Ok(0), + Message::EmptyQueryResponse => rows = 0, + Message::ReadyForQuery(_) => return Ok(rows), _ => return Err(Error::unexpected_message()), } } @@ -203,15 +204,17 @@ impl Stream for RowStream { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); - match ready!(this.responses.poll_next(cx)?) { - Message::DataRow(body) => { - Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::DataRow(body) => { + return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + } + Message::EmptyQueryResponse + | Message::CommandComplete(_) + | Message::PortalSuspended => {} + Message::ReadyForQuery(_) => return Poll::Ready(None), + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } - Message::EmptyQueryResponse - | Message::CommandComplete(_) - | Message::PortalSuspended => Poll::Ready(None), - Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))), - _ => Poll::Ready(Some(Err(Error::unexpected_message()))), } } } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index c0b4bf20..31d7fa29 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -805,3 +805,29 @@ async fn query_opt() { .err() .unwrap(); } + +#[tokio::test] +async fn deferred_constraint() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE t ( + i INT, + UNIQUE (i) DEFERRABLE INITIALLY DEFERRED + ); + ", + ) + .await + .unwrap(); + + client + .execute("INSERT INTO t (i) VALUES (1)", &[]) + .await + .unwrap(); + client + .execute("INSERT INTO t (i) VALUES (1)", &[]) + .await + .unwrap_err(); +}