diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index f5510443..a3ec78bf 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -33,6 +33,10 @@ pub(crate) struct Inner { pub(crate) port: Vec, #[cfg(feature = "runtime")] pub(crate) connect_timeout: Option, + #[cfg(feature = "runtime")] + pub(crate) keepalives: bool, + #[cfg(feature = "runtime")] + pub(crate) keepalives_idle: Duration, } #[derive(Debug, Clone, PartialEq)] @@ -59,6 +63,10 @@ impl Config { port: vec![], #[cfg(feature = "runtime")] connect_timeout: None, + #[cfg(feature = "runtime")] + keepalives: true, + #[cfg(feature = "runtime")] + keepalives_idle: Duration::from_secs(2 * 60 * 60), })) } @@ -100,6 +108,18 @@ impl Config { self } + #[cfg(feature = "runtime")] + pub fn keepalives(&mut self, keepalives: bool) -> &mut Config { + Arc::make_mut(&mut self.0).keepalives = keepalives; + self + } + + #[cfg(feature = "runtime")] + pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { + Arc::make_mut(&mut self.0).keepalives_idle = keepalives_idle; + self + } + pub fn password(&mut self, password: T) -> &mut Config where T: AsRef<[u8]>, @@ -170,6 +190,20 @@ impl FromStr for Config { builder.connect_timeout(Duration::from_secs(timeout as u64)); } } + #[cfg(feature = "runtime")] + "keepalives" => { + let keepalives = value.parse::().map_err(Error::invalid_keepalives)?; + builder.keepalives(keepalives != 0); + } + #[cfg(feature = "runtime")] + "keepalives_idle" => { + let keepalives_idle = value + .parse::() + .map_err(Error::invalid_keepalives_idle)?; + if keepalives_idle > 0 { + builder.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); + } + } key => { builder.param(key, &value); } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index 1b7fca2c..58f0e56b 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -359,6 +359,10 @@ enum Kind { #[cfg(feature = "runtime")] InvalidConnectTimeout, #[cfg(feature = "runtime")] + InvalidKeepalives, + #[cfg(feature = "runtime")] + InvalidKeepalivesIdle, + #[cfg(feature = "runtime")] Timer, #[cfg(feature = "runtime")] ConnectTimeout, @@ -410,6 +414,10 @@ impl fmt::Display for Error { #[cfg(feature = "runtime")] Kind::InvalidConnectTimeout => "invalid connect_timeout", #[cfg(feature = "runtime")] + Kind::InvalidKeepalives => "invalid keepalives", + #[cfg(feature = "runtime")] + Kind::InvalidKeepalivesIdle => "invalid keepalives_value", + #[cfg(feature = "runtime")] Kind::Timer => "timer error", #[cfg(feature = "runtime")] Kind::ConnectTimeout => "timed out connecting to server", @@ -541,6 +549,16 @@ impl Error { Error::new(Kind::InvalidConnectTimeout, Some(Box::new(e))) } + #[cfg(feature = "runtime")] + pub(crate) fn invalid_keepalives(e: ParseIntError) -> Error { + Error::new(Kind::InvalidKeepalives, Some(Box::new(e))) + } + + #[cfg(feature = "runtime")] + pub(crate) fn invalid_keepalives_idle(e: ParseIntError) -> Error { + Error::new(Kind::InvalidKeepalivesIdle, Some(Box::new(e))) + } + #[cfg(feature = "runtime")] pub(crate) fn timer(e: tokio_timer::Error) -> Error { Error::new(Kind::Timer, Some(Box::new(e))) diff --git a/tokio-postgres/src/proto/connect_once.rs b/tokio-postgres/src/proto/connect_once.rs index 2c21f8d8..cfffe3b5 100644 --- a/tokio-postgres/src/proto/connect_once.rs +++ b/tokio-postgres/src/proto/connect_once.rs @@ -194,6 +194,12 @@ where let state = state.take(); stream.set_nodelay(true).map_err(Error::connect)?; + if state.config.0.keepalives { + stream + .set_keepalive(Some(state.config.0.keepalives_idle)) + .map_err(Error::connect)?; + } + let stream = Socket::new_tcp(stream); transition!(Handshaking { diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index ac320d20..74fedb38 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "runtime")] +use std::time::Duration; + #[test] fn pairs_ok() { let params = r"user=foo password=' fizz \'buzz\\ ' thing = ''" @@ -17,10 +20,26 @@ fn pairs_ok() { fn pairs_ws() { let params = " user\t=\r\n\x0bfoo \t password = hunter2 " .parse::() - .unwrap();; + .unwrap(); let mut expected = tokio_postgres::Config::new(); expected.param("user", "foo").password("hunter2"); assert_eq!(params, expected); } + +#[test] +#[cfg(feature = "runtime")] +fn settings() { + let params = "connect_timeout=3 keepalives=0 keepalives_idle=30" + .parse::() + .unwrap(); + + let mut expected = tokio_postgres::Config::new(); + expected + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)); + + assert_eq!(params, expected); +}