Re-add savepoint method to Transaction

Revives #184.

The rewrite for async/await and Tokio accidentally lost functionality
that allowed users to assign specific names to savepoints when using
nested transactions. This functionality had originally been added
in #184 and had been updated in #374.

This commit revives this functionality using a similar scheme to the
one that existed before. This should allow CockroachDB users to update
to the next patch release of version `0.17`.
This commit is contained in:
Nathan VanBenschoten 2020-05-01 12:55:48 -04:00
parent e3d3c6d5cd
commit 64d6e97eff
3 changed files with 97 additions and 16 deletions

View File

@ -151,6 +151,57 @@ fn nested_transactions() {
assert_eq!(rows[2].get::<_, i32>(0), 4);
}
#[test]
fn savepoints() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
.unwrap();
let mut transaction = client.transaction().unwrap();
transaction
.execute("INSERT INTO foo (id) VALUES (1)", &[])
.unwrap();
let mut savepoint1 = transaction.savepoint("savepoint1").unwrap();
savepoint1
.execute("INSERT INTO foo (id) VALUES (2)", &[])
.unwrap();
savepoint1.rollback().unwrap();
let rows = transaction
.query("SELECT id FROM foo ORDER BY id", &[])
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
let mut savepoint2 = transaction.savepoint("savepoint2").unwrap();
savepoint2
.execute("INSERT INTO foo (id) VALUES(3)", &[])
.unwrap();
let mut savepoint3 = savepoint2.savepoint("savepoint3").unwrap();
savepoint3
.execute("INSERT INTO foo (id) VALUES(4)", &[])
.unwrap();
savepoint3.commit().unwrap();
savepoint2.commit().unwrap();
transaction.commit().unwrap();
let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap();
assert_eq!(rows.len(), 3);
assert_eq!(rows[0].get::<_, i32>(0), 1);
assert_eq!(rows[1].get::<_, i32>(0), 3);
assert_eq!(rows[2].get::<_, i32>(0), 4);
}
#[test]
fn copy_in() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();

View File

@ -173,7 +173,7 @@ impl<'a> Transaction<'a> {
CancelToken::new(self.transaction.cancel_token())
}
/// Like `Client::transaction`.
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = self.connection.block_on(self.transaction.transaction())?;
Ok(Transaction {
@ -181,4 +181,15 @@ impl<'a> Transaction<'a> {
transaction,
})
}
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
let transaction = self.connection.block_on(self.transaction.savepoint(name))?;
Ok(Transaction {
connection: self.connection.as_ref(),
transaction,
})
}
}

View File

@ -23,20 +23,26 @@ use tokio::io::{AsyncRead, AsyncWrite};
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
depth: u32,
savepoint: Option<Savepoint>,
done: bool,
}
/// A representation of a PostgreSQL database savepoint.
struct Savepoint {
name: String,
depth: u32,
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if self.done {
return;
}
let query = if self.depth == 0 {
"ROLLBACK".to_string()
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("ROLLBACK TO {}", sp.name)
} else {
format!("ROLLBACK TO sp{}", self.depth)
"ROLLBACK".to_string()
};
let buf = self.client.inner().with_buf(|buf| {
frontend::query(&query, buf).unwrap();
@ -53,7 +59,7 @@ impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
depth: 0,
savepoint: None,
done: false,
}
}
@ -61,10 +67,10 @@ impl<'a> Transaction<'a> {
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.depth == 0 {
"COMMIT".to_string()
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("RELEASE {}", sp.name)
} else {
format!("RELEASE sp{}", self.depth)
"COMMIT".to_string()
};
self.client.batch_execute(&query).await
}
@ -74,10 +80,10 @@ impl<'a> Transaction<'a> {
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.depth == 0 {
"ROLLBACK".to_string()
let query = if let Some(sp) = self.savepoint.as_ref() {
format!("ROLLBACK TO {}", sp.name)
} else {
format!("ROLLBACK TO sp{}", self.depth)
"ROLLBACK".to_string()
};
self.client.batch_execute(&query).await
}
@ -272,15 +278,28 @@ impl<'a> Transaction<'a> {
self.client.cancel_query_raw(stream, tls).await
}
/// Like `Client::transaction`.
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let depth = self.depth + 1;
let query = format!("SAVEPOINT sp{}", depth);
self._savepoint(None).await
}
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
self._savepoint(Some(name.into())).await
}
async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
let name = name.unwrap_or_else(|| format!("sp_{}", depth));
let query = format!("SAVEPOINT {}", name);
self.batch_execute(&query).await?;
Ok(Transaction {
client: self.client,
depth,
savepoint: Some(Savepoint { name, depth }),
done: false,
})
}