use crate::client::{InnerClient, Responses}; use pin_project::pin_project; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{IsNull, ToSql}; use crate::{Error, Portal, Row, Statement}; use bytes::{Bytes, BytesMut}; use futures::{ready, Stream}; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::pin::Pin; use std::task::{Context, Poll}; use std::marker::PhantomPinned; pub async fn query<'a, I>( client: &InnerClient, statement: Statement, params: I, ) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, { let buf = encode(client, &statement, params)?; let responses = start(client, buf).await?; Ok(RowStream { statement, responses, _p: PhantomPinned, }) } pub async fn query_portal( client: &InnerClient, portal: &Portal, max_rows: i32, ) -> Result { let buf = client.with_buf(|buf| { frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?; frontend::sync(buf); Ok(buf.take().freeze()) })?; let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; Ok(RowStream { statement: portal.statement().clone(), responses, _p: PhantomPinned, }) } pub async fn execute<'a, I>( client: &InnerClient, statement: Statement, params: I, ) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, { let buf = encode(client, &statement, params)?; let mut responses = start(client, buf).await?; loop { match responses.next().await? { Message::DataRow(_) => {} Message::CommandComplete(body) => { let rows = body .tag() .map_err(Error::parse)? .rsplit(' ') .next() .unwrap() .parse() .unwrap_or(0); return Ok(rows); } Message::EmptyQueryResponse => return Ok(0), _ => return Err(Error::unexpected_message()), } } } async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { Message::BindComplete => {} _ => return Err(Error::unexpected_message()), } Ok(responses) } pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, { client.with_buf(|buf| { encode_bind(statement, params, "", buf)?; frontend::execute("", 0, buf).map_err(Error::encode)?; frontend::sync(buf); Ok(buf.take().freeze()) }) } pub fn encode_bind<'a, I>( statement: &Statement, params: I, portal: &str, buf: &mut BytesMut, ) -> Result<(), Error> where I: IntoIterator, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); assert!( statement.params().len() == params.len(), "expected {} parameters but got {}", statement.params().len(), params.len() ); let mut error_idx = 0; let r = frontend::bind( portal, statement.name(), Some(1), params.zip(statement.params()).enumerate(), |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Err(e) => { error_idx = idx; Err(e) } }, Some(1), buf, ); match r { Ok(()) => Ok(()), Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), } } /// A stream of table rows. #[pin_project] pub struct RowStream { statement: Statement, responses: Responses, #[pin] _p: PhantomPinned, } impl Stream for RowStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) } Message::EmptyQueryResponse | Message::CommandComplete(_) | Message::PortalSuspended => Poll::Ready(None), Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))), _ => Poll::Ready(Some(Err(Error::unexpected_message()))), } } }