Merge pull request #588 from sfackler/notifications

Add a notification API to the blocking client
This commit is contained in:
Steven Fackler 2020-03-23 07:29:32 -04:00 committed by GitHub
commit 70ca1b4fa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 508 additions and 147 deletions

View File

@ -22,7 +22,7 @@ version: 2
jobs:
build:
docker:
- image: rust:1.40.0
- image: rust:1.41.0
environment:
RUSTFLAGS: -D warnings
- image: sfackler/rust-postgres-test:6

View File

@ -35,7 +35,7 @@ fallible-iterator = "0.2"
futures = "0.3"
tokio-postgres = { version = "0.5.3", path = "../tokio-postgres" }
tokio = { version = "0.2", features = ["rt-core"] }
tokio = { version = "0.2", features = ["rt-core", "time"] }
log = "0.4"
[dev-dependencies]

View File

@ -1,7 +1,8 @@
//! Utilities for working with the PostgreSQL binary copy format.
use crate::connection::ConnectionRef;
use crate::types::{ToSql, Type};
use crate::{CopyInWriter, CopyOutReader, Error, Rt};
use crate::{CopyInWriter, CopyOutReader, Error};
use fallible_iterator::FallibleIterator;
use futures::StreamExt;
use std::pin::Pin;
@ -13,7 +14,7 @@ use tokio_postgres::binary_copy::{self, BinaryCopyOutStream};
///
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
pub struct BinaryCopyInWriter<'a> {
runtime: Rt<'a>,
connection: ConnectionRef<'a>,
sink: Pin<Box<binary_copy::BinaryCopyInWriter>>,
}
@ -26,7 +27,7 @@ impl<'a> BinaryCopyInWriter<'a> {
.expect("writer has already been written to");
BinaryCopyInWriter {
runtime: writer.runtime,
connection: writer.connection,
sink: Box::pin(binary_copy::BinaryCopyInWriter::new(stream, types)),
}
}
@ -37,7 +38,7 @@ impl<'a> BinaryCopyInWriter<'a> {
///
/// Panics if the number of values provided does not match the number expected.
pub fn write(&mut self, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> {
self.runtime.block_on(self.sink.as_mut().write(values))
self.connection.block_on(self.sink.as_mut().write(values))
}
/// A maximally-flexible version of `write`.
@ -50,20 +51,21 @@ impl<'a> BinaryCopyInWriter<'a> {
I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator,
{
self.runtime.block_on(self.sink.as_mut().write_raw(values))
self.connection
.block_on(self.sink.as_mut().write_raw(values))
}
/// Completes the copy, returning the number of rows added.
///
/// This method *must* be used to complete the copy process. If it is not, the copy will be aborted.
pub fn finish(mut self) -> Result<u64, Error> {
self.runtime.block_on(self.sink.as_mut().finish())
self.connection.block_on(self.sink.as_mut().finish())
}
}
/// An iterator of rows deserialized from the PostgreSQL binary copy format.
pub struct BinaryCopyOutIter<'a> {
runtime: Rt<'a>,
connection: ConnectionRef<'a>,
stream: Pin<Box<BinaryCopyOutStream>>,
}
@ -76,7 +78,7 @@ impl<'a> BinaryCopyOutIter<'a> {
.expect("reader has already been read from");
BinaryCopyOutIter {
runtime: reader.runtime,
connection: reader.connection,
stream: Box::pin(BinaryCopyOutStream::new(stream, types)),
}
}
@ -87,6 +89,8 @@ impl FallibleIterator for BinaryCopyOutIter<'_> {
type Error = Error;
fn next(&mut self) -> Result<Option<BinaryCopyOutRow>, Error> {
self.runtime.block_on(self.stream.next()).transpose()
let stream = &mut self.stream;
self.connection
.block_on(async { stream.next().await.transpose() })
}
}

View File

