From 3ed45434261b98c5666e3ed9cbb69223a8504b25 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 4 Aug 2019 19:21:32 -0700 Subject: [PATCH] Don't block the reactor on DNS --- postgres-native-tls/src/lib.rs | 8 +++++--- postgres-native-tls/src/test.rs | 21 +++++++++++--------- postgres/src/config.rs | 8 ++++++-- tokio-postgres/src/connect_socket.rs | 29 +++++++++++++++++++++++----- 4 files changed, 47 insertions(+), 19 deletions(-) diff --git a/postgres-native-tls/src/lib.rs b/postgres-native-tls/src/lib.rs index 32809f59..0e9ac30c 100644 --- a/postgres-native-tls/src/lib.rs +++ b/postgres-native-tls/src/lib.rs @@ -49,13 +49,13 @@ #![warn(rust_2018_idioms, clippy::all, missing_docs)] #![feature(async_await)] +use std::future::Future; +use std::pin::Pin; use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::{ChannelBinding, TlsConnect}; use tokio_tls::TlsStream; -use std::pin::Pin; -use std::future::Future; #[cfg(test)] mod test; @@ -111,7 +111,9 @@ where { type Stream = TlsStream; type Error = native_tls::Error; - type Future = Pin, ChannelBinding), native_tls::Error>> + Send>>; + type Future = Pin< + Box, ChannelBinding), native_tls::Error>> + Send>, + >; fn connect(self, stream: S) -> Self::Future { let future = async move { diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index 81f93398..5e9dac58 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -1,16 +1,16 @@ +use futures::{FutureExt, TryStreamExt}; use native_tls::{self, Certificate}; use tokio::net::TcpStream; use tokio_postgres::tls::TlsConnect; -use futures::{FutureExt, TryStreamExt}; #[cfg(feature = "runtime")] use crate::MakeTlsConnector; use crate::TlsConnector; async fn smoke_test(s: &str, tls: T) - where - T: TlsConnect, - T::Stream: 'static + Send, +where + T: TlsConnect, + T::Stream: 'static + Send, { let stream = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .await @@ -44,7 +44,8 @@ async fn require() { smoke_test( "user=ssl_user dbname=postgres sslmode=require", TlsConnector::new(connector, "localhost"), - ).await; + ) + .await; } #[tokio::test] @@ -58,7 +59,8 @@ async fn prefer() { smoke_test( "user=ssl_user dbname=postgres", TlsConnector::new(connector, "localhost"), - ).await; + ) + .await; } #[tokio::test] @@ -72,7 +74,8 @@ async fn scram_user() { smoke_test( "user=scram_user password=password dbname=postgres sslmode=require", TlsConnector::new(connector, "localhost"), - ).await; + ) + .await; } #[tokio::test] @@ -90,8 +93,8 @@ async fn runtime() { "host=localhost port=5433 user=postgres sslmode=require", connector, ) - .await - .unwrap(); + .await + .unwrap(); let connection = connection.map(|r| r.unwrap()); tokio::spawn(connection); diff --git a/postgres/src/config.rs b/postgres/src/config.rs index 2c2fa655..9c44c4cb 100644 --- a/postgres/src/config.rs +++ b/postgres/src/config.rs @@ -277,8 +277,12 @@ impl Config { }); match &self.executor { Some(executor) => { - executor.lock().unwrap().spawn(Box::pin(connection)).unwrap(); - }, + executor + .lock() + .unwrap() + .spawn(Box::pin(connection)) + .unwrap(); + } None => { RUNTIME.spawn(connection); } diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 3209b139..5c7d7271 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,9 +1,12 @@ use crate::config::Host; use crate::{Error, Socket}; +use std::vec; +use futures::channel::oneshot; +use futures::future; use std::future::Future; -use std::io; use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; use std::time::Duration; +use std::{io, thread}; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; @@ -23,10 +26,7 @@ pub(crate) async fn connect_socket( // avoid dealing with blocking DNS entirely if possible vec![SocketAddr::new(ip, port)].into_iter() } - Err(_) => { - // FIXME what do? - (&**host, port).to_socket_addrs().map_err(Error::connect)? - } + Err(_) => dns(host, port).await.map_err(Error::connect)?, }; let mut error = None; @@ -64,6 +64,25 @@ pub(crate) async fn connect_socket( } } +async fn dns(host: &str, port: u16) -> io::Result> { + // if we're running on a threadpool, use its blocking support + if let Ok(r) = + future::poll_fn(|_| tokio_threadpool::blocking(|| (host, port).to_socket_addrs())).await + { + return r; + } + + // FIXME what should we do here? + let (tx, rx) = oneshot::channel(); + let host = host.to_string(); + thread::spawn(move || { + let addrs = (&*host, port).to_socket_addrs(); + let _ = tx.send(addrs); + }); + + rx.await.unwrap() +} + async fn connect_with_timeout(connect: F, timeout: Option) -> Result where F: Future>,