rust-postgres/tokio-postgres/src/transaction.rs
Steven Fackler 0736382593 Fix tests
2019-10-12 17:47:55 -07:00

280 lines
8.2 KiB
Rust

use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::copy_out::CopyStream;
use crate::query::RowStream;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::types::{ToSql, Type};
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{
bind, query, slice_iter, Client, Error, Portal, Row, SimpleQueryMessage, Statement, ToStatement,
};
use bytes::IntoBuf;
use futures::{TryStream, TryStreamExt};
use postgres_protocol::message::frontend;
use postgres_types::private::BytesMut;
use std::error;
use tokio::io::{AsyncRead, AsyncWrite};
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
depth: u32,
done: bool,
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if self.done {
return;
}
let mut buf = BytesMut::new();
let query = if self.depth == 0 {
"ROLLBACK".to_string()
} else {
format!("ROLLBACK TO sp{}", self.depth)
};
frontend::query(&query, &mut buf).unwrap();
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf.freeze())));
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
depth: 0,
done: false,
}
}
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.depth == 0 {
"COMMIT".to_string()
} else {
format!("RELEASE sp{}", self.depth)
};
self.client.batch_execute(&query).await
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.depth == 0 {
"ROLLBACK".to_string()
} else {
format!("ROLLBACK TO sp{}", self.depth)
};
self.client.batch_execute(&query).await
}
/// Like `Client::prepare`.
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.client.prepare(query).await
}
/// Like `Client::prepare_typed`.
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
self.client.prepare_typed(query, parameter_types).await
}
/// Like `Client::query`.
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.client.query(statement, params).await
}
/// Like `Client::query_one`.
pub async fn query_one<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Row, Error>
where
T: ?Sized + ToStatement,
{
self.client.query_one(statement, params).await
}
/// Like `Client::query_raw`.
pub async fn query_raw<'b, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw(statement, params).await
}
/// Like `Client::execute`.
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.client.execute(statement, params).await
}
/// Like `Client::execute_iter`.
pub async fn execute_raw<'b, I, T>(
&self,
statement: &Statement,
params: I,
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
self.client.execute_raw(statement, params).await
}
/// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried.
///
/// Portals only last for the duration of the transaction in which they are created, and can only be used on the
/// connection that created them.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn bind<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Portal, Error>
where
T: ?Sized + ToStatement,
{
self.bind_iter(statement, slice_iter(params)).await
}
/// Like [`bind`], but takes an iterator of parameters rather than a slice.
///
/// [`bind`]: #method.bind
pub async fn bind_iter<'b, T, I>(&self, statement: &T, params: I) -> Result<Portal, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(&self.client).await?;
bind::bind(self.client.inner(), statement, params).await
}
/// Continues execution of a portal, returning a stream of the resulting rows.
///
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all rows will be returned.
pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.query_portal_raw(portal, max_rows)
.await?
.try_collect()
.await
}
/// The maximally flexible version of `query_portal`.
pub async fn query_portal_raw(
&self,
portal: &Portal,
max_rows: i32,
) -> Result<RowStream, Error> {
query::query_portal(self.client.inner(), portal, max_rows).await
}
/// Like `Client::copy_in`.
pub async fn copy_in<T, S>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
stream: S,
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
S: TryStream,
S::Ok: IntoBuf,
<S::Ok as IntoBuf>::Buf: 'static + Send,
S::Error: Into<Box<dyn error::Error + Sync + Send>>,
{
self.client.copy_in(statement, params, stream).await
}
/// Like `Client::copy_out`.
pub async fn copy_out<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<CopyStream, Error>
where
T: ?Sized + ToStatement,
{
self.client.copy_out(statement, params).await
}
/// Like `Client::simple_query`.
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.client.simple_query(query).await
}
/// Like `Client::batch_execute`.
pub async fn batch_execute(&self, query: &str) -> Result<(), Error> {
self.client.batch_execute(query).await
}
/// Like `Client::cancel_query`.
#[cfg(feature = "runtime")]
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where
T: MakeTlsConnect<Socket>,
{
self.client.cancel_query(tls).await
}
/// Like `Client::cancel_query_raw`.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
self.client.cancel_query_raw(stream, tls).await
}
/// Like `Client::transaction`.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let depth = self.depth + 1;
let query = format!("SAVEPOINT sp{}", depth);
self.batch_execute(&query).await?;
Ok(Transaction {
client: self.client,
depth,
done: false,
})
}
}