Ensure transactions roll back immediately on drop

Closes #635
This commit is contained in:
Steven Fackler 2020-07-19 13:24:46 -06:00
parent 4fd7527c3c
commit a4a68d543d
2 changed files with 83 additions and 33 deletions

View File

@ -100,6 +100,31 @@ fn transaction_drop() {
assert_eq!(rows.len(), 0);
}
#[test]
fn transaction_drop_immediate_rollback() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
let mut client2 = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.simple_query("CREATE TABLE IF NOT EXISTS foo (id SERIAL PRIMARY KEY)")
.unwrap();
client
.execute("INSERT INTO foo VALUES (1) ON CONFLICT DO NOTHING", &[])
.unwrap();
let mut transaction = client.transaction().unwrap();
transaction
.execute("SELECT * FROM foo FOR UPDATE", &[])
.unwrap();
drop(transaction);
let rows = client2.query("SELECT * FROM foo FOR UPDATE", &[]).unwrap();
assert_eq!(rows.len(), 1);
}
#[test]
fn nested_transactions() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

View File

@ -9,7 +9,15 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage};
/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
pub struct Transaction<'a> {
connection: ConnectionRef<'a>,
transaction: tokio_postgres::Transaction<'a>,
transaction: Option<tokio_postgres::Transaction<'a>>,
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if let Some(transaction) = self.transaction.take() {
let _ = self.connection.block_on(transaction.rollback());
}
}
}
impl<'a> Transaction<'a> {
@ -19,31 +27,38 @@ impl<'a> Transaction<'a> {
) -> Transaction<'a> {
Transaction {
connection,
transaction,
transaction: Some(transaction),
}
}
/// Consumes the transaction, committing all changes made within it.
pub fn commit(mut self) -> Result<(), Error> {
self.connection.block_on(self.transaction.commit())
self.connection
.block_on(self.transaction.take().unwrap().commit())
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub fn rollback(mut self) -> Result<(), Error> {
self.connection.block_on(self.transaction.rollback())
self.connection
.block_on(self.transaction.take().unwrap().rollback())
}
/// Like `Client::prepare`.
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.connection.block_on(self.transaction.prepare(query))
self.connection
.block_on(self.transaction.as_ref().unwrap().prepare(query))
}
/// Like `Client::prepare_typed`.
pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
self.connection
.block_on(self.transaction.prepare_typed(query, types))
self.connection.block_on(
self.transaction
.as_ref()
.unwrap()
.prepare_typed(query, types),
)
}
/// Like `Client::execute`.
@ -52,7 +67,7 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement,
{
self.connection
.block_on(self.transaction.execute(query, params))
.block_on(self.transaction.as_ref().unwrap().execute(query, params))
}
/// Like `Client::query`.
@ -61,7 +76,7 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement,
{
self.connection
.block_on(self.transaction.query(query, params))
.block_on(self.transaction.as_ref().unwrap().query(query, params))
}
/// Like `Client::query_one`.
@ -70,7 +85,7 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement,
{
self.connection
.block_on(self.transaction.query_one(query, params))
.block_on(self.transaction.as_ref().unwrap().query_one(query, params))
}
/// Like `Client::query_opt`.
@ -83,7 +98,7 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement,
{
self.connection
.block_on(self.transaction.query_opt(query, params))
.block_on(self.transaction.as_ref().unwrap().query_opt(query, params))
}
/// Like `Client::query_raw`.
@ -95,7 +110,7 @@ impl<'a> Transaction<'a> {
{
let stream = self
.connection
.block_on(self.transaction.query_raw(query, params))?;
.block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?;
Ok(RowIter::new(self.connection.as_ref(), stream))
}
@ -114,7 +129,7 @@ impl<'a> Transaction<'a> {
T: ?Sized + ToStatement,
{
self.connection
.block_on(self.transaction.bind(query, params))
.block_on(self.transaction.as_ref().unwrap().bind(query, params))
}
/// Continues execution of a portal, returning the next set of rows.
@ -122,8 +137,12 @@ impl<'a> Transaction<'a> {
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.connection
.block_on(self.transaction.query_portal(portal, max_rows))
self.connection.block_on(
self.transaction
.as_ref()
.unwrap()
.query_portal(portal, max_rows),
)
}
/// The maximally flexible version of `query_portal`.
@ -132,9 +151,12 @@ impl<'a> Transaction<'a> {
portal: &Portal,
max_rows: i32,
) -> Result<RowIter<'_>, Error> {
let stream = self
.connection
.block_on(self.transaction.query_portal_raw(portal, max_rows))?;
let stream = self.connection.block_on(
self.transaction
.as_ref()
.unwrap()
.query_portal_raw(portal, max_rows),
)?;
Ok(RowIter::new(self.connection.as_ref(), stream))
}
@ -143,7 +165,9 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
let sink = self.connection.block_on(self.transaction.copy_in(query))?;
let sink = self
.connection
.block_on(self.transaction.as_ref().unwrap().copy_in(query))?;
Ok(CopyInWriter::new(self.connection.as_ref(), sink))
}
@ -152,44 +176,45 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
let stream = self.connection.block_on(self.transaction.copy_out(query))?;
let stream = self
.connection
.block_on(self.transaction.as_ref().unwrap().copy_out(query))?;
Ok(CopyOutReader::new(self.connection.as_ref(), stream))
}
/// Like `Client::simple_query`.
pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.connection
.block_on(self.transaction.simple_query(query))
.block_on(self.transaction.as_ref().unwrap().simple_query(query))
}
/// Like `Client::batch_execute`.
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
self.connection
.block_on(self.transaction.batch_execute(query))
.block_on(self.transaction.as_ref().unwrap().batch_execute(query))
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
CancelToken::new(self.transaction.cancel_token())
CancelToken::new(self.transaction.as_ref().unwrap().cancel_token())
}
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = self.connection.block_on(self.transaction.transaction())?;
Ok(Transaction {
connection: self.connection.as_ref(),
transaction,
})
let transaction = self
.connection
.block_on(self.transaction.as_mut().unwrap().transaction())?;
Ok(Transaction::new(self.connection.as_ref(), transaction))
}
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
let transaction = self.connection.block_on(self.transaction.savepoint(name))?;
Ok(Transaction {
connection: self.connection.as_ref(),
transaction,
})
let transaction = self
.connection
.block_on(self.transaction.as_mut().unwrap().savepoint(name))?;
Ok(Transaction::new(self.connection.as_ref(), transaction))
}
}