diff --git a/postgres/src/client.rs b/postgres/src/client.rs index e4a1e382..bb53e0a7 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -319,6 +319,20 @@ impl Client { Ok(Iter::new(self.0.simple_query(query))) } + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. This is intended for use when, for example, initializing a database schema. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { + executor::block_on(self.0.batch_execute(query)) + } + /// Begins a new database transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. diff --git a/postgres/src/test.rs b/postgres/src/test.rs index 06399a19..59953e4e 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -96,13 +96,12 @@ fn transaction_drop() { assert_eq!(rows.len(), 0); } -/* #[test] fn nested_transactions() { let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); client - .simple_query("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)") + .batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)") .unwrap(); let mut transaction = client.transaction().unwrap(); @@ -147,7 +146,6 @@ fn nested_transactions() { assert_eq!(rows[1].get::<_, i32>(0), 3); assert_eq!(rows[2].get::<_, i32>(0), 4); } -*/ #[test] fn copy_in() { diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index fcf7ce78..65ac5176 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -150,15 +150,14 @@ impl<'a> Transaction<'a> { Ok(Iter::new(self.0.simple_query(query))) } - // /// Like `Client::transaction`. - // pub fn transaction(&mut self) -> Result, Error> { - // let depth = self.depth + 1; - // self.client - // .simple_query(&format!("SAVEPOINT sp{}", depth))?; - // Ok(Transaction { - // client: self.client, - // depth, - // done: false, - // }) - // } + /// Like `Client::batch_execute`. + pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { + executor::block_on(self.0.batch_execute(query)) + } + + /// Like `Client::transaction`. + pub fn transaction(&mut self) -> Result, Error> { + let transaction = executor::block_on(self.0.transaction())?; + Ok(Transaction(transaction)) + } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index c1271d21..0489f09f 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -20,6 +20,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; /// transaction. Transactions can be nested, with inner transactions implemented via safepoints. pub struct Transaction<'a> { client: &'a mut Client, + depth: u32, done: bool, } @@ -30,7 +31,12 @@ impl<'a> Drop for Transaction<'a> { } let mut buf = vec![]; - frontend::query("ROLLBACK", &mut buf).unwrap(); + let query = if self.depth == 0 { + "ROLLBACK".to_string() + } else { + format!("ROLLBACK TO sp{}", self.depth) + }; + frontend::query(&query, &mut buf).unwrap(); let _ = self .client .inner() @@ -42,6 +48,7 @@ impl<'a> Transaction<'a> { pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { Transaction { client, + depth: 0, done: false, } } @@ -49,7 +56,12 @@ impl<'a> Transaction<'a> { /// Consumes the transaction, committing all changes made within it. pub async fn commit(mut self) -> Result<(), Error> { self.done = true; - self.client.batch_execute("COMMIT").await + let query = if self.depth == 0 { + "COMMIT".to_string() + } else { + format!("RELEASE sp{}", self.depth) + }; + self.client.batch_execute(&query).await } /// Rolls the transaction back, discarding all changes made within it. @@ -57,7 +69,12 @@ impl<'a> Transaction<'a> { /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. pub async fn rollback(mut self) -> Result<(), Error> { self.done = true; - self.client.batch_execute("ROLLBACK").await + let query = if self.depth == 0 { + "ROLLBACK".to_string() + } else { + format!("ROLLBACK TO sp{}", self.depth) + }; + self.client.batch_execute(&query).await } /// Like `Client::prepare`. @@ -227,4 +244,17 @@ impl<'a> Transaction<'a> { { self.client.cancel_query_raw(stream, tls) } + + /// Like `Client::transaction`. + pub async fn transaction(&mut self) -> Result, Error> { + let depth = self.depth + 1; + let query = format!("SAVEPOINT sp{}", depth); + self.batch_execute(&query).await?; + + Ok(Transaction { + client: self.client, + depth, + done: false, + }) + } }