Upgrade to tokio 0.3

This commit is contained in:
Steven Fackler 2020-10-17 09:49:45 -04:00
parent d1f9d6d802
commit 2689070d19
17 changed files with 49 additions and 191 deletions

View File

@ -16,13 +16,12 @@ default = ["runtime"]
runtime = ["tokio-postgres/runtime"]
[dependencies]
bytes = "0.5"
futures = "0.3"
native-tls = "0.2"
tokio = "0.2"
tokio-tls = "0.3"
tokio = "0.3"
tokio-native-tls = "0.2"
tokio-postgres = { version = "0.5.0", path = "../tokio-postgres", default-features = false }
[dev-dependencies]
tokio = { version = "0.2", features = ["full"] }
tokio = { version = "0.3", features = ["full"] }
postgres = { version = "0.17.0", path = "../postgres" }

View File

@ -48,13 +48,11 @@
#![doc(html_root_url = "https://docs.rs/postgres-native-tls/0.3")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use bytes::{Buf, BufMut};
use std::future::Future;
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
@ -94,7 +92,7 @@ where
/// A `TlsConnect` implementation using the `native-tls` crate.
pub struct TlsConnector {
connector: tokio_tls::TlsConnector,
connector: tokio_native_tls::TlsConnector,
domain: String,
}
@ -102,7 +100,7 @@ impl TlsConnector {
/// Creates a new connector configured to connect to the specified domain.
pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
TlsConnector {
connector: tokio_tls::TlsConnector::from(connector),
connector: tokio_native_tls::TlsConnector::from(connector),
domain: domain.to_string(),
}
}
@ -129,34 +127,19 @@ where
}
/// The stream returned by `TlsConnector`.
pub struct TlsStream<S>(tokio_tls::TlsStream<S>);
pub struct TlsStream<S>(tokio_native_tls::TlsStream<S>);
impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_read_buf(cx, buf)
}
}
impl<S> AsyncWrite for TlsStream<S>
@ -178,17 +161,6 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_write_buf(cx, buf)
}
}
impl<S> tls::TlsStream for TlsStream<S>
@ -196,7 +168,9 @@ where
S: AsyncRead + AsyncWrite + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
// FIXME https://github.com/tokio-rs/tokio/issues/1383
ChannelBinding::none()
match self.0.get_ref().tls_server_end_point().ok().flatten() {
Some(buf) => ChannelBinding::tls_server_end_point(buf),
None => ChannelBinding::none(),
}
}
}

View File

@ -16,13 +16,12 @@ default = ["runtime"]
runtime = ["tokio-postgres/runtime"]
[dependencies]
bytes = "0.5"
futures = "0.3"
openssl = "0.10"
tokio = "0.2"
tokio-openssl = "0.4"
tokio = "0.3"
tokio-openssl = "0.5"
tokio-postgres = { version = "0.5.0", path = "../tokio-postgres", default-features = false }
[dev-dependencies]
tokio = { version = "0.2", features = ["full"] }
tokio = { version = "0.3", features = ["full"] }
postgres = { version = "0.17.0", path = "../postgres" }

View File

@ -42,7 +42,6 @@
#![doc(html_root_url = "https://docs.rs/postgres-openssl/0.3")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use bytes::{Buf, BufMut};
#[cfg(feature = "runtime")]
use openssl::error::ErrorStack;
use openssl::hash::MessageDigest;
@ -53,12 +52,11 @@ use openssl::ssl::{ConnectConfiguration, SslRef};
use std::fmt::Debug;
use std::future::Future;
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
#[cfg(feature = "runtime")]
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_openssl::{HandshakeError, SslStream};
use tokio_postgres::tls;
#[cfg(feature = "runtime")]
@ -157,28 +155,13 @@ impl<S> AsyncRead for TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
fn poll_read_buf<B: BufMut>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_read_buf(cx, buf)
}
}
impl<S> AsyncWrite for TlsStream<S>
@ -200,17 +183,6 @@ where
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
fn poll_write_buf<B: Buf>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
{
Pin::new(&mut self.0).poll_write_buf(cx, buf)
}
}
impl<S> tls::TlsStream for TlsStream<S>

View File

