Merge pull request #985 from risinglightdb/skyzh/raw-stream

feat: add `rows_affected` to RowStream
This commit is contained in:
Steven Fackler 2023-01-19 21:30:34 -05:00 committed by GitHub
commit 4bae134f50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 28 deletions

View File

@ -1,6 +1,7 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::query::extract_row_affected;
use crate::{query, slice_iter, Error, Statement};
use bytes::{Buf, BufMut, BytesMut};
use futures_channel::mpsc;
@ -110,14 +111,7 @@ where
let this = self.as_mut().project();
match ready!(this.responses.poll_next(cx))? {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
let rows = extract_row_affected(&body)?;
return Poll::Ready(Ok(rows));
}
_ => return Poll::Ready(Err(Error::unexpected_message())),

View File

@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut};
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol::message::backend::Message;
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
use postgres_protocol::message::frontend;
use std::fmt;
use std::marker::PhantomPinned;
@ -52,6 +52,7 @@ where
Ok(RowStream {
statement,
responses,
rows_affected: None,
_p: PhantomPinned,
})
}
@ -72,10 +73,24 @@ pub async fn query_portal(
Ok(RowStream {
statement: portal.statement().clone(),
responses,
rows_affected: None,
_p: PhantomPinned,
})
}
/// Extract the number of rows affected from [`CommandCompleteBody`].
pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
Ok(rows)
}
pub async fn execute<P, I>(
client: &InnerClient,
statement: Statement,
@ -104,14 +119,7 @@ where
match responses.next().await? {
Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
rows = extract_row_affected(&body)?;
}
Message::EmptyQueryResponse => rows = 0,
Message::ReadyForQuery(_) => return Ok(rows),
@ -202,6 +210,7 @@ pin_project! {
pub struct RowStream {
statement: Statement,
responses: Responses,
rows_affected: Option<u64>,
#[pin]
_p: PhantomPinned,
}
@ -217,12 +226,22 @@ impl Stream for RowStream {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
}
Message::EmptyQueryResponse
| Message::CommandComplete(_)
| Message::PortalSuspended => {}
Message::CommandComplete(body) => {
*this.rows_affected = Some(extract_row_affected(&body)?);
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::ReadyForQuery(_) => return Poll::Ready(None),
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}
impl RowStream {
/// Returns the number of rows affected by the query.
///
/// This function will return `None` until the stream has been exhausted.
pub fn rows_affected(&self) -> Option<u64> {
self.rows_affected
}
}

View File

@ -1,6 +1,7 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::query::extract_row_affected;
use crate::{Error, SimpleQueryMessage, SimpleQueryRow};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
@ -87,14 +88,7 @@ impl Stream for SimpleQueryStream {
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
let rows = extract_row_affected(&body)?;
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
}
Message::EmptyQueryResponse => {