Move to local runtimes per connection

This avoids a bunch of context switches and cross-thread
synchronization, which ends up improving the performance of a simple
query by ~20%, from 252us to 216us.
This commit is contained in:
Steven Fackler 2019-12-03 18:25:29 -08:00
parent d6163c088f
commit 09a63d6255
9 changed files with 151 additions and 179 deletions

View File

@ -10,6 +10,10 @@ readme = "../README.md"
keywords = ["database", "postgres", "postgresql", "sql"] keywords = ["database", "postgres", "postgresql", "sql"]
categories = ["database"] categories = ["database"]
[[bench]]
name = "bench"
harness = false
[package.metadata.docs.rs] [package.metadata.docs.rs]
all-features = true all-features = true
@ -17,9 +21,6 @@ all-features = true
circle-ci = { repository = "sfackler/rust-postgres" } circle-ci = { repository = "sfackler/rust-postgres" }
[features] [features]
default = ["runtime"]
runtime = ["tokio-postgres/runtime", "tokio", "lazy_static", "log"]
with-bit-vec-0_6 = ["tokio-postgres/with-bit-vec-0_6"] with-bit-vec-0_6 = ["tokio-postgres/with-bit-vec-0_6"]
with-chrono-0_4 = ["tokio-postgres/with-chrono-0_4"] with-chrono-0_4 = ["tokio-postgres/with-chrono-0_4"]
with-eui48-0_4 = ["tokio-postgres/with-eui48-0_4"] with-eui48-0_4 = ["tokio-postgres/with-eui48-0_4"]
@ -32,11 +33,11 @@ with-uuid-0_8 = ["tokio-postgres/with-uuid-0_8"]
bytes = "0.5" bytes = "0.5"
fallible-iterator = "0.2" fallible-iterator = "0.2"
futures = "0.3" futures = "0.3"
tokio-postgres = { version = "=0.5.0-alpha.2", path = "../tokio-postgres", default-features = false } tokio-postgres = { version = "=0.5.0-alpha.2", path = "../tokio-postgres" }
tokio = { version = "0.2", optional = true, features = ["rt-threaded"] } tokio = { version = "0.2", features = ["rt-core"] }
lazy_static = { version = "1.0", optional = true } log = "0.4"
log = { version = "0.4", optional = true }
[dev-dependencies] [dev-dependencies]
criterion = "0.3"
tokio = "0.2" tokio = "0.2"

17
postgres/benches/bench.rs Normal file
View File

@ -0,0 +1,17 @@
use criterion::{criterion_group, criterion_main, Criterion};
use postgres::{Client, NoTls};
// spawned: 249us 252us 255us
// local: 214us 216us 219us
fn query_prepared(c: &mut Criterion) {
let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
let stmt = client.prepare("SELECT $1::INT8").unwrap();
c.bench_function("query_prepared", move |b| {
b.iter(|| client.query(&stmt, &[&1i64]).unwrap())
});
}
criterion_group!(group, query_prepared);
criterion_main!(group);

View File

