Re-add savepoint method to Transaction
Revives #184. The rewrite for async/await and Tokio accidentally lost functionality that allowed users to assign specific names to savepoints when using nested transactions. This functionality had originally been added in #184 and had been updated in #374. This commit revives this functionality using a similar scheme to the one that existed before. This should allow CockroachDB users to update to the next patch release of version `0.17`.
This commit is contained in:
parent
e3d3c6d5cd
commit
64d6e97eff
@ -151,6 +151,57 @@ fn nested_transactions() {
|
||||
assert_eq!(rows[2].get::<_, i32>(0), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn savepoints() {
|
||||
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
|
||||
|
||||
client
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)")
|
||||
.unwrap();
|
||||
|
||||
let mut transaction = client.transaction().unwrap();
|
||||
|
||||
transaction
|
||||
.execute("INSERT INTO foo (id) VALUES (1)", &[])
|
||||
.unwrap();
|
||||
|
||||
let mut savepoint1 = transaction.savepoint("savepoint1").unwrap();
|
||||
|
||||
savepoint1
|
||||
.execute("INSERT INTO foo (id) VALUES (2)", &[])
|
||||
.unwrap();
|
||||
|
||||
savepoint1.rollback().unwrap();
|
||||
|
||||
let rows = transaction
|
||||
.query("SELECT id FROM foo ORDER BY id", &[])
|
||||
.unwrap();
|
||||
assert_eq!(rows.len(), 1);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
|
||||
let mut savepoint2 = transaction.savepoint("savepoint2").unwrap();
|
||||
|
||||
savepoint2
|
||||
.execute("INSERT INTO foo (id) VALUES(3)", &[])
|
||||
.unwrap();
|
||||
|
||||
let mut savepoint3 = savepoint2.savepoint("savepoint3").unwrap();
|
||||
|
||||
savepoint3
|
||||
.execute("INSERT INTO foo (id) VALUES(4)", &[])
|
||||
.unwrap();
|
||||
|
||||
savepoint3.commit().unwrap();
|
||||
savepoint2.commit().unwrap();
|
||||
transaction.commit().unwrap();
|
||||
|
||||
let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap();
|
||||
assert_eq!(rows.len(), 3);
|
||||
assert_eq!(rows[0].get::<_, i32>(0), 1);
|
||||
assert_eq!(rows[1].get::<_, i32>(0), 3);
|
||||
assert_eq!(rows[2].get::<_, i32>(0), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copy_in() {
|
||||
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
|
||||
|
@ -173,7 +173,7 @@ impl<'a> Transaction<'a> {
|
||||
CancelToken::new(self.transaction.cancel_token())
|
||||
}
|
||||
|
||||
/// Like `Client::transaction`.
|
||||
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
|
||||
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
|
||||
let transaction = self.connection.block_on(self.transaction.transaction())?;
|
||||
Ok(Transaction {
|
||||
@ -181,4 +181,15 @@ impl<'a> Transaction<'a> {
|
||||
transaction,
|
||||
})
|
||||
}
|
||||
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
|
||||
pub fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
|
||||
where
|
||||
I: Into<String>,
|
||||
{
|
||||
let transaction = self.connection.block_on(self.transaction.savepoint(name))?;
|
||||
Ok(Transaction {
|
||||
connection: self.connection.as_ref(),
|
||||
transaction,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -23,20 +23,26 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
|
||||
pub struct Transaction<'a> {
|
||||
client: &'a mut Client,
|
||||
depth: u32,
|
||||
savepoint: Option<Savepoint>,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
/// A representation of a PostgreSQL database savepoint.
|
||||
struct Savepoint {
|
||||
name: String,
|
||||
depth: u32,
|
||||
}
|
||||
|
||||
impl<'a> Drop for Transaction<'a> {
|
||||
fn drop(&mut self) {
|
||||
if self.done {
|
||||
return;
|
||||
}
|
||||
|
||||
let query = if self.depth == 0 {
|
||||
"ROLLBACK".to_string()
|
||||
let query = if let Some(sp) = self.savepoint.as_ref() {
|
||||
format!("ROLLBACK TO {}", sp.name)
|
||||
} else {
|
||||
format!("ROLLBACK TO sp{}", self.depth)
|
||||
"ROLLBACK".to_string()
|
||||
};
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query(&query, buf).unwrap();
|
||||
@ -53,7 +59,7 @@ impl<'a> Transaction<'a> {
|
||||
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
|
||||
Transaction {
|
||||
client,
|
||||
depth: 0,
|
||||
savepoint: None,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
@ -61,10 +67,10 @@ impl<'a> Transaction<'a> {
|
||||
/// Consumes the transaction, committing all changes made within it.
|
||||
pub async fn commit(mut self) -> Result<(), Error> {
|
||||
self.done = true;
|
||||
let query = if self.depth == 0 {
|
||||
"COMMIT".to_string()
|
||||
let query = if let Some(sp) = self.savepoint.as_ref() {
|
||||
format!("RELEASE {}", sp.name)
|
||||
} else {
|
||||
format!("RELEASE sp{}", self.depth)
|
||||
"COMMIT".to_string()
|
||||
};
|
||||
self.client.batch_execute(&query).await
|
||||
}
|
||||
@ -74,10 +80,10 @@ impl<'a> Transaction<'a> {
|
||||
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
|
||||
pub async fn rollback(mut self) -> Result<(), Error> {
|
||||
self.done = true;
|
||||
let query = if self.depth == 0 {
|
||||
"ROLLBACK".to_string()
|
||||
let query = if let Some(sp) = self.savepoint.as_ref() {
|
||||
format!("ROLLBACK TO {}", sp.name)
|
||||
} else {
|
||||
format!("ROLLBACK TO sp{}", self.depth)
|
||||
"ROLLBACK".to_string()
|
||||
};
|
||||
self.client.batch_execute(&query).await
|
||||
}
|
||||
@ -272,15 +278,28 @@ impl<'a> Transaction<'a> {
|
||||
self.client.cancel_query_raw(stream, tls).await
|
||||
}
|
||||
|
||||
/// Like `Client::transaction`.
|
||||
/// Like `Client::transaction`, but creates a nested transaction via a savepoint.
|
||||
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
|
||||
let depth = self.depth + 1;
|
||||
let query = format!("SAVEPOINT sp{}", depth);
|
||||
self._savepoint(None).await
|
||||
}
|
||||
|
||||
/// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name.
|
||||
pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
|
||||
where
|
||||
I: Into<String>,
|
||||
{
|
||||
self._savepoint(Some(name.into())).await
|
||||
}
|
||||
|
||||
async fn _savepoint(&mut self, name: Option<String>) -> Result<Transaction<'_>, Error> {
|
||||
let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1;
|
||||
let name = name.unwrap_or_else(|| format!("sp_{}", depth));
|
||||
let query = format!("SAVEPOINT {}", name);
|
||||
self.batch_execute(&query).await?;
|
||||
|
||||
Ok(Transaction {
|
||||
client: self.client,
|
||||
depth,
|
||||
savepoint: Some(Savepoint { name, depth }),
|
||||
done: false,
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user