Merge pull request #985 from risinglightdb/skyzh/raw-stream
feat: add `rows_affected` to RowStream
This commit is contained in:
commit
4bae134f50
@ -1,6 +1,7 @@
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::query::extract_row_affected;
|
||||
use crate::{query, slice_iter, Error, Statement};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use futures_channel::mpsc;
|
||||
@ -110,14 +111,7 @@ where
|
||||
let this = self.as_mut().project();
|
||||
match ready!(this.responses.poll_next(cx))? {
|
||||
Message::CommandComplete(body) => {
|
||||
let rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
let rows = extract_row_affected(&body)?;
|
||||
return Poll::Ready(Ok(rows));
|
||||
}
|
||||
_ => return Poll::Ready(Err(Error::unexpected_message())),
|
||||
|
@ -7,7 +7,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use futures_util::{ready, Stream};
|
||||
use log::{debug, log_enabled, Level};
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::backend::{CommandCompleteBody, Message};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::fmt;
|
||||
use std::marker::PhantomPinned;
|
||||
@ -52,6 +52,7 @@ where
|
||||
Ok(RowStream {
|
||||
statement,
|
||||
responses,
|
||||
rows_affected: None,
|
||||
_p: PhantomPinned,
|
||||
})
|
||||
}
|
||||
@ -72,10 +73,24 @@ pub async fn query_portal(
|
||||
Ok(RowStream {
|
||||
statement: portal.statement().clone(),
|
||||
responses,
|
||||
rows_affected: None,
|
||||
_p: PhantomPinned,
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract the number of rows affected from [`CommandCompleteBody`].
|
||||
pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
|
||||
let rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
pub async fn execute<P, I>(
|
||||
client: &InnerClient,
|
||||
statement: Statement,
|
||||
@ -104,14 +119,7 @@ where
|
||||
match responses.next().await? {
|
||||
Message::DataRow(_) => {}
|
||||
Message::CommandComplete(body) => {
|
||||
rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
rows = extract_row_affected(&body)?;
|
||||
}
|
||||
Message::EmptyQueryResponse => rows = 0,
|
||||
Message::ReadyForQuery(_) => return Ok(rows),
|
||||
@ -202,6 +210,7 @@ pin_project! {
|
||||
pub struct RowStream {
|
||||
statement: Statement,
|
||||
responses: Responses,
|
||||
rows_affected: Option<u64>,
|
||||
#[pin]
|
||||
_p: PhantomPinned,
|
||||
}
|
||||
@ -217,12 +226,22 @@ impl Stream for RowStream {
|
||||
Message::DataRow(body) => {
|
||||
return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
|
||||
}
|
||||
Message::EmptyQueryResponse
|
||||
| Message::CommandComplete(_)
|
||||
| Message::PortalSuspended => {}
|
||||
Message::CommandComplete(body) => {
|
||||
*this.rows_affected = Some(extract_row_affected(&body)?);
|
||||
}
|
||||
Message::EmptyQueryResponse | Message::PortalSuspended => {}
|
||||
Message::ReadyForQuery(_) => return Poll::Ready(None),
|
||||
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RowStream {
|
||||
/// Returns the number of rows affected by the query.
|
||||
///
|
||||
/// This function will return `None` until the stream has been exhausted.
|
||||
pub fn rows_affected(&self) -> Option<u64> {
|
||||
self.rows_affected
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::query::extract_row_affected;
|
||||
use crate::{Error, SimpleQueryMessage, SimpleQueryRow};
|
||||
use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
@ -87,14 +88,7 @@ impl Stream for SimpleQueryStream {
|
||||
loop {
|
||||
match ready!(this.responses.poll_next(cx)?) {
|
||||
Message::CommandComplete(body) => {
|
||||
let rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
let rows = extract_row_affected(&body)?;
|
||||
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
|
||||
}
|
||||
Message::EmptyQueryResponse => {
|
||||
|
Loading…
Reference in New Issue
Block a user