use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::copy_out::CopyStream; use crate::query::RowStream; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; use crate::types::{ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement, }; use bytes::IntoBuf; use futures::{TryStream, TryStreamExt}; use postgres_protocol::message::frontend; use postgres_types::private::BytesMut; use std::error; use tokio::io::{AsyncRead, AsyncWrite}; /// A representation of a PostgreSQL database transaction. /// /// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the /// transaction. Transactions can be nested, with inner transactions implemented via safepoints. pub struct Transaction<'a> { client: &'a mut Client, depth: u32, done: bool, } impl<'a> Drop for Transaction<'a> { fn drop(&mut self) { if self.done { return; } let mut buf = BytesMut::new(); let query = if self.depth == 0 { "ROLLBACK".to_string() } else { format!("ROLLBACK TO sp{}", self.depth) }; frontend::query(&query, &mut buf).unwrap(); let _ = self .client .inner() .send(RequestMessages::Single(FrontendMessage::Raw(buf.freeze()))); } } impl<'a> Transaction<'a> { pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { Transaction { client, depth: 0, done: false, } } /// Consumes the transaction, committing all changes made within it. pub async fn commit(mut self) -> Result<(), Error> { self.done = true; let query = if self.depth == 0 { "COMMIT".to_string() } else { format!("RELEASE sp{}", self.depth) }; self.client.batch_execute(&query).await } /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. pub async fn rollback(mut self) -> Result<(), Error> { self.done = true; let query = if self.depth == 0 { "ROLLBACK".to_string() } else { format!("ROLLBACK TO sp{}", self.depth) }; self.client.batch_execute(&query).await } /// Like `Client::prepare`. pub async fn prepare(&self, query: &str) -> Result<Statement, Error> { self.client.prepare(query).await } /// Like `Client::prepare_typed`. pub async fn prepare_typed( &self, query: &str, parameter_types: &[Type], ) -> Result<Statement, Error> { self.client.prepare_typed(query, parameter_types).await } /// Like `Client::query`. pub async fn query<T>( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result<Vec<Row>, Error> where T: ?Sized + ToStatement, { self.client.query(statement, params).await } /// Like `Client::query_one`. pub async fn query_one<T>( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result<Row, Error> where T: ?Sized + ToStatement, { self.client.query_one(statement, params).await } /// Like `Client::query_raw`. pub async fn query_raw<'b, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error> where T: ?Sized + ToStatement, I: IntoIterator<Item = &'b dyn ToSql>, I::IntoIter: ExactSizeIterator, { self.client.query_raw(statement, params).await } /// Like `Client::execute`. pub async fn execute<T>( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result<u64, Error> where T: ?Sized + ToStatement, { self.client.execute(statement, params).await } /// Like `Client::execute_iter`. pub async fn execute_raw<'b, I, T>( &self, statement: &Statement, params: I, ) -> Result<u64, Error> where T: ?Sized + ToStatement, I: IntoIterator<Item = &'b dyn ToSql>, I::IntoIter: ExactSizeIterator, { self.client.execute_raw(statement, params).await } /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried. /// /// Portals only last for the duration of the transaction in which they are created, and can only be used on the /// connection that created them. /// /// # Panics /// /// Panics if the number of parameters provided does not match the number expected. pub async fn bind<T>( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result<Portal, Error> where T: ?Sized + ToStatement, { self.bind_iter(statement, slice_iter(params)).await } /// Like [`bind`], but takes an iterator of parameters rather than a slice. /// /// [`bind`]: #method.bind pub async fn bind_iter<'b, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error> where T: ?Sized + ToStatement, I: IntoIterator<Item = &'b dyn ToSql>, I::IntoIter: ExactSizeIterator, { let statement = statement.__convert().into_statement(&self.client).await?; bind::bind(self.client.inner(), statement, params).await } /// Continues execution of a portal, returning a stream of the resulting rows. /// /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// `query_portal`. If the requested number is negative or 0, all rows will be returned. pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> { self.query_portal_raw(portal, max_rows) .await? .try_collect() .await } /// The maximally flexible version of `query_portal`. pub async fn query_portal_raw( &self, portal: &Portal, max_rows: i32, ) -> Result<RowStream, Error> { query::query_portal(self.client.inner(), portal, max_rows).await } /// Like `Client::copy_in`. pub async fn copy_in<T, S>( &self, statement: &T, params: &[&(dyn ToSql + Sync)], stream: S, ) -> Result<u64, Error> where T: ?Sized + ToStatement, S: TryStream, S::Ok: IntoBuf, <S::Ok as IntoBuf>::Buf: 'static + Send, S::Error: Into<Box<dyn error::Error + Sync + Send>>, { self.client.copy_in(statement, params, stream).await } /// Like `Client::copy_out`. pub async fn copy_out<T>( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result<CopyStream, Error> where T: ?Sized + ToStatement, { self.client.copy_out(statement, params).await } /// Like `Client::simple_query`. pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> { self.client.simple_query(query).await } /// Like `Client::batch_execute`. pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { self.client.batch_execute(query).await } /// Like `Client::cancel_query`. #[cfg(feature = "runtime")] pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error> where T: MakeTlsConnect<Socket>, { self.client.cancel_query(tls).await } /// Like `Client::cancel_query_raw`. pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect<S>, { self.client.cancel_query_raw(stream, tls).await } /// Like `Client::transaction`. pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> { let depth = self.depth + 1; let query = format!("SAVEPOINT sp{}", depth); self.batch_execute(&query).await?; Ok(Transaction { client: self.client, depth, done: false, }) } }