@ -1,45 +1,21 @@
use crate::connection::Connection;
use crate::{
CancelToken, Config, CopyInWriter, CopyOutReader, RowIter, Statement, ToStatement, Transaction,
TransactionBuilder,
CancelToken, Config, CopyInWriter, CopyOutReader, Notifications, RowIter, Statement,
ToStatement, Transaction, TransactionBuilder,
};
use std::ops::{Deref, DerefMut};
use tokio::runtime::Runtime;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{Error, Row, SimpleQueryMessage, Socket};
pub(crate) struct Rt<'a>(pub &'a mut Runtime);
// no-op impl to extend the borrow until drop
impl Drop for Rt<'_> {
fn drop(&mut self) {}
}
impl Deref for Rt<'_> {
type Target = Runtime;
#[inline]
fn deref(&self) -> &Runtime {
self.0
}
}
impl DerefMut for Rt<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Runtime {
self.0
}
}
/// A synchronous PostgreSQL client.
pub struct Client {
runtime: Runtime,
connection: Connection,
client: tokio_postgres::Client,
}
impl Client {
pub(crate) fn new(runtime: Runtime, client: tokio_postgres::Client) -> Client {
Client { runtime, client }
pub(crate) fn new(connection: Connection, client: tokio_postgres::Client) -> Client {
Client { connection, client }
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
@ -62,10 +38,6 @@ impl Client {
Config::new()
}
fn rt(&mut self) -> Rt<'_> {
Rt(&mut self.runtime)
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
@ -104,7 +76,7 @@ impl Client {
where
T: ?Sized + ToStatement,
{
self.runtime.block_on(self.client.execute(query, params))
self.connection.block_on(self.client.execute(query, params))
}
/// Executes a statement, returning the resulting rows.
@ -140,7 +112,7 @@ impl Client {
where
T: ?Sized + ToStatement,
{
self.runtime.block_on(self.client.query(query, params))
self.connection.block_on(self.client.query(query, params))
}
/// Executes a statement which returns a single row, returning it.
@ -177,7 +149,8 @@ impl Client {
where
T: ?Sized + ToStatement,
{
self.runtime.block_on(self.client.query_one(query, params))
self.connection
.block_on(self.client.query_one(query, params))
}
/// Executes a statement which returns zero or one rows, returning it.
@ -223,7 +196,8 @@ impl Client {
where
T: ?Sized + ToStatement,
{
self.runtime.block_on(self.client.query_opt(query, params))
self.connection
.block_on(self.client.query_opt(query, params))
}
/// A maximally-flexible version of `query`.
@ -289,9 +263,9 @@ impl Client {
I::IntoIter: ExactSizeIterator,
{
let stream = self
.runtime
.connection
.block_on(self.client.query_raw(query, params))?;
Ok(RowIter::new(self.rt(), stream))
Ok(RowIter::new(self.connection.as_ref(), stream))
}
/// Creates a new prepared statement.
@ -318,7 +292,7 @@ impl Client {
/// # }
/// ```
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.runtime.block_on(self.client.prepare(query))
self.connection.block_on(self.client.prepare(query))
}
/// Like `prepare`, but allows the types of query parameters to be explicitly specified.
@ -349,7 +323,7 @@ impl Client {
/// # }
/// ```
pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
self.runtime
self.connection
.block_on(self.client.prepare_typed(query, types))
}
@ -380,8 +354,8 @@ impl Client {
where
T: ?Sized + ToStatement,
{
let sink = self.runtime.block_on(self.client.copy_in(query))?;
Ok(CopyInWriter::new(self.rt(), sink))
let sink = self.connection.block_on(self.client.copy_in(query))?;
Ok(CopyInWriter::new(self.connection.as_ref(), sink))
}
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
@ -408,8 +382,8 @@ impl Client {
where
T: ?Sized + ToStatement,
{
let stream = self.runtime.block_on(self.client.copy_out(query))?;
Ok(CopyOutReader::new(self.rt(), stream))
let stream = self.connection.block_on(self.client.copy_out(query))?;
Ok(CopyOutReader::new(self.connection.as_ref(), stream))
}
/// Executes a sequence of SQL statements using the simple query protocol.
@ -428,7 +402,7 @@ impl Client {
/// functionality to safely imbed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.runtime.block_on(self.client.simple_query(query))
self.connection.block_on(self.client.simple_query(query))
}
/// Executes a sequence of SQL statements using the simple query protocol.
@ -442,7 +416,7 @@ impl Client {
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
self.runtime.block_on(self.client.batch_execute(query))
self.connection.block_on(self.client.batch_execute(query))
}
/// Begins a new database transaction.
@ -466,8 +440,8 @@ impl Client {
/// # }
/// ```
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = self.runtime.block_on(self.client.transaction())?;
Ok(Transaction::new(&mut self.runtime, transaction))
let transaction = self.connection.block_on(self.client.transaction())?;
Ok(Transaction::new(self.connection.as_ref(), transaction))
}
/// Returns a builder for a transaction with custom settings.
@ -494,7 +468,14 @@ impl Client {
/// # }
/// ```
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(&mut self.runtime, self.client.build_transaction())
TransactionBuilder::new(self.connection.as_ref(), self.client.build_transaction())
}
/// Returns a structure providing access to asynchronous notifications.
///
/// Use the `LISTEN` command to register this connection for notifications.
pub fn notifications(&mut self) -> Notifications<'_> {
Notifications::new(self.connection.as_ref())
}
/// Constructs a cancellation token that can later be used to request
@ -516,7 +497,7 @@ impl Client {
/// thread::spawn(move || {
/// // Abort the query after 5s.
/// thread::sleep(Duration::from_secs(5));
/// cancel_token.cancel_query(NoTls);
/// let _ = cancel_token.cancel_query(NoTls);
/// });
///
/// match client.simple_query("SELECT long_running_query()") {

View File

@ -2,9 +2,8 @@
//!
//! Requires the `runtime` Cargo feature (enabled by default).
use crate::connection::Connection;
use crate::Client;
use futures::FutureExt;
use log::error;
use std::fmt;
use std::path::Path;
use std::str::FromStr;
@ -324,15 +323,8 @@ impl Config {
let (client, connection) = runtime.block_on(self.config.connect(tls))?;
// FIXME don't spawn this so error reporting is less weird.
let connection = connection.map(|r| {
if let Err(e) = r {
error!("postgres connection error: {}", e)
}
});
runtime.spawn(connection);
Ok(Client::new(runtime, client))
let connection = Connection::new(runtime, connection);
Ok(Client::new(connection, client))
}
}

129
postgres/src/connection.rs Normal file
View File

@ -0,0 +1,129 @@
use crate::{Error, Notification};
use futures::future;
use futures::{pin_mut, Stream};
use log::info;
use std::collections::VecDeque;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::runtime::Runtime;
use tokio_postgres::AsyncMessage;
pub struct Connection {
runtime: Runtime,
connection: Pin<Box<dyn Stream<Item = Result<AsyncMessage, Error>> + Send>>,
notifications: VecDeque<Notification>,
}
impl Connection {
pub fn new<S, T>(runtime: Runtime, connection: tokio_postgres::Connection<S, T>) -> Connection
where
S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
T: AsyncRead + AsyncWrite + Unpin + 'static + Send,
{
Connection {
runtime,
connection: Box::pin(ConnectionStream { connection }),
notifications: VecDeque::new(),
}
}
pub fn as_ref(&mut self) -> ConnectionRef<'_> {
ConnectionRef { connection: self }
}
pub fn enter<F, T>(&self, f: F) -> T
where
F: FnOnce() -> T,
{
self.runtime.enter(f)
}
pub fn block_on<F, T>(&mut self, future: F) -> Result<T, Error>
where
F: Future<Output = Result<T, Error>>,
{
pin_mut!(future);
self.poll_block_on(|cx, _, _| future.as_mut().poll(cx))
}
pub fn poll_block_on<F, T>(&mut self, mut f: F) -> Result<T, Error>
where
F: FnMut(&mut Context<'_>, &mut VecDeque<Notification>, bool) -> Poll<Result<T, Error>>,
{
let connection = &mut self.connection;
let notifications = &mut self.notifications;
self.runtime.block_on({
future::poll_fn(|cx| {
let done = loop {
match connection.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(AsyncMessage::Notification(notification)))) => {
notifications.push_back(notification);
}
Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => {
info!("{}: {}", notice.severity(), notice.message());
}
Poll::Ready(Some(Ok(_))) => {}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Poll::Ready(None) => break true,
Poll::Pending => break false,
}
};
f(cx, notifications, done)
})
})
}
pub fn notifications(&self) -> &VecDeque<Notification> {
&self.notifications
}
pub fn notifications_mut(&mut self) -> &mut VecDeque<Notification> {
&mut self.notifications
}
}
pub struct ConnectionRef<'a> {
connection: &'a mut Connection,
}
// no-op impl to extend the borrow until drop
impl Drop for ConnectionRef<'_> {
#[inline]
fn drop(&mut self) {}
}
impl Deref for ConnectionRef<'_> {
type Target = Connection;
#[inline]
fn deref(&self) -> &Connection {
self.connection
}
}
impl DerefMut for ConnectionRef<'_> {
#[inline]
fn deref_mut(&mut self) -> &mut Connection {
self.connection
}
}
struct ConnectionStream<S, T> {
connection: tokio_postgres::Connection<S, T>,
}
impl<S, T> Stream for ConnectionStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<AsyncMessage, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.connection.poll_message(cx)
}
}

