use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::copy_out::CopyStream; use crate::query::RowStream; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; use crate::types::{ToSql, Type}; #[cfg(feature = "runtime")] use crate::Socket; use crate::{ bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement, }; use bytes::IntoBuf; use futures::{TryStream, TryStreamExt}; use postgres_protocol::message::frontend; use std::error; use tokio::io::{AsyncRead, AsyncWrite}; /// A representation of a PostgreSQL database transaction. /// /// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the /// transaction. Transactions can be nested, with inner transactions implemented via safepoints. pub struct Transaction<'a> { client: &'a mut Client, depth: u32, done: bool, } impl<'a> Drop for Transaction<'a> { fn drop(&mut self) { if self.done { return; } let mut buf = vec![]; let query = if self.depth == 0 { "ROLLBACK".to_string() } else { format!("ROLLBACK TO sp{}", self.depth) }; frontend::query(&query, &mut buf).unwrap(); let _ = self .client .inner() .send(RequestMessages::Single(FrontendMessage::Raw(buf))); } } impl<'a> Transaction<'a> { pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { Transaction { client, depth: 0, done: false, } } /// Consumes the transaction, committing all changes made within it. pub async fn commit(mut self) -> Result<(), Error> { self.done = true; let query = if self.depth == 0 { "COMMIT".to_string() } else { format!("RELEASE sp{}", self.depth) }; self.client.batch_execute(&query).await } /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. pub async fn rollback(mut self) -> Result<(), Error> { self.done = true; let query = if self.depth == 0 { "ROLLBACK".to_string() } else { format!("ROLLBACK TO sp{}", self.depth) }; self.client.batch_execute(&query).await } /// Like `Client::prepare`. pub async fn prepare(&self, query: &str) -> Result { self.client.prepare(query).await } /// Like `Client::prepare_typed`. pub async fn prepare_typed( &self, query: &str, parameter_types: &[Type], ) -> Result { self.client.prepare_typed(query, parameter_types).await } /// Like `Client::query`. pub async fn query( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result, Error> where T: ?Sized + ToStatement, { self.client.query(statement, params).await } /// Like `Client::query_one`. pub async fn query_one( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result where T: ?Sized + ToStatement, { self.client.query_one(statement, params).await } /// Like `Client::query_raw`. pub async fn query_raw<'b, T, I>(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, I: IntoIterator, I::IntoIter: ExactSizeIterator, { self.client.query_raw(statement, params).await } /// Like `Client::execute`. pub async fn execute( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result where T: ?Sized + ToStatement, { self.client.execute(statement, params).await } /// Like `Client::execute_iter`. pub async fn execute_raw<'b, I, T>( &self, statement: &Statement, params: I, ) -> Result where T: ?Sized + ToStatement, I: IntoIterator, I::IntoIter: ExactSizeIterator, { self.client.execute_raw(statement, params).await } /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried. /// /// Portals only last for the duration of the transaction in which they are created, and can only be used on the /// connection that created them. /// /// # Panics /// /// Panics if the number of parameters provided does not match the number expected. pub async fn bind( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result where T: ?Sized + ToStatement, { self.bind_iter(statement, slice_iter(params)).await } /// Like [`bind`], but takes an iterator of parameters rather than a slice. /// /// [`bind`]: #method.bind pub async fn bind_iter<'b, T, I>(&self, statement: &T, params: I) -> Result where T: ?Sized + ToStatement, I: IntoIterator, I::IntoIter: ExactSizeIterator, { let statement = statement.__convert().into_statement(&self.client).await?; bind::bind(self.client.inner(), statement, params).await } /// Continues execution of a portal, returning a stream of the resulting rows. /// /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// `query_portal`. If the requested number is negative or 0, all rows will be returned. pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result, Error> { self.query_portal_raw(portal, max_rows) .await? .try_collect() .await } /// The maximally flexible version of `query_portal`. pub async fn query_portal_raw( &self, portal: &Portal, max_rows: i32, ) -> Result { query::query_portal(self.client.inner(), portal, max_rows).await } /// Like `Client::copy_in`. pub async fn copy_in( &self, statement: &T, params: &[&(dyn ToSql + Sync)], stream: S, ) -> Result where T: ?Sized + ToStatement, S: TryStream, S::Ok: IntoBuf, ::Buf: 'static + Send, S::Error: Into>, { self.client.copy_in(statement, params, stream).await } /// Like `Client::copy_out`. pub async fn copy_out( &self, statement: &T, params: &[&(dyn ToSql + Sync)], ) -> Result where T: ?Sized + ToStatement, { self.client.copy_out(statement, params).await } /// Like `Client::simple_query`. pub async fn simple_query( &self, query: &str, ) -> Result, Error> { self.client.simple_query(query).await } /// Like `Client::batch_execute`. pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { self.client.batch_execute(query).await } /// Like `Client::cancel_query`. #[cfg(feature = "runtime")] pub async fn cancel_query(&self, tls: T) -> Result<(), Error> where T: MakeTlsConnect, { self.client.cancel_query(tls).await } /// Like `Client::cancel_query_raw`. pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { self.client.cancel_query_raw(stream, tls).await } /// Like `Client::transaction`. pub async fn transaction(&mut self) -> Result, Error> { let depth = self.depth + 1; let query = format!("SAVEPOINT sp{}", depth); self.batch_execute(&query).await?; Ok(Transaction { client: self.client, depth, done: false, }) } }