Support TCP keepalive

This commit is contained in:
Steven Fackler 2018-12-30 09:38:12 -08:00
parent 983de2ef9d
commit 38db34eb6a
4 changed files with 78 additions and 1 deletions

View File

@ -33,6 +33,10 @@ pub(crate) struct Inner {
pub(crate) port: Vec<u16>, pub(crate) port: Vec<u16>,
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub(crate) connect_timeout: Option<Duration>, pub(crate) connect_timeout: Option<Duration>,
#[cfg(feature = "runtime")]
pub(crate) keepalives: bool,
#[cfg(feature = "runtime")]
pub(crate) keepalives_idle: Duration,
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -59,6 +63,10 @@ impl Config {
port: vec![], port: vec![],
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
connect_timeout: None, connect_timeout: None,
#[cfg(feature = "runtime")]
keepalives: true,
#[cfg(feature = "runtime")]
keepalives_idle: Duration::from_secs(2 * 60 * 60),
})) }))
} }
@ -100,6 +108,18 @@ impl Config {
self self
} }
#[cfg(feature = "runtime")]
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
Arc::make_mut(&mut self.0).keepalives = keepalives;
self
}
#[cfg(feature = "runtime")]
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
Arc::make_mut(&mut self.0).keepalives_idle = keepalives_idle;
self
}
pub fn password<T>(&mut self, password: T) -> &mut Config pub fn password<T>(&mut self, password: T) -> &mut Config
where where
T: AsRef<[u8]>, T: AsRef<[u8]>,
@ -170,6 +190,20 @@ impl FromStr for Config {
builder.connect_timeout(Duration::from_secs(timeout as u64)); builder.connect_timeout(Duration::from_secs(timeout as u64));
} }
} }
#[cfg(feature = "runtime")]
"keepalives" => {
let keepalives = value.parse::<u64>().map_err(Error::invalid_keepalives)?;
builder.keepalives(keepalives != 0);
}
#[cfg(feature = "runtime")]
"keepalives_idle" => {
let keepalives_idle = value
.parse::<i64>()
.map_err(Error::invalid_keepalives_idle)?;
if keepalives_idle > 0 {
builder.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
}
}
key => { key => {
builder.param(key, &value); builder.param(key, &value);
} }

View File

@ -359,6 +359,10 @@ enum Kind {
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
InvalidConnectTimeout, InvalidConnectTimeout,
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
InvalidKeepalives,
#[cfg(feature = "runtime")]
InvalidKeepalivesIdle,
#[cfg(feature = "runtime")]
Timer, Timer,
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
ConnectTimeout, ConnectTimeout,
@ -410,6 +414,10 @@ impl fmt::Display for Error {
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
Kind::InvalidConnectTimeout => "invalid connect_timeout", Kind::InvalidConnectTimeout => "invalid connect_timeout",
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
Kind::InvalidKeepalives => "invalid keepalives",
#[cfg(feature = "runtime")]
Kind::InvalidKeepalivesIdle => "invalid keepalives_value",
#[cfg(feature = "runtime")]
Kind::Timer => "timer error", Kind::Timer => "timer error",
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
Kind::ConnectTimeout => "timed out connecting to server", Kind::ConnectTimeout => "timed out connecting to server",
@ -541,6 +549,16 @@ impl Error {
Error::new(Kind::InvalidConnectTimeout, Some(Box::new(e))) Error::new(Kind::InvalidConnectTimeout, Some(Box::new(e)))
} }
#[cfg(feature = "runtime")]
pub(crate) fn invalid_keepalives(e: ParseIntError) -> Error {
Error::new(Kind::InvalidKeepalives, Some(Box::new(e)))
}
#[cfg(feature = "runtime")]
pub(crate) fn invalid_keepalives_idle(e: ParseIntError) -> Error {
Error::new(Kind::InvalidKeepalivesIdle, Some(Box::new(e)))
}
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
pub(crate) fn timer(e: tokio_timer::Error) -> Error { pub(crate) fn timer(e: tokio_timer::Error) -> Error {
Error::new(Kind::Timer, Some(Box::new(e))) Error::new(Kind::Timer, Some(Box::new(e)))

View File

@ -194,6 +194,12 @@ where
let state = state.take(); let state = state.take();
stream.set_nodelay(true).map_err(Error::connect)?; stream.set_nodelay(true).map_err(Error::connect)?;
if state.config.0.keepalives {
stream
.set_keepalive(Some(state.config.0.keepalives_idle))
.map_err(Error::connect)?;
}
let stream = Socket::new_tcp(stream); let stream = Socket::new_tcp(stream);
transition!(Handshaking { transition!(Handshaking {

View File

@ -1,3 +1,6 @@
#[cfg(feature = "runtime")]
use std::time::Duration;
#[test] #[test]
fn pairs_ok() { fn pairs_ok() {
let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''" let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''"
@ -17,10 +20,26 @@ fn pairs_ok() {
fn pairs_ws() { fn pairs_ws() {
let params = " user\t=\r\n\x0bfoo \t password = hunter2 " let params = " user\t=\r\n\x0bfoo \t password = hunter2 "
.parse::<tokio_postgres::Config>() .parse::<tokio_postgres::Config>()
.unwrap();; .unwrap();
let mut expected = tokio_postgres::Config::new(); let mut expected = tokio_postgres::Config::new();
expected.param("user", "foo").password("hunter2"); expected.param("user", "foo").password("hunter2");
assert_eq!(params, expected); assert_eq!(params, expected);
} }
#[test]
#[cfg(feature = "runtime")]
fn settings() {
let params = "connect_timeout=3 keepalives=0 keepalives_idle=30"
.parse::<tokio_postgres::Config>()
.unwrap();
let mut expected = tokio_postgres::Config::new();
expected
.connect_timeout(Duration::from_secs(3))
.keepalives(false)
.keepalives_idle(Duration::from_secs(30));
assert_eq!(params, expected);
}