@ -1,27 +1,25 @@
#[cfg(feature = "runtime")] use crate::{Config, CopyInWriter, CopyOutReader, RowIter, Statement, ToStatement, Transaction};
use crate::Config; use tokio::runtime::Runtime;
use crate::{CopyInWriter, CopyOutReader, RowIter, Statement, ToStatement, Transaction};
use futures::executor;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::types::{ToSql, Type};
#[cfg(feature = "runtime")] use tokio_postgres::{Error, Row, SimpleQueryMessage, Socket};
use tokio_postgres::Socket;
use tokio_postgres::{Error, Row, SimpleQueryMessage};
/// A synchronous PostgreSQL client. /// A synchronous PostgreSQL client.
/// pub struct Client {
/// This is a lightweight wrapper over the asynchronous tokio_postgres `Client`. runtime: Runtime,
pub struct Client(tokio_postgres::Client); client: tokio_postgres::Client,
}
impl Client { impl Client {
pub(crate) fn new(runtime: Runtime, client: tokio_postgres::Client) -> Client {
Client { runtime, client }
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database. /// A convenience function which parses a configuration string into a `Config` and then connects to the database.
/// ///
/// See the documentation for [`Config`] for information about the connection syntax. /// See the documentation for [`Config`] for information about the connection syntax.
/// ///
/// Requires the `runtime` Cargo feature (enabled by default).
///
/// [`Config`]: config/struct.Config.html /// [`Config`]: config/struct.Config.html
#[cfg(feature = "runtime")]
pub fn connect<T>(params: &str, tls_mode: T) -> Result<Client, Error> pub fn connect<T>(params: &str, tls_mode: T) -> Result<Client, Error>
where where
T: MakeTlsConnect<Socket> + 'static + Send, T: MakeTlsConnect<Socket> + 'static + Send,
@ -78,7 +76,7 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.execute(query, params)) self.runtime.block_on(self.client.execute(query, params))
} }
/// Executes a statement, returning the resulting rows. /// Executes a statement, returning the resulting rows.
@ -114,7 +112,7 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.query(query, params)) self.runtime.block_on(self.client.query(query, params))
} }
/// Executes a statement which returns a single row, returning it. /// Executes a statement which returns a single row, returning it.
@ -151,7 +149,7 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.query_one(query, params)) self.runtime.block_on(self.client.query_one(query, params))
} }
/// Executes a statement which returns zero or one rows, returning it. /// Executes a statement which returns zero or one rows, returning it.
@ -197,7 +195,7 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.query_opt(query, params)) self.runtime.block_on(self.client.query_opt(query, params))
} }
/// A maximally-flexible version of `query`. /// A maximally-flexible version of `query`.
@ -235,8 +233,10 @@ impl Client {
I: IntoIterator<Item = &'a dyn ToSql>, I: IntoIterator<Item = &'a dyn ToSql>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let stream = executor::block_on(self.0.query_raw(query, params))?; let stream = self
Ok(RowIter::new(stream)) .runtime
.block_on(self.client.query_raw(query, params))?;
Ok(RowIter::new(&mut self.runtime, stream))
} }
/// Creates a new prepared statement. /// Creates a new prepared statement.
@ -263,7 +263,7 @@ impl Client {
/// # } /// # }
/// ``` /// ```
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> { pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
executor::block_on(self.0.prepare(query)) self.runtime.block_on(self.client.prepare(query))
} }
/// Like `prepare`, but allows the types of query parameters to be explicitly specified. /// Like `prepare`, but allows the types of query parameters to be explicitly specified.
@ -294,7 +294,8 @@ impl Client {
/// # } /// # }
/// ``` /// ```
pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> { pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
executor::block_on(self.0.prepare_typed(query, types)) self.runtime
.block_on(self.client.prepare_typed(query, types))
} }
/// Executes a `COPY FROM STDIN` statement, returning the number of rows created. /// Executes a `COPY FROM STDIN` statement, returning the number of rows created.
@ -327,8 +328,8 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let sink = executor::block_on(self.0.copy_in(query, params))?; let sink = self.runtime.block_on(self.client.copy_in(query, params))?;
Ok(CopyInWriter::new(sink)) Ok(CopyInWriter::new(&mut self.runtime, sink))
} }
/// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data. /// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data.
@ -358,8 +359,8 @@ impl Client {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let stream = executor::block_on(self.0.copy_out(query, params))?; let stream = self.runtime.block_on(self.client.copy_out(query, params))?;
CopyOutReader::new(stream) CopyOutReader::new(&mut self.runtime, stream)
} }
/// Executes a sequence of SQL statements using the simple query protocol. /// Executes a sequence of SQL statements using the simple query protocol.
@ -378,7 +379,7 @@ impl Client {
/// functionality to safely imbed that data in the request. Do not form statements via string concatenation and pass /// functionality to safely imbed that data in the request. Do not form statements via string concatenation and pass
/// them to this method! /// them to this method!
pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> { pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
executor::block_on(self.0.simple_query(query)) self.runtime.block_on(self.client.simple_query(query))
} }
/// Executes a sequence of SQL statements using the simple query protocol. /// Executes a sequence of SQL statements using the simple query protocol.
@ -392,7 +393,7 @@ impl Client {
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method! /// them to this method!
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
executor::block_on(self.0.batch_execute(query)) self.runtime.block_on(self.client.batch_execute(query))
} }
/// Begins a new database transaction. /// Begins a new database transaction.
@ -416,35 +417,14 @@ impl Client {
/// # } /// # }
/// ``` /// ```
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> { pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = executor::block_on(self.0.transaction())?; let transaction = self.runtime.block_on(self.client.transaction())?;
Ok(Transaction::new(transaction)) Ok(Transaction::new(&mut self.runtime, transaction))
} }
/// Determines if the client's connection has already closed. /// Determines if the client's connection has already closed.
/// ///
/// If this returns `true`, the client is no longer usable. /// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool { pub fn is_closed(&self) -> bool {
self.0.is_closed() self.client.is_closed()
}
/// Returns a shared reference to the inner nonblocking client.
pub fn get_ref(&self) -> &tokio_postgres::Client {
&self.0
}
/// Returns a mutable reference to the inner nonblocking client.
pub fn get_mut(&mut self) -> &mut tokio_postgres::Client {
&mut self.0
}
/// Consumes the client, returning the inner nonblocking client.
pub fn into_inner(self) -> tokio_postgres::Client {
self.0
}
}
impl From<tokio_postgres::Client> for Client {
fn from(c: tokio_postgres::Client) -> Client {
Client(c)
} }
} }

