Don't block the reactor on DNS

This commit is contained in:
Steven Fackler 2019-08-04 19:21:32 -07:00
parent f07ebc7373
commit 3ed4543426
4 changed files with 47 additions and 19 deletions

View File

@ -49,13 +49,13 @@
#![warn(rust_2018_idioms, clippy::all, missing_docs)] #![warn(rust_2018_idioms, clippy::all, missing_docs)]
#![feature(async_await)] #![feature(async_await)]
use std::future::Future;
use std::pin::Pin;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect; use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect}; use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use tokio_tls::TlsStream; use tokio_tls::TlsStream;
use std::pin::Pin;
use std::future::Future;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
@ -111,7 +111,9 @@ where
{ {
type Stream = TlsStream<S>; type Stream = TlsStream<S>;
type Error = native_tls::Error; type Error = native_tls::Error;
type Future = Pin<Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>>; type Future = Pin<
Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>,
>;
fn connect(self, stream: S) -> Self::Future { fn connect(self, stream: S) -> Self::Future {
let future = async move { let future = async move {

View File

@ -1,16 +1,16 @@
use futures::{FutureExt, TryStreamExt};
use native_tls::{self, Certificate}; use native_tls::{self, Certificate};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio_postgres::tls::TlsConnect; use tokio_postgres::tls::TlsConnect;
use futures::{FutureExt, TryStreamExt};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::MakeTlsConnector; use crate::MakeTlsConnector;
use crate::TlsConnector; use crate::TlsConnector;
async fn smoke_test<T>(s: &str, tls: T) async fn smoke_test<T>(s: &str, tls: T)
where where
T: TlsConnect<TcpStream>, T: TlsConnect<TcpStream>,
T::Stream: 'static + Send, T::Stream: 'static + Send,
{ {
let stream = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) let stream = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
.await .await
@ -44,7 +44,8 @@ async fn require() {
smoke_test( smoke_test(
"user=ssl_user dbname=postgres sslmode=require", "user=ssl_user dbname=postgres sslmode=require",
TlsConnector::new(connector, "localhost"), TlsConnector::new(connector, "localhost"),
).await; )
.await;
} }
#[tokio::test] #[tokio::test]
@ -58,7 +59,8 @@ async fn prefer() {
smoke_test( smoke_test(
"user=ssl_user dbname=postgres", "user=ssl_user dbname=postgres",
TlsConnector::new(connector, "localhost"), TlsConnector::new(connector, "localhost"),
).await; )
.await;
} }
#[tokio::test] #[tokio::test]
@ -72,7 +74,8 @@ async fn scram_user() {
smoke_test( smoke_test(
"user=scram_user password=password dbname=postgres sslmode=require", "user=scram_user password=password dbname=postgres sslmode=require",
TlsConnector::new(connector, "localhost"), TlsConnector::new(connector, "localhost"),
).await; )
.await;
} }
#[tokio::test] #[tokio::test]
@ -90,8 +93,8 @@ async fn runtime() {
"host=localhost port=5433 user=postgres sslmode=require", "host=localhost port=5433 user=postgres sslmode=require",
connector, connector,
) )
.await .await
.unwrap(); .unwrap();
let connection = connection.map(|r| r.unwrap()); let connection = connection.map(|r| r.unwrap());
tokio::spawn(connection); tokio::spawn(connection);

View File

@ -277,8 +277,12 @@ impl Config {
}); });
match &self.executor { match &self.executor {
Some(executor) => { Some(executor) => {
executor.lock().unwrap().spawn(Box::pin(connection)).unwrap(); executor
}, .lock()
.unwrap()
.spawn(Box::pin(connection))
.unwrap();
}
None => { None => {
RUNTIME.spawn(connection); RUNTIME.spawn(connection);
} }

View File

@ -1,9 +1,12 @@
use crate::config::Host; use crate::config::Host;
use crate::{Error, Socket}; use crate::{Error, Socket};
use std::vec;
use futures::channel::oneshot;
use futures::future;
use std::future::Future; use std::future::Future;
use std::io;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use std::time::Duration; use std::time::Duration;
use std::{io, thread};
use tokio::net::TcpStream; use tokio::net::TcpStream;
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixStream; use tokio::net::UnixStream;
@ -23,10 +26,7 @@ pub(crate) async fn connect_socket(
// avoid dealing with blocking DNS entirely if possible // avoid dealing with blocking DNS entirely if possible
vec![SocketAddr::new(ip, port)].into_iter() vec![SocketAddr::new(ip, port)].into_iter()
} }
Err(_) => { Err(_) => dns(host, port).await.map_err(Error::connect)?,
// FIXME what do?
(&**host, port).to_socket_addrs().map_err(Error::connect)?
}
}; };
let mut error = None; let mut error = None;
@ -64,6 +64,25 @@ pub(crate) async fn connect_socket(
} }
} }
async fn dns(host: &str, port: u16) -> io::Result<vec::IntoIter<SocketAddr>> {
// if we're running on a threadpool, use its blocking support
if let Ok(r) =
future::poll_fn(|_| tokio_threadpool::blocking(|| (host, port).to_socket_addrs())).await
{
return r;
}
// FIXME what should we do here?
let (tx, rx) = oneshot::channel();
let host = host.to_string();
thread::spawn(move || {
let addrs = (&*host, port).to_socket_addrs();
let _ = tx.send(addrs);
});
rx.await.unwrap()
}
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error> async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
where where
F: Future<Output = io::Result<T>>, F: Future<Output = io::Result<T>>,