Don't block the reactor on DNS
This commit is contained in:
parent
f07ebc7373
commit
3ed4543426
@ -49,13 +49,13 @@
|
|||||||
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
||||||
#![feature(async_await)]
|
#![feature(async_await)]
|
||||||
|
|
||||||
|
use std::future::Future;
|
||||||
|
use std::pin::Pin;
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
use tokio_postgres::tls::MakeTlsConnect;
|
use tokio_postgres::tls::MakeTlsConnect;
|
||||||
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
|
||||||
use tokio_tls::TlsStream;
|
use tokio_tls::TlsStream;
|
||||||
use std::pin::Pin;
|
|
||||||
use std::future::Future;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
@ -111,7 +111,9 @@ where
|
|||||||
{
|
{
|
||||||
type Stream = TlsStream<S>;
|
type Stream = TlsStream<S>;
|
||||||
type Error = native_tls::Error;
|
type Error = native_tls::Error;
|
||||||
type Future = Pin<Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>>;
|
type Future = Pin<
|
||||||
|
Box<dyn Future<Output = Result<(TlsStream<S>, ChannelBinding), native_tls::Error>> + Send>,
|
||||||
|
>;
|
||||||
|
|
||||||
fn connect(self, stream: S) -> Self::Future {
|
fn connect(self, stream: S) -> Self::Future {
|
||||||
let future = async move {
|
let future = async move {
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
|
use futures::{FutureExt, TryStreamExt};
|
||||||
use native_tls::{self, Certificate};
|
use native_tls::{self, Certificate};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_postgres::tls::TlsConnect;
|
use tokio_postgres::tls::TlsConnect;
|
||||||
use futures::{FutureExt, TryStreamExt};
|
|
||||||
|
|
||||||
#[cfg(feature = "runtime")]
|
#[cfg(feature = "runtime")]
|
||||||
use crate::MakeTlsConnector;
|
use crate::MakeTlsConnector;
|
||||||
use crate::TlsConnector;
|
use crate::TlsConnector;
|
||||||
|
|
||||||
async fn smoke_test<T>(s: &str, tls: T)
|
async fn smoke_test<T>(s: &str, tls: T)
|
||||||
where
|
where
|
||||||
T: TlsConnect<TcpStream>,
|
T: TlsConnect<TcpStream>,
|
||||||
T::Stream: 'static + Send,
|
T::Stream: 'static + Send,
|
||||||
{
|
{
|
||||||
let stream = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
|
let stream = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap())
|
||||||
.await
|
.await
|
||||||
@ -44,7 +44,8 @@ async fn require() {
|
|||||||
smoke_test(
|
smoke_test(
|
||||||
"user=ssl_user dbname=postgres sslmode=require",
|
"user=ssl_user dbname=postgres sslmode=require",
|
||||||
TlsConnector::new(connector, "localhost"),
|
TlsConnector::new(connector, "localhost"),
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -58,7 +59,8 @@ async fn prefer() {
|
|||||||
smoke_test(
|
smoke_test(
|
||||||
"user=ssl_user dbname=postgres",
|
"user=ssl_user dbname=postgres",
|
||||||
TlsConnector::new(connector, "localhost"),
|
TlsConnector::new(connector, "localhost"),
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -72,7 +74,8 @@ async fn scram_user() {
|
|||||||
smoke_test(
|
smoke_test(
|
||||||
"user=scram_user password=password dbname=postgres sslmode=require",
|
"user=scram_user password=password dbname=postgres sslmode=require",
|
||||||
TlsConnector::new(connector, "localhost"),
|
TlsConnector::new(connector, "localhost"),
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -90,8 +93,8 @@ async fn runtime() {
|
|||||||
"host=localhost port=5433 user=postgres sslmode=require",
|
"host=localhost port=5433 user=postgres sslmode=require",
|
||||||
connector,
|
connector,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let connection = connection.map(|r| r.unwrap());
|
let connection = connection.map(|r| r.unwrap());
|
||||||
tokio::spawn(connection);
|
tokio::spawn(connection);
|
||||||
|
|
||||||
|
@ -277,8 +277,12 @@ impl Config {
|
|||||||
});
|
});
|
||||||
match &self.executor {
|
match &self.executor {
|
||||||
Some(executor) => {
|
Some(executor) => {
|
||||||
executor.lock().unwrap().spawn(Box::pin(connection)).unwrap();
|
executor
|
||||||
},
|
.lock()
|
||||||
|
.unwrap()
|
||||||
|
.spawn(Box::pin(connection))
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
None => {
|
None => {
|
||||||
RUNTIME.spawn(connection);
|
RUNTIME.spawn(connection);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
use crate::config::Host;
|
use crate::config::Host;
|
||||||
use crate::{Error, Socket};
|
use crate::{Error, Socket};
|
||||||
|
use std::vec;
|
||||||
|
use futures::channel::oneshot;
|
||||||
|
use futures::future;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::io;
|
|
||||||
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
|
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::{io, thread};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio::net::UnixStream;
|
use tokio::net::UnixStream;
|
||||||
@ -23,10 +26,7 @@ pub(crate) async fn connect_socket(
|
|||||||
// avoid dealing with blocking DNS entirely if possible
|
// avoid dealing with blocking DNS entirely if possible
|
||||||
vec![SocketAddr::new(ip, port)].into_iter()
|
vec![SocketAddr::new(ip, port)].into_iter()
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => dns(host, port).await.map_err(Error::connect)?,
|
||||||
// FIXME what do?
|
|
||||||
(&**host, port).to_socket_addrs().map_err(Error::connect)?
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut error = None;
|
let mut error = None;
|
||||||
@ -64,6 +64,25 @@ pub(crate) async fn connect_socket(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn dns(host: &str, port: u16) -> io::Result<vec::IntoIter<SocketAddr>> {
|
||||||
|
// if we're running on a threadpool, use its blocking support
|
||||||
|
if let Ok(r) =
|
||||||
|
future::poll_fn(|_| tokio_threadpool::blocking(|| (host, port).to_socket_addrs())).await
|
||||||
|
{
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME what should we do here?
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
let host = host.to_string();
|
||||||
|
thread::spawn(move || {
|
||||||
|
let addrs = (&*host, port).to_socket_addrs();
|
||||||
|
let _ = tx.send(addrs);
|
||||||
|
});
|
||||||
|
|
||||||
|
rx.await.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
|
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
|
||||||
where
|
where
|
||||||
F: Future<Output = io::Result<T>>,
|
F: Future<Output = io::Result<T>>,
|
||||||
|
Loading…
Reference in New Issue
Block a user