use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{IsNull, ToSql}; use crate::{Error, Row, Statement}; use futures::{ready, Stream}; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; pub async fn query<'a, I>( client: Arc, statement: &Statement, params: I, ) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, { let responses = start(&client, &statement, params).await?; Ok(Query { statement: statement.clone(), responses, }) } pub async fn execute<'a, I>( client: Arc, statement: &Statement, params: I, ) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, { let mut responses = start(&client, &statement, params).await?; loop { match responses.next().await? { Message::DataRow(_) => {} Message::CommandComplete(body) => { let rows = body .tag() .map_err(Error::parse)? .rsplit(' ') .next() .unwrap() .parse() .unwrap_or(0); return Ok(rows); } Message::EmptyQueryResponse => return Ok(0), _ => return Err(Error::unexpected_message()), } } } async fn start<'a, I>( client: &Arc, statement: &Statement, params: I, ) -> Result where I: IntoIterator, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); assert!( statement.params().len() == params.len(), "expected {} parameters but got {}", statement.params().len(), params.len() ); let mut buf = vec![]; let mut error_idx = 0; let r = frontend::bind( "", statement.name(), Some(1), params.zip(statement.params()).enumerate(), |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), Err(e) => { error_idx = idx; Err(e) } }, Some(1), &mut buf, ); match r { Ok(()) => {} Err(frontend::BindError::Conversion(e)) => return Err(Error::to_sql(e, error_idx)), Err(frontend::BindError::Serialization(e)) => return Err(Error::encode(e)), } frontend::execute("", 0, &mut buf).map_err(Error::encode)?; frontend::sync(&mut buf); let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { Message::BindComplete => {} _ => return Err(Error::unexpected_message()), } Ok(responses) } pub struct Query { statement: Statement, responses: Responses, } impl Stream for Query { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match ready!(self.responses.poll_next(cx)?) { Message::DataRow(body) => { Poll::Ready(Some(Ok(Row::new(self.statement.clone(), body)?))) } Message::EmptyQueryResponse | Message::CommandComplete(_) => Poll::Ready(None), Message::ErrorResponse(body) => Poll::Ready(Some(Err(Error::db(body)))), _ => Poll::Ready(Some(Err(Error::unexpected_message()))), } } }