diff --git a/postgres/src/binary_copy.rs b/postgres/src/binary_copy.rs index 7828cb59..25934719 100644 --- a/postgres/src/binary_copy.rs +++ b/postgres/src/binary_copy.rs @@ -1,7 +1,8 @@ //! Utilities for working with the PostgreSQL binary copy format. +use crate::connection::ConnectionRef; use crate::types::{ToSql, Type}; -use crate::{CopyInWriter, CopyOutReader, Error, Rt}; +use crate::{CopyInWriter, CopyOutReader, Error}; use fallible_iterator::FallibleIterator; use futures::StreamExt; use std::pin::Pin; @@ -13,7 +14,7 @@ use tokio_postgres::binary_copy::{self, BinaryCopyOutStream}; /// /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. pub struct BinaryCopyInWriter<'a> { - runtime: Rt<'a>, + connection: ConnectionRef<'a>, sink: Pin>, } @@ -26,7 +27,7 @@ impl<'a> BinaryCopyInWriter<'a> { .expect("writer has already been written to"); BinaryCopyInWriter { - runtime: writer.runtime, + connection: writer.connection, sink: Box::pin(binary_copy::BinaryCopyInWriter::new(stream, types)), } } @@ -37,7 +38,7 @@ impl<'a> BinaryCopyInWriter<'a> { /// /// Panics if the number of values provided does not match the number expected. pub fn write(&mut self, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> { - self.runtime.block_on(self.sink.as_mut().write(values)) + self.connection.block_on(self.sink.as_mut().write(values)) } /// A maximally-flexible version of `write`. @@ -50,20 +51,21 @@ impl<'a> BinaryCopyInWriter<'a> { I: IntoIterator, I::IntoIter: ExactSizeIterator, { - self.runtime.block_on(self.sink.as_mut().write_raw(values)) + self.connection + .block_on(self.sink.as_mut().write_raw(values)) } /// Completes the copy, returning the number of rows added. /// /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted. pub fn finish(mut self) -> Result { - self.runtime.block_on(self.sink.as_mut().finish()) + self.connection.block_on(self.sink.as_mut().finish()) } } /// An iterator of rows deserialized from the PostgreSQL binary copy format. pub struct BinaryCopyOutIter<'a> { - runtime: Rt<'a>, + connection: ConnectionRef<'a>, stream: Pin>, } @@ -76,7 +78,7 @@ impl<'a> BinaryCopyOutIter<'a> { .expect("reader has already been read from"); BinaryCopyOutIter { - runtime: reader.runtime, + connection: reader.connection, stream: Box::pin(BinaryCopyOutStream::new(stream, types)), } } @@ -87,6 +89,8 @@ impl FallibleIterator for BinaryCopyOutIter<'_> { type Error = Error; fn next(&mut self) -> Result, Error> { - self.runtime.block_on(self.stream.next()).transpose() + let stream = &mut self.stream; + self.connection + .block_on(async { stream.next().await.transpose() }) } } diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 0a3a51e1..3ae5f86c 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -1,45 +1,21 @@ +use crate::connection::Connection; use crate::{ CancelToken, Config, CopyInWriter, CopyOutReader, RowIter, Statement, ToStatement, Transaction, TransactionBuilder, }; -use std::ops::{Deref, DerefMut}; -use tokio::runtime::Runtime; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row, SimpleQueryMessage, Socket}; -pub(crate) struct Rt<'a>(pub &'a mut Runtime); - -// no-op impl to extend the borrow until drop -impl Drop for Rt<'_> { - fn drop(&mut self) {} -} - -impl Deref for Rt<'_> { - type Target = Runtime; - - #[inline] - fn deref(&self) -> &Runtime { - self.0 - } -} - -impl DerefMut for Rt<'_> { - #[inline] - fn deref_mut(&mut self) -> &mut Runtime { - self.0 - } -} - /// A synchronous PostgreSQL client. pub struct Client { - runtime: Runtime, + connection: Connection, client: tokio_postgres::Client, } impl Client { - pub(crate) fn new(runtime: Runtime, client: tokio_postgres::Client) -> Client { - Client { runtime, client } + pub(crate) fn new(connection: Connection, client: tokio_postgres::Client) -> Client { + Client { connection, client } } /// A convenience function which parses a configuration string into a `Config` and then connects to the database. @@ -62,10 +38,6 @@ impl Client { Config::new() } - fn rt(&mut self) -> Rt<'_> { - Rt(&mut self.runtime) - } - /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list @@ -104,7 +76,7 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.execute(query, params)) + self.connection.block_on(self.client.execute(query, params)) } /// Executes a statement, returning the resulting rows. @@ -140,7 +112,7 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.query(query, params)) + self.connection.block_on(self.client.query(query, params)) } /// Executes a statement which returns a single row, returning it. @@ -177,7 +149,8 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.query_one(query, params)) + self.connection + .block_on(self.client.query_one(query, params)) } /// Executes a statement which returns zero or one rows, returning it. @@ -223,7 +196,8 @@ impl Client { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.client.query_opt(query, params)) + self.connection + .block_on(self.client.query_opt(query, params)) } /// A maximally-flexible version of `query`. @@ -289,9 +263,9 @@ impl Client { I::IntoIter: ExactSizeIterator, { let stream = self - .runtime + .connection .block_on(self.client.query_raw(query, params))?; - Ok(RowIter::new(self.rt(), stream)) + Ok(RowIter::new(self.connection.as_ref(), stream)) } /// Creates a new prepared statement. @@ -318,7 +292,7 @@ impl Client { /// # } /// ``` pub fn prepare(&mut self, query: &str) -> Result { - self.runtime.block_on(self.client.prepare(query)) + self.connection.block_on(self.client.prepare(query)) } /// Like `prepare`, but allows the types of query parameters to be explicitly specified. @@ -349,7 +323,7 @@ impl Client { /// # } /// ``` pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { - self.runtime + self.connection .block_on(self.client.prepare_typed(query, types)) } @@ -380,8 +354,8 @@ impl Client { where T: ?Sized + ToStatement, { - let sink = self.runtime.block_on(self.client.copy_in(query))?; - Ok(CopyInWriter::new(self.rt(), sink)) + let sink = self.connection.block_on(self.client.copy_in(query))?; + Ok(CopyInWriter::new(self.connection.as_ref(), sink)) } /// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data. @@ -408,8 +382,8 @@ impl Client { where T: ?Sized + ToStatement, { - let stream = self.runtime.block_on(self.client.copy_out(query))?; - Ok(CopyOutReader::new(self.rt(), stream)) + let stream = self.connection.block_on(self.client.copy_out(query))?; + Ok(CopyOutReader::new(self.connection.as_ref(), stream)) } /// Executes a sequence of SQL statements using the simple query protocol. @@ -428,7 +402,7 @@ impl Client { /// functionality to safely imbed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn simple_query(&mut self, query: &str) -> Result, Error> { - self.runtime.block_on(self.client.simple_query(query)) + self.connection.block_on(self.client.simple_query(query)) } /// Executes a sequence of SQL statements using the simple query protocol. @@ -442,7 +416,7 @@ impl Client { /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.runtime.block_on(self.client.batch_execute(query)) + self.connection.block_on(self.client.batch_execute(query)) } /// Begins a new database transaction. @@ -466,8 +440,8 @@ impl Client { /// # } /// ``` pub fn transaction(&mut self) -> Result, Error> { - let transaction = self.runtime.block_on(self.client.transaction())?; - Ok(Transaction::new(&mut self.runtime, transaction)) + let transaction = self.connection.block_on(self.client.transaction())?; + Ok(Transaction::new(self.connection.as_ref(), transaction)) } /// Returns a builder for a transaction with custom settings. @@ -494,7 +468,7 @@ impl Client { /// # } /// ``` pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { - TransactionBuilder::new(&mut self.runtime, self.client.build_transaction()) + TransactionBuilder::new(self.connection.as_ref(), self.client.build_transaction()) } /// Constructs a cancellation token that can later be used to request diff --git a/postgres/src/config.rs b/postgres/src/config.rs index f6b151a8..b344efdd 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -2,9 +2,8 @@ //! //! Requires the `runtime` Cargo feature (enabled by default). +use crate::connection::Connection; use crate::Client; -use futures::FutureExt; -use log::error; use std::fmt; use std::path::Path; use std::str::FromStr; @@ -324,15 +323,8 @@ impl Config { let (client, connection) = runtime.block_on(self.config.connect(tls))?; - // FIXME don't spawn this so error reporting is less weird. - let connection = connection.map(|r| { - if let Err(e) = r { - error!("postgres connection error: {}", e) - } - }); - runtime.spawn(connection); - - Ok(Client::new(runtime, client)) + let connection = Connection::new(runtime, connection); + Ok(Client::new(connection, client)) } } diff --git a/postgres/src/connection.rs b/postgres/src/connection.rs new file mode 100644 index 00000000..440ad5da --- /dev/null +++ b/postgres/src/connection.rs @@ -0,0 +1,106 @@ +use crate::{Error, Notification}; +use futures::future; +use futures::{pin_mut, Stream}; +use log::info; +use std::collections::VecDeque; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::runtime::Runtime; +use tokio_postgres::AsyncMessage; + +pub struct Connection { + runtime: Runtime, + connection: Pin> + Send>>, + notifications: VecDeque, +} + +impl Connection { + pub fn new(runtime: Runtime, connection: tokio_postgres::Connection) -> Connection + where + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, + T: AsyncRead + AsyncWrite + Unpin + 'static + Send, + { + Connection { + runtime, + connection: Box::pin(ConnectionStream { connection }), + notifications: VecDeque::new(), + } + } + + pub fn as_ref(&mut self) -> ConnectionRef<'_> { + ConnectionRef { connection: self } + } + + pub fn block_on(&mut self, future: F) -> Result + where + F: Future>, + { + pin_mut!(future); + let connection = &mut self.connection; + let notifications = &mut self.notifications; + self.runtime.block_on({ + future::poll_fn(|cx| { + loop { + match connection.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(AsyncMessage::Notification(notification)))) => { + notifications.push_back(notification); + } + Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => { + info!("{}: {}", notice.severity(), notice.message()); + } + Poll::Ready(Some(Ok(_))) => {} + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) | Poll::Pending => break, + } + } + + future.as_mut().poll(cx) + }) + }) + } +} + +pub struct ConnectionRef<'a> { + connection: &'a mut Connection, +} + +// no-op impl to extend the borrow until drop +impl Drop for ConnectionRef<'_> { + #[inline] + fn drop(&mut self) {} +} + +impl Deref for ConnectionRef<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.connection + } +} + +impl DerefMut for ConnectionRef<'_> { + #[inline] + fn deref_mut(&mut self) -> &mut Connection { + self.connection + } +} + +struct ConnectionStream { + connection: tokio_postgres::Connection, +} + +impl Stream for ConnectionStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.connection.poll_message(cx) + } +} diff --git a/postgres/src/copy_in_writer.rs b/postgres/src/copy_in_writer.rs index fc11818a..c996ed85 100644 --- a/postgres/src/copy_in_writer.rs +++ b/postgres/src/copy_in_writer.rs @@ -1,5 +1,5 @@ +use crate::connection::ConnectionRef; use crate::lazy_pin::LazyPin; -use crate::Rt; use bytes::{Bytes, BytesMut}; use futures::SinkExt; use std::io; @@ -10,15 +10,15 @@ use tokio_postgres::{CopyInSink, Error}; /// /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. pub struct CopyInWriter<'a> { - pub(crate) runtime: Rt<'a>, + pub(crate) connection: ConnectionRef<'a>, pub(crate) sink: LazyPin>, buf: BytesMut, } impl<'a> CopyInWriter<'a> { - pub(crate) fn new(runtime: Rt<'a>, sink: CopyInSink) -> CopyInWriter<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, sink: CopyInSink) -> CopyInWriter<'a> { CopyInWriter { - runtime, + connection, sink: LazyPin::new(sink), buf: BytesMut::new(), } @@ -29,7 +29,7 @@ impl<'a> CopyInWriter<'a> { /// If this is not called, the copy will be aborted. pub fn finish(mut self) -> Result { self.flush_inner()?; - self.runtime.block_on(self.sink.pinned().finish()) + self.connection.block_on(self.sink.pinned().finish()) } fn flush_inner(&mut self) -> Result<(), Error> { @@ -37,7 +37,7 @@ impl<'a> CopyInWriter<'a> { return Ok(()); } - self.runtime + self.connection .block_on(self.sink.pinned().send(self.buf.split().freeze())) } } diff --git a/postgres/src/copy_out_reader.rs b/postgres/src/copy_out_reader.rs index 9091e220..a205d1a1 100644 --- a/postgres/src/copy_out_reader.rs +++ b/postgres/src/copy_out_reader.rs @@ -1,5 +1,5 @@ +use crate::connection::ConnectionRef; use crate::lazy_pin::LazyPin; -use crate::Rt; use bytes::{Buf, Bytes}; use futures::StreamExt; use std::io::{self, BufRead, Read}; @@ -7,15 +7,15 @@ use tokio_postgres::CopyOutStream; /// The reader returned by the `copy_out` method. pub struct CopyOutReader<'a> { - pub(crate) runtime: Rt<'a>, + pub(crate) connection: ConnectionRef<'a>, pub(crate) stream: LazyPin, cur: Bytes, } impl<'a> CopyOutReader<'a> { - pub(crate) fn new(runtime: Rt<'a>, stream: CopyOutStream) -> CopyOutReader<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, stream: CopyOutStream) -> CopyOutReader<'a> { CopyOutReader { - runtime, + connection, stream: LazyPin::new(stream), cur: Bytes::new(), } @@ -35,10 +35,14 @@ impl Read for CopyOutReader<'_> { impl BufRead for CopyOutReader<'_> { fn fill_buf(&mut self) -> io::Result<&[u8]> { if !self.cur.has_remaining() { - match self.runtime.block_on(self.stream.pinned().next()) { - Some(Ok(cur)) => self.cur = cur, - Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)), - None => {} + let mut stream = self.stream.pinned(); + match self + .connection + .block_on({ async { stream.next().await.transpose() } }) + { + Ok(Some(cur)) => self.cur = cur, + Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + Ok(None) => {} }; } diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index 2b2dcec3..78b318b1 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -65,8 +65,8 @@ pub use fallible_iterator; pub use tokio_postgres::{ - error, row, tls, types, Column, IsolationLevel, Portal, SimpleQueryMessage, Socket, Statement, - ToStatement, + error, row, tls, types, Column, IsolationLevel, Notification, Portal, SimpleQueryMessage, + Socket, Statement, ToStatement, }; pub use crate::cancel_token::CancelToken; @@ -89,6 +89,7 @@ pub mod binary_copy; mod cancel_token; mod client; pub mod config; +mod connection; mod copy_in_writer; mod copy_out_reader; mod generic_client; diff --git a/postgres/src/row_iter.rs b/postgres/src/row_iter.rs index 4be5f347..3cd41b90 100644 --- a/postgres/src/row_iter.rs +++ b/postgres/src/row_iter.rs @@ -1,4 +1,4 @@ -use crate::Rt; +use crate::connection::ConnectionRef; use fallible_iterator::FallibleIterator; use futures::StreamExt; use std::pin::Pin; @@ -6,19 +6,14 @@ use tokio_postgres::{Error, Row, RowStream}; /// The iterator returned by `query_raw`. pub struct RowIter<'a> { - runtime: Rt<'a>, + connection: ConnectionRef<'a>, it: Pin>, } -// no-op impl to extend the borrow until drop -impl Drop for RowIter<'_> { - fn drop(&mut self) {} -} - impl<'a> RowIter<'a> { - pub(crate) fn new(runtime: Rt<'a>, stream: RowStream) -> RowIter<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, stream: RowStream) -> RowIter<'a> { RowIter { - runtime, + connection, it: Box::pin(stream), } } @@ -29,6 +24,8 @@ impl FallibleIterator for RowIter<'_> { type Error = Error; fn next(&mut self) -> Result, Error> { - self.runtime.block_on(self.it.next()).transpose() + let it = &mut self.it; + self.connection + .block_on(async { it.next().await.transpose() }) } } diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index e5b3682f..25bfff57 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -1,7 +1,5 @@ -use crate::{ - CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Rt, Statement, ToStatement, -}; -use tokio::runtime::Runtime; +use crate::connection::ConnectionRef; +use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement}; use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::{Error, Row, SimpleQueryMessage}; @@ -10,45 +8,41 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage}; /// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made /// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints. pub struct Transaction<'a> { - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, transaction: tokio_postgres::Transaction<'a>, } impl<'a> Transaction<'a> { pub(crate) fn new( - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, transaction: tokio_postgres::Transaction<'a>, ) -> Transaction<'a> { Transaction { - runtime, + connection, transaction, } } - fn rt(&mut self) -> Rt<'_> { - Rt(self.runtime) - } - /// Consumes the transaction, committing all changes made within it. - pub fn commit(self) -> Result<(), Error> { - self.runtime.block_on(self.transaction.commit()) + pub fn commit(mut self) -> Result<(), Error> { + self.connection.block_on(self.transaction.commit()) } /// Rolls the transaction back, discarding all changes made within it. /// /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. - pub fn rollback(self) -> Result<(), Error> { - self.runtime.block_on(self.transaction.rollback()) + pub fn rollback(mut self) -> Result<(), Error> { + self.connection.block_on(self.transaction.rollback()) } /// Like `Client::prepare`. pub fn prepare(&mut self, query: &str) -> Result { - self.runtime.block_on(self.transaction.prepare(query)) + self.connection.block_on(self.transaction.prepare(query)) } /// Like `Client::prepare_typed`. pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { - self.runtime + self.connection .block_on(self.transaction.prepare_typed(query, types)) } @@ -57,7 +51,7 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime + self.connection .block_on(self.transaction.execute(query, params)) } @@ -66,7 +60,8 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.transaction.query(query, params)) + self.connection + .block_on(self.transaction.query(query, params)) } /// Like `Client::query_one`. @@ -74,7 +69,7 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime + self.connection .block_on(self.transaction.query_one(query, params)) } @@ -87,7 +82,7 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime + self.connection .block_on(self.transaction.query_opt(query, params)) } @@ -99,9 +94,9 @@ impl<'a> Transaction<'a> { I::IntoIter: ExactSizeIterator, { let stream = self - .runtime + .connection .block_on(self.transaction.query_raw(query, params))?; - Ok(RowIter::new(self.rt(), stream)) + Ok(RowIter::new(self.connection.as_ref(), stream)) } /// Binds parameters to a statement, creating a "portal". @@ -118,7 +113,8 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - self.runtime.block_on(self.transaction.bind(query, params)) + self.connection + .block_on(self.transaction.bind(query, params)) } /// Continues execution of a portal, returning the next set of rows. @@ -126,7 +122,7 @@ impl<'a> Transaction<'a> { /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned. pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result, Error> { - self.runtime + self.connection .block_on(self.transaction.query_portal(portal, max_rows)) } @@ -137,9 +133,9 @@ impl<'a> Transaction<'a> { max_rows: i32, ) -> Result, Error> { let stream = self - .runtime + .connection .block_on(self.transaction.query_portal_raw(portal, max_rows))?; - Ok(RowIter::new(self.rt(), stream)) + Ok(RowIter::new(self.connection.as_ref(), stream)) } /// Like `Client::copy_in`. @@ -147,8 +143,8 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - let sink = self.runtime.block_on(self.transaction.copy_in(query))?; - Ok(CopyInWriter::new(self.rt(), sink)) + let sink = self.connection.block_on(self.transaction.copy_in(query))?; + Ok(CopyInWriter::new(self.connection.as_ref(), sink)) } /// Like `Client::copy_out`. @@ -156,18 +152,20 @@ impl<'a> Transaction<'a> { where T: ?Sized + ToStatement, { - let stream = self.runtime.block_on(self.transaction.copy_out(query))?; - Ok(CopyOutReader::new(self.rt(), stream)) + let stream = self.connection.block_on(self.transaction.copy_out(query))?; + Ok(CopyOutReader::new(self.connection.as_ref(), stream)) } /// Like `Client::simple_query`. pub fn simple_query(&mut self, query: &str) -> Result, Error> { - self.runtime.block_on(self.transaction.simple_query(query)) + self.connection + .block_on(self.transaction.simple_query(query)) } /// Like `Client::batch_execute`. pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { - self.runtime.block_on(self.transaction.batch_execute(query)) + self.connection + .block_on(self.transaction.batch_execute(query)) } /// Like `Client::cancel_token`. @@ -177,9 +175,9 @@ impl<'a> Transaction<'a> { /// Like `Client::transaction`. pub fn transaction(&mut self) -> Result, Error> { - let transaction = self.runtime.block_on(self.transaction.transaction())?; + let transaction = self.connection.block_on(self.transaction.transaction())?; Ok(Transaction { - runtime: self.runtime, + connection: self.connection.as_ref(), transaction, }) } diff --git a/postgres/src/transaction_builder.rs b/postgres/src/transaction_builder.rs index d87d1a12..e0f8a56e 100644 --- a/postgres/src/transaction_builder.rs +++ b/postgres/src/transaction_builder.rs @@ -1,18 +1,21 @@ +use crate::connection::ConnectionRef; use crate::{Error, IsolationLevel, Transaction}; -use tokio::runtime::Runtime; /// A builder for database transactions. pub struct TransactionBuilder<'a> { - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, builder: tokio_postgres::TransactionBuilder<'a>, } impl<'a> TransactionBuilder<'a> { pub(crate) fn new( - runtime: &'a mut Runtime, + connection: ConnectionRef<'a>, builder: tokio_postgres::TransactionBuilder<'a>, ) -> TransactionBuilder<'a> { - TransactionBuilder { runtime, builder } + TransactionBuilder { + connection, + builder, + } } /// Sets the isolation level of the transaction. @@ -40,8 +43,8 @@ impl<'a> TransactionBuilder<'a> { /// Begins the transaction. /// /// The transaction will roll back by default - use the `commit` method to commit it. - pub fn start(self) -> Result, Error> { - let transaction = self.runtime.block_on(self.builder.start())?; - Ok(Transaction::new(self.runtime, transaction)) + pub fn start(mut self) -> Result, Error> { + let transaction = self.connection.block_on(self.builder.start())?; + Ok(Transaction::new(self.connection, transaction)) } } diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 73860115..b01037ed 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -559,11 +559,9 @@ async fn copy_out() { .copy_out(&stmt) .await .unwrap() - .try_fold(BytesMut::new(), |mut buf, chunk| { - async move { - buf.extend_from_slice(&chunk); - Ok(buf) - } + .try_fold(BytesMut::new(), |mut buf, chunk| async move { + buf.extend_from_slice(&chunk); + Ok(buf) }) .await .unwrap();