From 5ad3c9a139303ba0c63b5c06337790a41d6474a2 Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Thu, 5 Nov 2020 21:14:56 -0500 Subject: [PATCH] Add back keepalives config handling Also fix connection timeouts to be per-address --- tokio-postgres/Cargo.toml | 1 + tokio-postgres/src/connect_socket.rs | 55 ++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 883f8e1e..14f8c7e9 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -49,6 +49,7 @@ pin-project-lite = "0.1" phf = "0.8" postgres-protocol = { version = "0.5.0", path = "../postgres-protocol" } postgres-types = { version = "0.1.2", path = "../postgres-types" } +socket2 = "0.3" tokio = { version = "0.3", features = ["io-util"] } tokio-util = { version = "0.4", features = ["codec"] } diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs index 145eb7dc..564677b0 100644 --- a/tokio-postgres/src/connect_socket.rs +++ b/tokio-postgres/src/connect_socket.rs @@ -1,28 +1,67 @@ use crate::config::Host; use crate::{Error, Socket}; +use socket2::{Domain, Protocol, Type}; use std::future::Future; use std::io; +use std::net::SocketAddr; +#[cfg(unix)] +use std::os::unix::io::{FromRawFd, IntoRawFd}; +#[cfg(windows)] +use std::os::windows::io::{FromRawSocket, IntoRawSocket}; use std::time::Duration; -use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; +use tokio::net::{self, TcpSocket}; use tokio::time; pub(crate) async fn connect_socket( host: &Host, port: u16, connect_timeout: Option, - _keepalives: bool, - _keepalives_idle: Duration, + keepalives: bool, + keepalives_idle: Duration, ) -> Result { match host { Host::Tcp(host) => { - let socket = - connect_with_timeout(TcpStream::connect((&**host, port)), connect_timeout).await?; - socket.set_nodelay(true).map_err(Error::connect)?; - // FIXME support keepalives? + let addrs = net::lookup_host((&**host, port)) + .await + .map_err(Error::connect)?; - Ok(Socket::new_tcp(socket)) + let mut last_err = None; + + for addr in addrs { + let domain = match addr { + SocketAddr::V4(_) => Domain::ipv4(), + SocketAddr::V6(_) => Domain::ipv6(), + }; + + let socket = socket2::Socket::new(domain, Type::stream(), Some(Protocol::tcp())) + .map_err(Error::connect)?; + socket.set_nonblocking(true).map_err(Error::connect)?; + socket.set_nodelay(true).map_err(Error::connect)?; + if keepalives { + socket + .set_keepalive(Some(keepalives_idle)) + .map_err(Error::connect)?; + } + + #[cfg(unix)] + let socket = unsafe { TcpSocket::from_raw_fd(socket.into_raw_fd()) }; + #[cfg(windows)] + let socket = unsafe { TcpSocket::from_raw_socket(socket.into_raw_socket()) }; + + match connect_with_timeout(socket.connect(addr), connect_timeout).await { + Ok(socket) => return Ok(Socket::new_tcp(socket)), + Err(e) => last_err = Some(e), + } + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) } #[cfg(unix)] Host::Unix(path) => {