View File

@ -2,23 +2,19 @@
//! //!
//! Requires the `runtime` Cargo feature (enabled by default). //! Requires the `runtime` Cargo feature (enabled by default).
use crate::{Client, RUNTIME}; use crate::Client;
use futures::{executor, FutureExt}; use futures::FutureExt;
use log::error; use log::error;
use std::fmt; use std::fmt;
use std::future::Future;
use std::path::Path; use std::path::Path;
use std::pin::Pin;
use std::str::FromStr; use std::str::FromStr;
use std::sync::{mpsc, Arc};
use std::time::Duration; use std::time::Duration;
use tokio::runtime;
#[doc(inline)] #[doc(inline)]
pub use tokio_postgres::config::{ChannelBinding, SslMode, TargetSessionAttrs}; pub use tokio_postgres::config::{ChannelBinding, SslMode, TargetSessionAttrs};
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Error, Socket}; use tokio_postgres::{Error, Socket};
type Spawn = dyn Fn(Pin<Box<dyn Future<Output = ()> + Send>>) + Sync + Send;
/// Connection configuration. /// Connection configuration.
/// ///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: /// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
@ -95,7 +91,6 @@ type Spawn = dyn Fn(Pin<Box<dyn Future<Output = ()> + Send>>) + Sync + Send;
#[derive(Clone)] #[derive(Clone)]
pub struct Config { pub struct Config {
config: tokio_postgres::Config, config: tokio_postgres::Config,
spawner: Option<Arc<Spawn>>,
} }
impl fmt::Debug for Config { impl fmt::Debug for Config {
@ -117,7 +112,6 @@ impl Config {
pub fn new() -> Config { pub fn new() -> Config {
Config { Config {
config: tokio_postgres::Config::new(), config: tokio_postgres::Config::new(),
spawner: None,
} }
} }
@ -242,17 +236,6 @@ impl Config {
self self
} }
/// Sets the spawner used to run the connection futures.
///
/// Defaults to a postgres-specific tokio `Runtime`.
pub fn spawner<F>(&mut self, spawn: F) -> &mut Config
where
F: Fn(Pin<Box<dyn Future<Output = ()> + Send>>) + 'static + Sync + Send,
{
self.spawner = Some(Arc::new(spawn));
self
}
/// Opens a connection to a PostgreSQL database. /// Opens a connection to a PostgreSQL database.
pub fn connect<T>(&self, tls: T) -> Result<Client, Error> pub fn connect<T>(&self, tls: T) -> Result<Client, Error>
where where
@ -261,38 +244,23 @@ impl Config {
T::Stream: Send, T::Stream: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send, <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{ {
let (client, connection) = match &self.spawner { let mut runtime = runtime::Builder::new()
Some(spawn) => { .enable_all()
let (tx, rx) = mpsc::channel(); .basic_scheduler()
let config = self.config.clone(); .build()
let connect = async move { .unwrap(); // FIXME don't unwrap
let r = config.connect(tls).await;
let _ = tx.send(r);
};
spawn(Box::pin(connect));
rx.recv().unwrap()?
}
None => {
let connect = self.config.connect(tls);
RUNTIME.handle().enter(|| executor::block_on(connect))?
}
};
let (client, connection) = runtime.block_on(self.config.connect(tls))?;
// FIXME don't spawn this so error reporting is less weird.
let connection = connection.map(|r| { let connection = connection.map(|r| {
if let Err(e) = r { if let Err(e) = r {
error!("postgres connection error: {}", e) error!("postgres connection error: {}", e)
} }
}); });
match &self.spawner { runtime.spawn(connection);
Some(spawn) => {
spawn(Box::pin(connection));
}
None => {
RUNTIME.spawn(connection);
}
}
Ok(Client::from(client)) Ok(Client::new(runtime, client))
} }
} }
@ -306,9 +274,6 @@ impl FromStr for Config {
impl From<tokio_postgres::Config> for Config { impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config { fn from(config: tokio_postgres::Config) -> Config {
Config { Config { config }
config,
spawner: None,
}
} }
} }

View File

@ -1,18 +1,18 @@
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures::{executor, SinkExt}; use futures::SinkExt;
use std::io; use std::io;
use std::io::Write; use std::io::Write;
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use tokio::runtime::Runtime;
use tokio_postgres::{CopyInSink, Error}; use tokio_postgres::{CopyInSink, Error};
/// The writer returned by the `copy_in` method. /// The writer returned by the `copy_in` method.
/// ///
/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted.
pub struct CopyInWriter<'a> { pub struct CopyInWriter<'a> {
runtime: &'a mut Runtime,
sink: Pin<Box<CopyInSink<Bytes>>>, sink: Pin<Box<CopyInSink<Bytes>>>,
buf: BytesMut, buf: BytesMut,
_p: PhantomData<&'a mut ()>,
} }
// no-op impl to extend borrow until drop // no-op impl to extend borrow until drop
@ -21,11 +21,11 @@ impl Drop for CopyInWriter<'_> {
} }
impl<'a> CopyInWriter<'a> { impl<'a> CopyInWriter<'a> {
pub(crate) fn new(sink: CopyInSink<Bytes>) -> CopyInWriter<'a> { pub(crate) fn new(runtime: &'a mut Runtime, sink: CopyInSink<Bytes>) -> CopyInWriter<'a> {
CopyInWriter { CopyInWriter {
runtime,
sink: Box::pin(sink), sink: Box::pin(sink),
buf: BytesMut::new(), buf: BytesMut::new(),
_p: PhantomData,
} }
} }
@ -34,7 +34,7 @@ impl<'a> CopyInWriter<'a> {
/// If this is not called, the copy will be aborted. /// If this is not called, the copy will be aborted.
pub fn finish(mut self) -> Result<u64, Error> { pub fn finish(mut self) -> Result<u64, Error> {
self.flush_inner()?; self.flush_inner()?;
executor::block_on(self.sink.as_mut().finish()) self.runtime.block_on(self.sink.as_mut().finish())
} }
fn flush_inner(&mut self) -> Result<(), Error> { fn flush_inner(&mut self) -> Result<(), Error> {
@ -42,7 +42,8 @@ impl<'a> CopyInWriter<'a> {
return Ok(()); return Ok(());
} }
executor::block_on(self.sink.as_mut().send(self.buf.split().freeze())) self.runtime
.block_on(self.sink.as_mut().send(self.buf.split().freeze()))
} }
} }

