Fix transaction not being rolled back on Client::transaction()
Future
dropped before completion
This commit is contained in:
parent
0adcf58555
commit
f6189a95f2
@ -1,4 +1,4 @@
|
||||
use crate::codec::BackendMessages;
|
||||
use crate::codec::{BackendMessages, FrontendMessage};
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::connection::{Request, RequestMessages};
|
||||
use crate::copy_out::CopyOutStream;
|
||||
@ -19,7 +19,7 @@ use fallible_iterator::FallibleIterator;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{future, pin_mut, ready, StreamExt, TryStreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::{backend::Message, frontend};
|
||||
use postgres_types::BorrowToSql;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
@ -488,7 +488,42 @@ impl Client {
|
||||
///
|
||||
/// The transaction will roll back by default - use the `commit` method to commit it.
|
||||
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
|
||||
self.batch_execute("BEGIN").await?;
|
||||
struct RollbackIfNotDone<'me> {
|
||||
client: &'me Client,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl<'a> Drop for RollbackIfNotDone<'a> {
|
||||
fn drop(&mut self) {
|
||||
if self.done {
|
||||
return;
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
.client
|
||||
.inner()
|
||||
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
|
||||
}
|
||||
}
|
||||
|
||||
// This is done, as `Future` created by this method can be dropped after
|
||||
// `RequestMessages` is synchronously send to the `Connection` by
|
||||
// `batch_execute()`, but before `Responses` is asynchronously polled to
|
||||
// completion. In that case `Transaction` won't be created and thus
|
||||
// won't be rolled back.
|
||||
{
|
||||
let mut cleaner = RollbackIfNotDone {
|
||||
client: self,
|
||||
done: false,
|
||||
};
|
||||
self.batch_execute("BEGIN").await?;
|
||||
cleaner.done = true;
|
||||
}
|
||||
|
||||
Ok(Transaction::new(self))
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,12 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{
|
||||
future, join, pin_mut, stream, try_join, FutureExt, SinkExt, StreamExt, TryStreamExt,
|
||||
future, join, pin_mut, stream, try_join, Future, FutureExt, SinkExt, StreamExt, TryStreamExt,
|
||||
};
|
||||
use pin_project_lite::pin_project;
|
||||
use std::fmt::Write;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time;
|
||||
@ -22,6 +25,35 @@ mod parse;
|
||||
mod runtime;
|
||||
mod types;
|
||||
|
||||
pin_project! {
|
||||
/// Polls `F` at most `polls_left` times returning `Some(F::Output)` if
|
||||
/// [`Future`] returned [`Poll::Ready`] or [`None`] otherwise.
|
||||
struct Cancellable<F> {
|
||||
#[pin]
|
||||
fut: F,
|
||||
polls_left: usize,
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Future> Future for Cancellable<F> {
|
||||
type Output = Option<F::Output>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
match this.fut.poll(ctx) {
|
||||
Poll::Ready(r) => Poll::Ready(Some(r)),
|
||||
Poll::Pending => {
|
||||
*this.polls_left = this.polls_left.saturating_sub(1);
|
||||
if *this.polls_left == 0 {
|
||||
Poll::Ready(None)
|
||||
} else {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_raw(s: &str) -> Result<(Client, Connection<TcpStream, NoTlsStream>), Error> {
|
||||
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
|
||||
let config = s.parse::<Config>().unwrap();
|
||||
@ -35,6 +67,20 @@ async fn connect(s: &str) -> Client {
|
||||
client
|
||||
}
|
||||
|
||||
async fn current_transaction_id(client: &Client) -> i64 {
|
||||
client
|
||||
.query("SELECT txid_current()", &[])
|
||||
.await
|
||||
.unwrap()
|
||||
.pop()
|
||||
.unwrap()
|
||||
.get::<_, i64>("txid_current")
|
||||
}
|
||||
|
||||
async fn in_transaction(client: &Client) -> bool {
|
||||
current_transaction_id(client).await == current_transaction_id(client).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn plain_password_missing() {
|
||||
connect_raw("user=pass_user dbname=postgres")
|
||||
@ -377,6 +423,80 @@ async fn transaction_rollback() {
|
||||
assert_eq!(rows.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transaction_future_cancellation() {
|
||||
let mut client = connect("user=postgres").await;
|
||||
|
||||
for i in 0.. {
|
||||
let done = {
|
||||
let txn = client.transaction();
|
||||
let fut = Cancellable {
|
||||
fut: txn,
|
||||
polls_left: i,
|
||||
};
|
||||
fut.await
|
||||
.map(|res| res.expect("transaction failed"))
|
||||
.is_some()
|
||||
};
|
||||
|
||||
assert!(!in_transaction(&client).await);
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transaction_commit_future_cancellation() {
|
||||
let mut client = connect("user=postgres").await;
|
||||
|
||||
for i in 0.. {
|
||||
let done = {
|
||||
let txn = client.transaction().await.unwrap();
|
||||
let commit = txn.commit();
|
||||
let fut = Cancellable {
|
||||
fut: commit,
|
||||
polls_left: i,
|
||||
};
|
||||
fut.await
|
||||
.map(|res| res.expect("transaction failed"))
|
||||
.is_some()
|
||||
};
|
||||
|
||||
assert!(!in_transaction(&client).await);
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transaction_rollback_future_cancellation() {
|
||||
let mut client = connect("user=postgres").await;
|
||||
|
||||
for i in 0.. {
|
||||
let done = {
|
||||
let txn = client.transaction().await.unwrap();
|
||||
let rollback = txn.rollback();
|
||||
let fut = Cancellable {
|
||||
fut: rollback,
|
||||
polls_left: i,
|
||||
};
|
||||
fut.await
|
||||
.map(|res| res.expect("transaction failed"))
|
||||
.is_some()
|
||||
};
|
||||
|
||||
assert!(!in_transaction(&client).await);
|
||||
|
||||
if done {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn transaction_rollback_drop() {
|
||||
let mut client = connect("user=postgres").await;
|
||||
|
Loading…
Reference in New Issue
Block a user