View File

@ -1,5 +1,5 @@
use crate::connection::ConnectionRef;
use crate::lazy_pin::LazyPin;
use crate::Rt;
use bytes::{Bytes, BytesMut};
use futures::SinkExt;
use std::io;
@ -10,15 +10,15 @@ use tokio_postgres::{CopyInSink, Error};
///
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
pub struct CopyInWriter<'a> {
pub(crate) runtime: Rt<'a>,
pub(crate) connection: ConnectionRef<'a>,
pub(crate) sink: LazyPin<CopyInSink<Bytes>>,
buf: BytesMut,
}
impl<'a> CopyInWriter<'a> {
pub(crate) fn new(runtime: Rt<'a>, sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
pub(crate) fn new(connection: ConnectionRef<'a>, sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
CopyInWriter {
runtime,
connection,
sink: LazyPin::new(sink),
buf: BytesMut::new(),
}
@ -29,7 +29,7 @@ impl<'a> CopyInWriter<'a> {
/// If this is not called, the copy will be aborted.
pub fn finish(mut self) -> Result<u64, Error> {
self.flush_inner()?;
self.runtime.block_on(self.sink.pinned().finish())
self.connection.block_on(self.sink.pinned().finish())
}
fn flush_inner(&mut self) -> Result<(), Error> {
@ -37,7 +37,7 @@ impl<'a> CopyInWriter<'a> {
return Ok(());
}
self.runtime
self.connection
.block_on(self.sink.pinned().send(self.buf.split().freeze()))
}
}

View File

@ -1,5 +1,5 @@
use crate::connection::ConnectionRef;
use crate::lazy_pin::LazyPin;
use crate::Rt;
use bytes::{Buf, Bytes};
use futures::StreamExt;
use std::io::{self, BufRead, Read};
@ -7,15 +7,15 @@ use tokio_postgres::CopyOutStream;
/// The reader returned by the `copy_out` method.
pub struct CopyOutReader<'a> {
pub(crate) runtime: Rt<'a>,
pub(crate) connection: ConnectionRef<'a>,
pub(crate) stream: LazyPin<CopyOutStream>,
cur: Bytes,
}
impl<'a> CopyOutReader<'a> {
pub(crate) fn new(runtime: Rt<'a>, stream: CopyOutStream) -> CopyOutReader<'a> {
pub(crate) fn new(connection: ConnectionRef<'a>, stream: CopyOutStream) -> CopyOutReader<'a> {
CopyOutReader {
runtime,
connection,
stream: LazyPin::new(stream),
cur: Bytes::new(),
}
@ -35,10 +35,14 @@ impl Read for CopyOutReader<'_> {
impl BufRead for CopyOutReader<'_> {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if !self.cur.has_remaining() {
match self.runtime.block_on(self.stream.pinned().next()) {
Some(Ok(cur)) => self.cur = cur,
Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
None => {}
let mut stream = self.stream.pinned();
match self
.connection
.block_on({ async { stream.next().await.transpose() } })
{
Ok(Some(cur)) => self.cur = cur,
Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
Ok(None) => {}
};
}

View File

@ -65,8 +65,8 @@
pub use fallible_iterator;
pub use tokio_postgres::{
error, row, tls, types, Column, IsolationLevel, Portal, SimpleQueryMessage, Socket, Statement,
ToStatement,
error, row, tls, types, Column, IsolationLevel, Notification, Portal, SimpleQueryMessage,
Socket, Statement, ToStatement,
};
pub use crate::cancel_token::CancelToken;
@ -77,6 +77,8 @@ pub use crate::copy_out_reader::CopyOutReader;
#[doc(no_inline)]
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
#[doc(inline)]
pub use crate::notifications::Notifications;
#[doc(no_inline)]
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::row_iter::RowIter;
@ -89,10 +91,12 @@ pub mod binary_copy;
mod cancel_token;
mod client;
pub mod config;
mod connection;
mod copy_in_writer;
mod copy_out_reader;
mod generic_client;
mod lazy_pin;
pub mod notifications;
mod row_iter;
mod transaction;
mod transaction_builder;

View File

@ -0,0 +1,161 @@
//! Asynchronous notifications.
use crate::connection::ConnectionRef;
use crate::{Error, Notification};
use fallible_iterator::FallibleIterator;
use futures::{ready, FutureExt};
use std::task::Poll;
use std::time::Duration;
use tokio::time::{self, Delay, Instant};
/// Notifications from a PostgreSQL backend.
pub struct Notifications<'a> {
connection: ConnectionRef<'a>,
}
impl<'a> Notifications<'a> {
pub(crate) fn new(connection: ConnectionRef<'a>) -> Notifications<'a> {
Notifications { connection }
}
/// Returns the number of already buffered pending notifications.
pub fn len(&self) -> usize {
self.connection.notifications().len()
}
/// Determines if there are any already buffered pending notifications.
pub fn is_empty(&self) -> bool {
self.connection.notifications().is_empty()
}
/// Returns a nonblocking iterator over notifications.
///
/// If there are no already buffered pending notifications, this iterator will poll the connection but will not
/// block waiting on notifications over the network. A return value of `None` either indicates that there are no
/// pending notifications or that the server has disconnected.
///
/// # Note
///
/// This iterator may start returning `Some` after previously returning `None` if more notifications are received.
pub fn iter(&mut self) -> Iter<'_> {
Iter {
connection: self.connection.as_ref(),
}
}
/// Returns a blocking iterator over notifications.
///
/// If there are no already buffered pending notifications, this iterator will block indefinitely waiting on the
/// PostgreSQL backend server to send one. It will only return `None` if the server has disconnected.
pub fn blocking_iter(&mut self) -> BlockingIter<'_> {
BlockingIter {
connection: self.connection.as_ref(),
}
}
/// Returns an iterator over notifications which blocks a limited amount of time.
///
/// If there are no already buffered pending notifications, this iterator will block waiting on the PostgreSQL
/// backend server to send one up to the provided timeout. A return value of `None` either indicates that there are
/// no pending notifications or that the server has disconnected.
///
/// # Note
///
/// This iterator may start returning `Some` after previously returning `None` if more notifications are received.
pub fn timeout_iter(&mut self, timeout: Duration) -> TimeoutIter<'_> {
TimeoutIter {
delay: self.connection.enter(|| time::delay_for(timeout)),
timeout,
connection: self.connection.as_ref(),
}
}
}
/// A nonblocking iterator over pending notifications.
pub struct Iter<'a> {
connection: ConnectionRef<'a>,
}
impl<'a> FallibleIterator for Iter<'a> {
type Item = Notification;
type Error = Error;
fn next(&mut self) -> Result<Option<Self::Item>, Self::Error> {
if let Some(notification) = self.connection.notifications_mut().pop_front() {
return Ok(Some(notification));
}
self.connection
.poll_block_on(|_, notifications, _| Poll::Ready(Ok(notifications.pop_front())))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.connection.notifications().len(), None)
}
}
/// A blocking iterator over pending notifications.
pub struct BlockingIter<'a> {
connection: ConnectionRef<'a>,
}
impl<'a> FallibleIterator for BlockingIter<'a> {
type Item = Notification;
type Error = Error;
fn next(&mut self) -> Result<Option<Self::Item>, Self::Error> {
if let Some(notification) = self.connection.notifications_mut().pop_front() {
return Ok(Some(notification));
}
self.connection
.poll_block_on(|_, notifications, done| match notifications.pop_front() {
Some(notification) => Poll::Ready(Ok(Some(notification))),
None if done => Poll::Ready(Ok(None)),
None => Poll::Pending,
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.connection.notifications().len(), None)
}
}
/// A time-limited blocking iterator over pending notifications.
pub struct TimeoutIter<'a> {
connection: ConnectionRef<'a>,
delay: Delay,
timeout: Duration,
}
impl<'a> FallibleIterator for TimeoutIter<'a> {
type Item = Notification;
type Error = Error;
fn next(&mut self) -> Result<Option<Self::Item>, Self::Error> {
if let Some(notification) = self.connection.notifications_mut().pop_front() {
self.delay.reset(Instant::now() + self.timeout);
return Ok(Some(notification));
}
let delay = &mut self.delay;
let timeout = self.timeout;
self.connection.poll_block_on(|cx, notifications, done| {
match notifications.pop_front() {
Some(notification) => {
delay.reset(Instant::now() + timeout);
return Poll::Ready(Ok(Some(notification)));
}
None if done => return Poll::Ready(Ok(None)),
None => {}
}
ready!(delay.poll_unpin(cx));
Poll::Ready(Ok(None))
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.connection.notifications().len(), None)
}
}

View File

@ -1,4 +1,4 @@
use crate::Rt;
use crate::connection::ConnectionRef;
use fallible_iterator::FallibleIterator;
use futures::StreamExt;
use std::pin::Pin;
@ -6,19 +6,14 @@ use tokio_postgres::{Error, Row, RowStream};
/// The iterator returned by `query_raw`.
pub struct RowIter<'a> {
runtime: Rt<'a>,
connection: ConnectionRef<'a>,
it: Pin<Box<RowStream>>,
}
// no-op impl to extend the borrow until drop
impl Drop for RowIter<'_> {
fn drop(&mut self) {}
}
impl<'a> RowIter<'a> {
pub(crate) fn new(runtime: Rt<'a>, stream: RowStream) -> RowIter<'a> {
pub(crate) fn new(connection: ConnectionRef<'a>, stream: RowStream) -> RowIter<'a> {
RowIter {
runtime,
connection,
it: Box::pin(stream),
}
}
@ -29,6 +24,8 @@ impl FallibleIterator for RowIter<'_> {
type Error = Error;
fn next(&mut self) -> Result<Option<Row>, Error> {
self.runtime.block_on(self.it.next()).transpose()
let it = &mut self.it;
self.connection
.block_on(async { it.next().await.transpose() })
}
}

View File

@ -309,3 +309,93 @@ fn cancel_query() {
cancel_thread.join().unwrap();
}
#[test]
fn notifications_iter() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute(
"\
LISTEN notifications_iter;
NOTIFY notifications_iter, 'hello';
NOTIFY notifications_iter, 'world';
",
)
.unwrap();
let notifications = client.notifications().iter().collect::<Vec<_>>().unwrap();
assert_eq!(notifications.len(), 2);
assert_eq!(notifications[0].payload(), "hello");
assert_eq!(notifications[1].payload(), "world");
}
#[test]
fn notifications_blocking_iter() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute(
"\
LISTEN notifications_blocking_iter;
NOTIFY notifications_blocking_iter, 'hello';
",
)
.unwrap();
thread::spawn(|| {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
thread::sleep(Duration::from_secs(1));
client
.batch_execute("NOTIFY notifications_blocking_iter, 'world'")
.unwrap();
});
let notifications = client
.notifications()
.blocking_iter()
.take(2)
.collect::<Vec<_>>()
.unwrap();
assert_eq!(notifications.len(), 2);
assert_eq!(notifications[0].payload(), "hello");
assert_eq!(notifications[1].payload(), "world");
}
#[test]
fn notifications_timeout_iter() {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
client
.batch_execute(
"\
LISTEN notifications_timeout_iter;
NOTIFY notifications_timeout_iter, 'hello';
",
)
.unwrap();
thread::spawn(|| {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
thread::sleep(Duration::from_secs(1));
client
.batch_execute("NOTIFY notifications_timeout_iter, 'world'")
.unwrap();
thread::sleep(Duration::from_secs(10));
client
.batch_execute("NOTIFY notifications_timeout_iter, '!'")
.unwrap();
});
let notifications = client
.notifications()
.timeout_iter(Duration::from_secs(2))
.collect::<Vec<_>>()
.unwrap();
assert_eq!(notifications.len(), 2);
assert_eq!(notifications[0].payload(), "hello");
assert_eq!(notifications[1].payload(), "world");
}

View File

@ -1,7 +1,5 @@
use crate::{
CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Rt, Statement, ToStatement,
};
use tokio::runtime::Runtime;
use crate::connection::ConnectionRef;
use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement};
use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{Error, Row, SimpleQueryMessage};
@ -10,45 +8,41 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage};
/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
pub struct Transaction<'a> {
runtime: &'a mut Runtime,
connection: ConnectionRef<'a>,
transaction: tokio_postgres::Transaction<'a>,
}
impl<'a> Transaction<'a> {
pub(crate) fn new(
runtime: &'a mut Runtime,
connection: ConnectionRef<'a>,
transaction: tokio_postgres::Transaction<'a>,
) -> Transaction<'a> {
Transaction {
runtime,
connection,
transaction,
}
}
fn rt(&mut self) -> Rt<'_> {
Rt(self.runtime)
}
/// Consumes the transaction, committing all changes made within it.
pub fn commit(self) -> Result<(), Error> {
self.runtime.block_on(self.transaction.commit())
pub fn commit(mut self) -> Result<(), Error> {
self.connection.block_on(self.transaction.commit())
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub fn rollback(self) -> Result<(), Error> {
self.runtime.block_on(self.transaction.rollback())
pub fn rollback(mut self) -> Result<(), Error> {
self.connection.block_on(self.transaction.rollback())
}
/// Like `Client::prepare`.
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
self.runtime.block_on(self.transaction.prepare(query))
self.connection.block_on(self.transaction.prepare(query))
}
/// Like `Client::prepare_typed`.
pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
self.runtime
self.connection
.block_on(self.transaction.prepare_typed(query, types))
}
@ -57,7 +51,7 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
self.runtime
self.connection
.block_on(self.transaction.execute(query, params))
}
@ -66,7 +60,8 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
self.runtime.block_on(self.transaction.query(query, params))
self.connection
.block_on(self.transaction.query(query, params))
}
/// Like `Client::query_one`.
@ -74,7 +69,7 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
self.runtime
self.connection
.block_on(self.transaction.query_one(query, params))
}
@ -87,7 +82,7 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
self.runtime
self.connection
.block_on(self.transaction.query_opt(query, params))
}
@ -99,9 +94,9 @@ impl<'a> Transaction<'a> {
I::IntoIter: ExactSizeIterator,
{
let stream = self
.runtime
.connection
.block_on(self.transaction.query_raw(query, params))?;
Ok(RowIter::new(self.rt(), stream))
Ok(RowIter::new(self.connection.as_ref(), stream))
}
/// Binds parameters to a statement, creating a "portal".
@ -118,7 +113,8 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
self.runtime.block_on(self.transaction.bind(query, params))
self.connection
.block_on(self.transaction.bind(query, params))
}
/// Continues execution of a portal, returning the next set of rows.
@ -126,7 +122,7 @@ impl<'a> Transaction<'a> {
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
self.runtime
self.connection
.block_on(self.transaction.query_portal(portal, max_rows))
}
@ -137,9 +133,9 @@ impl<'a> Transaction<'a> {
max_rows: i32,
) -> Result<RowIter<'_>, Error> {
let stream = self
.runtime
.connection
.block_on(self.transaction.query_portal_raw(portal, max_rows))?;
Ok(RowIter::new(self.rt(), stream))
Ok(RowIter::new(self.connection.as_ref(), stream))
}
/// Like `Client::copy_in`.
@ -147,8 +143,8 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
let sink = self.runtime.block_on(self.transaction.copy_in(query))?;
Ok(CopyInWriter::new(self.rt(), sink))
let sink = self.connection.block_on(self.transaction.copy_in(query))?;
Ok(CopyInWriter::new(self.connection.as_ref(), sink))
}
/// Like `Client::copy_out`.
@ -156,18 +152,20 @@ impl<'a> Transaction<'a> {
where
T: ?Sized + ToStatement,
{
let stream = self.runtime.block_on(self.transaction.copy_out(query))?;
Ok(CopyOutReader::new(self.rt(), stream))
let stream = self.connection.block_on(self.transaction.copy_out(query))?;
Ok(CopyOutReader::new(self.connection.as_ref(), stream))
}
/// Like `Client::simple_query`.
pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.runtime.block_on(self.transaction.simple_query(query))
self.connection
.block_on(self.transaction.simple_query(query))
}
/// Like `Client::batch_execute`.
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
self.runtime.block_on(self.transaction.batch_execute(query))
self.connection
.block_on(self.transaction.batch_execute(query))
}
/// Like `Client::cancel_token`.
@ -177,9 +175,9 @@ impl<'a> Transaction<'a> {
/// Like `Client::transaction`.
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = self.runtime.block_on(self.transaction.transaction())?;
let transaction = self.connection.block_on(self.transaction.transaction())?;
Ok(Transaction {
runtime: self.runtime,
connection: self.connection.as_ref(),
transaction,
})
}