View File

@ -1,15 +1,15 @@
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use futures::executor; use futures::StreamExt;
use std::io::{self, BufRead, Cursor, Read}; use std::io::{self, BufRead, Cursor, Read};
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use tokio::runtime::Runtime;
use tokio_postgres::{CopyOutStream, Error}; use tokio_postgres::{CopyOutStream, Error};
/// The reader returned by the `copy_out` method. /// The reader returned by the `copy_out` method.
pub struct CopyOutReader<'a> { pub struct CopyOutReader<'a> {
it: executor::BlockingStream<Pin<Box<CopyOutStream>>>, runtime: &'a mut Runtime,
stream: Pin<Box<CopyOutStream>>,
cur: Cursor<Bytes>, cur: Cursor<Bytes>,
_p: PhantomData<&'a mut ()>,
} }
// no-op impl to extend borrow until drop // no-op impl to extend borrow until drop
@ -18,18 +18,21 @@ impl Drop for CopyOutReader<'_> {
} }
impl<'a> CopyOutReader<'a> { impl<'a> CopyOutReader<'a> {
pub(crate) fn new(stream: CopyOutStream) -> Result<CopyOutReader<'a>, Error> { pub(crate) fn new(
let mut it = executor::block_on_stream(Box::pin(stream)); runtime: &'a mut Runtime,
let cur = match it.next() { stream: CopyOutStream,
) -> Result<CopyOutReader<'a>, Error> {
let mut stream = Box::pin(stream);
let cur = match runtime.block_on(stream.next()) {
Some(Ok(cur)) => cur, Some(Ok(cur)) => cur,
Some(Err(e)) => return Err(e), Some(Err(e)) => return Err(e),
None => Bytes::new(), None => Bytes::new(),
}; };
Ok(CopyOutReader { Ok(CopyOutReader {
it, runtime,
stream,
cur: Cursor::new(cur), cur: Cursor::new(cur),
_p: PhantomData,
}) })
} }
} }
@ -47,7 +50,7 @@ impl Read for CopyOutReader<'_> {
impl BufRead for CopyOutReader<'_> { impl BufRead for CopyOutReader<'_> {
fn fill_buf(&mut self) -> io::Result<&[u8]> { fn fill_buf(&mut self) -> io::Result<&[u8]> {
if self.cur.remaining() == 0 { if self.cur.remaining() == 0 {
match self.it.next() { match self.runtime.block_on(self.stream.next()) {
Some(Ok(cur)) => self.cur = Cursor::new(cur), Some(Ok(cur)) => self.cur = Cursor::new(cur),
Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)), Some(Err(e)) => return Err(io::Error::new(io::ErrorKind::Other, e)),
None => {} None => {}

View File

@ -55,19 +55,11 @@
#![doc(html_root_url = "https://docs.rs/postgres/0.17")] #![doc(html_root_url = "https://docs.rs/postgres/0.17")]
#![warn(clippy::all, rust_2018_idioms, missing_docs)] #![warn(clippy::all, rust_2018_idioms, missing_docs)]
#[cfg(feature = "runtime")]
use lazy_static::lazy_static;
#[cfg(feature = "runtime")]
use tokio::runtime::{self, Runtime};
#[cfg(feature = "runtime")]
pub use tokio_postgres::Socket;
pub use tokio_postgres::{ pub use tokio_postgres::{
error, row, tls, types, Column, Portal, SimpleQueryMessage, Statement, ToStatement, error, row, tls, types, Column, Portal, SimpleQueryMessage, Socket, Statement, ToStatement,
}; };
pub use crate::client::*; pub use crate::client::*;
#[cfg(feature = "runtime")]
pub use crate::config::Config; pub use crate::config::Config;
pub use crate::copy_in_writer::CopyInWriter; pub use crate::copy_in_writer::CopyInWriter;
pub use crate::copy_out_reader::CopyOutReader; pub use crate::copy_out_reader::CopyOutReader;
@ -81,23 +73,11 @@ pub use crate::tls::NoTls;
pub use crate::transaction::*; pub use crate::transaction::*;
mod client; mod client;
#[cfg(feature = "runtime")]
pub mod config; pub mod config;
mod copy_in_writer; mod copy_in_writer;
mod copy_out_reader; mod copy_out_reader;
mod row_iter; mod row_iter;
mod transaction; mod transaction;
#[cfg(feature = "runtime")]
#[cfg(test)] #[cfg(test)]
mod test; mod test;
#[cfg(feature = "runtime")]
lazy_static! {
static ref RUNTIME: Runtime = runtime::Builder::new()
.thread_name("postgres")
.threaded_scheduler()
.enable_all()
.build()
.unwrap();
}

View File

@ -1,13 +1,13 @@
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
use futures::executor::{self, BlockingStream};
use std::marker::PhantomData;
use std::pin::Pin; use std::pin::Pin;
use tokio::runtime::Runtime;
use tokio_postgres::{Error, Row, RowStream}; use tokio_postgres::{Error, Row, RowStream};
use futures::StreamExt;
/// The iterator returned by `query_raw`. /// The iterator returned by `query_raw`.
pub struct RowIter<'a> { pub struct RowIter<'a> {
it: BlockingStream<Pin<Box<RowStream>>>, runtime: &'a mut Runtime,
_p: PhantomData<&'a mut ()>, it: Pin<Box<RowStream>>,
} }
// no-op impl to extend the borrow until drop // no-op impl to extend the borrow until drop
@ -16,10 +16,10 @@ impl Drop for RowIter<'_> {
} }
impl<'a> RowIter<'a> { impl<'a> RowIter<'a> {
pub(crate) fn new(stream: RowStream) -> RowIter<'a> { pub(crate) fn new(runtime: &'a mut Runtime, stream: RowStream) -> RowIter<'a> {
RowIter { RowIter {
it: executor::block_on_stream(Box::pin(stream)), runtime,
_p: PhantomData, it: Box::pin(stream),
} }
} }
} }
@ -29,6 +29,6 @@ impl FallibleIterator for RowIter<'_> {
type Error = Error; type Error = Error;
fn next(&mut self) -> Result<Option<Row>, Error> { fn next(&mut self) -> Result<Option<Row>, Error> {
self.it.next().transpose() self.runtime.block_on(self.it.next()).transpose()
} }
} }