@ -36,7 +36,7 @@ fallible-iterator = "0.2"
futures = "0.3"
tokio-postgres = { version = "0.5.5", path = "../tokio-postgres" }
tokio = { version = "0.2", features = ["rt-core", "time"] }
tokio = { version = "0.3", features = ["rt", "time"] }
log = "0.4"
[dev-dependencies]

View File

@ -26,9 +26,8 @@ impl CancelToken {
where
T: MakeTlsConnect<Socket>,
{
runtime::Builder::new()
runtime::Builder::new_current_thread()
.enable_all()
.basic_scheduler()
.build()
.unwrap() // FIXME don't unwrap
.block_on(self.0.cancel_query(tls))

View File

@ -336,9 +336,8 @@ impl Config {
T::Stream: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let mut runtime = runtime::Builder::new()
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.basic_scheduler()
.build()
.unwrap(); // FIXME don't unwrap

View File

@ -45,7 +45,8 @@ impl Connection {
where
F: FnOnce() -> T,
{
self.runtime.enter(f)
let _guard = self.runtime.enter();
f()
}
pub fn block_on<F, T>(&mut self, future: F) -> Result<T, Error>

View File

@ -6,7 +6,7 @@ use fallible_iterator::FallibleIterator;
use futures::{ready, FutureExt};
use std::task::Poll;
use std::time::Duration;
use tokio::time::{self, Delay, Instant};
use tokio::time::{self, Instant, Sleep};
/// Notifications from a PostgreSQL backend.
pub struct Notifications<'a> {
@ -64,7 +64,7 @@ impl<'a> Notifications<'a> {
/// This iterator may start returning `Some` after previously returning `None` if more notifications are received.
pub fn timeout_iter(&mut self, timeout: Duration) -> TimeoutIter<'_> {
TimeoutIter {
delay: self.connection.enter(|| time::delay_for(timeout)),
delay: self.connection.enter(|| time::sleep(timeout)),
timeout,
connection: self.connection.as_ref(),
}
@ -124,7 +124,7 @@ impl<'a> FallibleIterator for BlockingIter<'a> {
/// A time-limited blocking iterator over pending notifications.
pub struct TimeoutIter<'a> {
connection: ConnectionRef<'a>,
delay: Delay,
delay: Sleep,
timeout: Duration,
}

View File

@ -25,7 +25,7 @@ circle-ci = { repository = "sfackler/rust-postgres" }
[features]
default = ["runtime"]
runtime = ["tokio/dns", "tokio/net", "tokio/time"]
runtime = ["tokio/net", "tokio/time"]
with-bit-vec-0_6 = ["postgres-types/with-bit-vec-0_6"]
with-chrono-0_4 = ["postgres-types/with-chrono-0_4"]
@ -49,11 +49,11 @@ pin-project-lite = "0.1"
phf = "0.8"
postgres-protocol = { version = "0.5.0", path = "../postgres-protocol" }
postgres-types = { version = "0.1.2", path = "../postgres-types" }
tokio = { version = "0.2", features = ["io-util"] }
tokio-util = { version = "0.3", features = ["codec"] }
tokio = { version = "0.3", features = ["io-util"] }
tokio-util = { version = "0.4", features = ["codec"] }
[dev-dependencies]
tokio = { version = "0.2", features = ["full"] }
tokio = { version = "0.3", features = ["full"] }
env_logger = "0.7"
criterion = "0.3"

View File

@ -7,7 +7,7 @@ use tokio::runtime::Runtime;
use tokio_postgres::{Client, NoTls};
fn setup() -> (Client, Runtime) {
let mut runtime = Runtime::new().unwrap();
let runtime = Runtime::new().unwrap();
let (client, conn) = runtime
.block_on(tokio_postgres::connect(
"host=localhost port=5433 user=postgres",
@ -19,7 +19,7 @@ fn setup() -> (Client, Runtime) {
}
fn query_prepared(c: &mut Criterion) {
let (client, mut runtime) = setup();
let (client, runtime) = setup();
let statement = runtime.block_on(client.prepare("SELECT $1::INT8")).unwrap();
c.bench_function("runtime_block_on", move |b| {
b.iter(|| {
@ -29,13 +29,13 @@ fn query_prepared(c: &mut Criterion) {
})
});
let (client, mut runtime) = setup();
let (client, runtime) = setup();
let statement = runtime.block_on(client.prepare("SELECT $1::INT8")).unwrap();
c.bench_function("executor_block_on", move |b| {
b.iter(|| executor::block_on(client.query(&statement, &[&1i64])).unwrap())
});
let (client, mut runtime) = setup();
let (client, runtime) = setup();
let client = Arc::new(client);
let statement = runtime.block_on(client.prepare("SELECT $1::INT8")).unwrap();
c.bench_function("spawned", move |b| {

View File

@ -12,19 +12,15 @@ pub(crate) async fn connect_socket(
host: &Host,
port: u16,
connect_timeout: Option<Duration>,
keepalives: bool,
keepalives_idle: Duration,
_keepalives: bool,
_keepalives_idle: Duration,
) -> Result<Socket, Error> {
match host {
Host::Tcp(host) => {
let socket =
connect_with_timeout(TcpStream::connect((&**host, port)), connect_timeout).await?;
socket.set_nodelay(true).map_err(Error::connect)?;
if keepalives {
socket
.set_keepalive(Some(keepalives_idle))
.map_err(Error::connect)?;
}
// FIXME support keepalives?
Ok(Socket::new_tcp(socket))
}

View File

@ -1,10 +1,8 @@
use crate::tls::{ChannelBinding, TlsStream};
use bytes::{Buf, BufMut};
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub enum MaybeTlsStream<S, T> {
Raw(S),
@ -16,38 +14,16 @@ where
S: AsyncRead + Unpin,
T: AsyncRead + Unpin,
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
match self {
MaybeTlsStream::Raw(s) => s.prepare_uninitialized_buffer(buf),
MaybeTlsStream::Tls(s) => s.prepare_uninitialized_buffer(buf),
}
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
fn poll_read_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: BufMut,
{
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read_buf(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read_buf(cx, buf),
}
}
}
impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
@ -79,21 +55,6 @@ where
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
fn poll_write_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: Buf,
{
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write_buf(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write_buf(cx, buf),
}
}
}
impl<S, T> TlsStream for MaybeTlsStream<S, T>

View File

@ -1,9 +1,7 @@
use bytes::{Buf, BufMut};
use std::io;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
@ -33,41 +31,17 @@ impl Socket {
}
impl AsyncRead for Socket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [MaybeUninit<u8>]) -> bool {
match &self.0 {
Inner::Tcp(s) => s.prepare_uninitialized_buffer(buf),
#[cfg(unix)]
Inner::Unix(s) => s.prepare_uninitialized_buffer(buf),
}
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_read(cx, buf),
}
}
fn poll_read_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: BufMut,
{
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_read_buf(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_read_buf(cx, buf),
}
}
}
impl AsyncWrite for Socket {
@ -98,20 +72,4 @@ impl AsyncWrite for Socket {
Inner::Unix(s) => Pin::new(s).poll_shutdown(cx),
}
}
fn poll_write_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: Buf,
{
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_write_buf(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_write_buf(cx, buf),
}
}
}

View File

@ -5,7 +5,7 @@ use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub(crate) mod private {
pub struct ForcePrivateApi;
@ -125,8 +125,8 @@ impl AsyncRead for NoTlsStream {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match *self {}
}
}

View File

@ -308,7 +308,7 @@ async fn cancel_query_raw() {
let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap();
let cancel_token = client.cancel_token();
let cancel = cancel_token.cancel_query_raw(socket, NoTls);
let cancel = time::delay_for(Duration::from_millis(100)).then(|()| cancel);
let cancel = time::sleep(Duration::from_millis(100)).then(|()| cancel);
let sleep = client.batch_execute("SELECT pg_sleep(100)");

View File

@ -72,7 +72,7 @@ async fn cancel_query() {
let cancel_token = client.cancel_token();
let cancel = cancel_token.cancel_query(NoTls);
let cancel = time::delay_for(Duration::from_millis(100)).then(|()| cancel);
let cancel = time::sleep(Duration::from_millis(100)).then(|()| cancel);
let sleep = client.batch_execute("SELECT pg_sleep(100)");