View File

@ -1,18 +1,21 @@
use crate::connection::ConnectionRef;
use crate::{Error, IsolationLevel, Transaction};
use tokio::runtime::Runtime;
/// A builder for database transactions.
pub struct TransactionBuilder<'a> {
runtime: &'a mut Runtime,
connection: ConnectionRef<'a>,
builder: tokio_postgres::TransactionBuilder<'a>,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(
runtime: &'a mut Runtime,
connection: ConnectionRef<'a>,
builder: tokio_postgres::TransactionBuilder<'a>,
) -> TransactionBuilder<'a> {
TransactionBuilder { runtime, builder }
TransactionBuilder {
connection,
builder,
}
}
/// Sets the isolation level of the transaction.
@ -40,8 +43,8 @@ impl<'a> TransactionBuilder<'a> {
/// Begins the transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub fn start(self) -> Result<Transaction<'a>, Error> {
let transaction = self.runtime.block_on(self.builder.start())?;
Ok(Transaction::new(self.runtime, transaction))
pub fn start(mut self) -> Result<Transaction<'a>, Error> {
let transaction = self.connection.block_on(self.builder.start())?;
Ok(Transaction::new(self.connection, transaction))
}
}

View File

@ -559,11 +559,9 @@ async fn copy_out() {
.copy_out(&stmt)
.await
.unwrap()
.try_fold(BytesMut::new(), |mut buf, chunk| {
async move {
buf.extend_from_slice(&chunk);
Ok(buf)
}
.try_fold(BytesMut::new(), |mut buf, chunk| async move {
buf.extend_from_slice(&chunk);
Ok(buf)
})
.await
.unwrap();