View File

@ -1,5 +1,5 @@
use crate::{CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement}; use crate::{CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement};
use futures::executor; use tokio::runtime::Runtime;
use tokio_postgres::types::{ToSql, Type}; use tokio_postgres::types::{ToSql, Type};
use tokio_postgres::{Error, Row, SimpleQueryMessage}; use tokio_postgres::{Error, Row, SimpleQueryMessage};
@ -7,33 +7,43 @@ use tokio_postgres::{Error, Row, SimpleQueryMessage};
/// ///
/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made /// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made
/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints. /// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints.
pub struct Transaction<'a>(tokio_postgres::Transaction<'a>); pub struct Transaction<'a> {
runtime: &'a mut Runtime,
transaction: tokio_postgres::Transaction<'a>,
}
impl<'a> Transaction<'a> { impl<'a> Transaction<'a> {
pub(crate) fn new(transaction: tokio_postgres::Transaction<'a>) -> Transaction<'a> { pub(crate) fn new(
Transaction(transaction) runtime: &'a mut Runtime,
transaction: tokio_postgres::Transaction<'a>,
) -> Transaction<'a> {
Transaction {
runtime,
transaction,
}
} }
/// Consumes the transaction, committing all changes made within it. /// Consumes the transaction, committing all changes made within it.
pub fn commit(self) -> Result<(), Error> { pub fn commit(self) -> Result<(), Error> {
executor::block_on(self.0.commit()) self.runtime.block_on(self.transaction.commit())
} }
/// Rolls the transaction back, discarding all changes made within it. /// Rolls the transaction back, discarding all changes made within it.
/// ///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub fn rollback(self) -> Result<(), Error> { pub fn rollback(self) -> Result<(), Error> {
executor::block_on(self.0.rollback()) self.runtime.block_on(self.transaction.rollback())
} }
/// Like `Client::prepare`. /// Like `Client::prepare`.
pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> { pub fn prepare(&mut self, query: &str) -> Result<Statement, Error> {
executor::block_on(self.0.prepare(query)) self.runtime.block_on(self.transaction.prepare(query))
} }
/// Like `Client::prepare_typed`. /// Like `Client::prepare_typed`.
pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> { pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result<Statement, Error> {
executor::block_on(self.0.prepare_typed(query, types)) self.runtime
.block_on(self.transaction.prepare_typed(query, types))
} }
/// Like `Client::execute`. /// Like `Client::execute`.
@ -41,7 +51,8 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.execute(query, params)) self.runtime
.block_on(self.transaction.execute(query, params))
} }
/// Like `Client::query`. /// Like `Client::query`.
@ -49,7 +60,7 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.query(query, params)) self.runtime.block_on(self.transaction.query(query, params))
} }
/// Like `Client::query_one`. /// Like `Client::query_one`.
@ -57,7 +68,8 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.query_one(query, params)) self.runtime
.block_on(self.transaction.query_one(query, params))
} }
/// Like `Client::query_opt`. /// Like `Client::query_opt`.
@ -69,7 +81,8 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.query_opt(query, params)) self.runtime
.block_on(self.transaction.query_opt(query, params))
} }
/// Like `Client::query_raw`. /// Like `Client::query_raw`.
@ -79,8 +92,10 @@ impl<'a> Transaction<'a> {
I: IntoIterator<Item = &'b dyn ToSql>, I: IntoIterator<Item = &'b dyn ToSql>,
I::IntoIter: ExactSizeIterator, I::IntoIter: ExactSizeIterator,
{ {
let stream = executor::block_on(self.0.query_raw(query, params))?; let stream = self
Ok(RowIter::new(stream)) .runtime
.block_on(self.transaction.query_raw(query, params))?;
Ok(RowIter::new(self.runtime, stream))
} }
/// Binds parameters to a statement, creating a "portal". /// Binds parameters to a statement, creating a "portal".
@ -97,7 +112,7 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
executor::block_on(self.0.bind(query, params)) self.runtime.block_on(self.transaction.bind(query, params))
} }
/// Continues execution of a portal, returning the next set of rows. /// Continues execution of a portal, returning the next set of rows.
@ -105,7 +120,8 @@ impl<'a> Transaction<'a> {
/// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to
/// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned. /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned.
pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> { pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result<Vec<Row>, Error> {
executor::block_on(self.0.query_portal(portal, max_rows)) self.runtime
.block_on(self.transaction.query_portal(portal, max_rows))
} }
/// The maximally flexible version of `query_portal`. /// The maximally flexible version of `query_portal`.
@ -114,8 +130,10 @@ impl<'a> Transaction<'a> {
portal: &Portal, portal: &Portal,
max_rows: i32, max_rows: i32,
) -> Result<RowIter<'_>, Error> { ) -> Result<RowIter<'_>, Error> {
let stream = executor::block_on(self.0.query_portal_raw(portal, max_rows))?; let stream = self
Ok(RowIter::new(stream)) .runtime
.block_on(self.transaction.query_portal_raw(portal, max_rows))?;
Ok(RowIter::new(self.runtime, stream))
} }
/// Like `Client::copy_in`. /// Like `Client::copy_in`.
@ -127,8 +145,10 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let sink = executor::block_on(self.0.copy_in(query, params))?; let sink = self
Ok(CopyInWriter::new(sink)) .runtime
.block_on(self.transaction.copy_in(query, params))?;
Ok(CopyInWriter::new(self.runtime, sink))
} }
/// Like `Client::copy_out`. /// Like `Client::copy_out`.
@ -140,23 +160,28 @@ impl<'a> Transaction<'a> {
where where
T: ?Sized + ToStatement, T: ?Sized + ToStatement,
{ {
let stream = executor::block_on(self.0.copy_out(query, params))?; let stream = self
CopyOutReader::new(stream) .runtime
.block_on(self.transaction.copy_out(query, params))?;
CopyOutReader::new(self.runtime, stream)
} }
/// Like `Client::simple_query`. /// Like `Client::simple_query`.
pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> { pub fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
executor::block_on(self.0.simple_query(query)) self.runtime.block_on(self.transaction.simple_query(query))
} }
/// Like `Client::batch_execute`. /// Like `Client::batch_execute`.
pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> {
executor::block_on(self.0.batch_execute(query)) self.runtime.block_on(self.transaction.batch_execute(query))
} }
/// Like `Client::transaction`. /// Like `Client::transaction`.
pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> { pub fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
let transaction = executor::block_on(self.0.transaction())?; let transaction = self.runtime.block_on(self.transaction.transaction())?;
Ok(Transaction(transaction)) Ok(Transaction {
runtime: self.runtime,
transaction,
})
} }
} }