From 45444d6c5129a656da69c0cf7b681b00eab50f4e Mon Sep 17 00:00:00 2001 From: Steven Fackler Date: Sun, 30 Dec 2018 11:50:15 -0800 Subject: [PATCH] Support target_session_attrs Closes #399 --- tokio-postgres/src/config.rs | 31 ++++++++++++++ tokio-postgres/src/error/mod.rs | 18 ++++++++ tokio-postgres/src/proto/connect_once.rs | 53 ++++++++++++++++++++---- tokio-postgres/tests/test/parse.rs | 12 ++++-- tokio-postgres/tests/test/runtime.rs | 20 +++++++++ 5 files changed, 123 insertions(+), 11 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index a3ec78bf..458cdf80 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -15,6 +15,15 @@ use crate::proto::HandshakeFuture; use crate::{Connect, MakeTlsMode, Socket}; use crate::{Error, Handshake, TlsMode}; +#[cfg(feature = "runtime")] +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum TargetSessionAttrs { + Any, + ReadWrite, + #[doc(hidden)] + __NonExhaustive, +} + #[cfg(feature = "runtime")] #[derive(Debug, Clone, PartialEq)] pub(crate) enum Host { @@ -37,6 +46,8 @@ pub(crate) struct Inner { pub(crate) keepalives: bool, #[cfg(feature = "runtime")] pub(crate) keepalives_idle: Duration, + #[cfg(feature = "runtime")] + pub(crate) target_session_attrs: TargetSessionAttrs, } #[derive(Debug, Clone, PartialEq)] @@ -67,6 +78,8 @@ impl Config { keepalives: true, #[cfg(feature = "runtime")] keepalives_idle: Duration::from_secs(2 * 60 * 60), + #[cfg(feature = "runtime")] + target_session_attrs: TargetSessionAttrs::Any, })) } @@ -120,6 +133,15 @@ impl Config { self } + #[cfg(feature = "runtime")] + pub fn target_session_attrs( + &mut self, + target_session_attrs: TargetSessionAttrs, + ) -> &mut Config { + Arc::make_mut(&mut self.0).target_session_attrs = target_session_attrs; + self + } + pub fn password(&mut self, password: T) -> &mut Config where T: AsRef<[u8]>, @@ -204,6 +226,15 @@ impl FromStr for Config { builder.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); } } + #[cfg(feature = "runtime")] + "target_session_attrs" => { + let target_session_attrs = match &*value { + "any" => TargetSessionAttrs::Any, + "read-write" => TargetSessionAttrs::ReadWrite, + _ => return Err(Error::invalid_target_session_attrs()), + }; + builder.target_session_attrs(target_session_attrs); + } key => { builder.param(key, &value); } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index 58f0e56b..1823387c 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -361,11 +361,15 @@ enum Kind { #[cfg(feature = "runtime")] InvalidKeepalives, #[cfg(feature = "runtime")] + InvalidTargetSessionAttrs, + #[cfg(feature = "runtime")] InvalidKeepalivesIdle, #[cfg(feature = "runtime")] Timer, #[cfg(feature = "runtime")] ConnectTimeout, + #[cfg(feature = "runtime")] + ReadOnlyDatabase, } struct ErrorInner { @@ -418,9 +422,13 @@ impl fmt::Display for Error { #[cfg(feature = "runtime")] Kind::InvalidKeepalivesIdle => "invalid keepalives_value", #[cfg(feature = "runtime")] + Kind::InvalidTargetSessionAttrs => "invalid target_session_attrs", + #[cfg(feature = "runtime")] Kind::Timer => "timer error", #[cfg(feature = "runtime")] Kind::ConnectTimeout => "timed out connecting to server", + #[cfg(feature = "runtime")] + Kind::ReadOnlyDatabase => "the database was read-only", }; fmt.write_str(s)?; if let Some(ref cause) = self.0.cause { @@ -559,6 +567,11 @@ impl Error { Error::new(Kind::InvalidKeepalivesIdle, Some(Box::new(e))) } + #[cfg(feature = "runtime")] + pub(crate) fn invalid_target_session_attrs() -> Error { + Error::new(Kind::InvalidTargetSessionAttrs, None) + } + #[cfg(feature = "runtime")] pub(crate) fn timer(e: tokio_timer::Error) -> Error { Error::new(Kind::Timer, Some(Box::new(e))) @@ -568,4 +581,9 @@ impl Error { pub(crate) fn connect_timeout() -> Error { Error::new(Kind::ConnectTimeout, None) } + + #[cfg(feature = "runtime")] + pub(crate) fn read_only_database() -> Error { + Error::new(Kind::ReadOnlyDatabase, None) + } } diff --git a/tokio-postgres/src/proto/connect_once.rs b/tokio-postgres/src/proto/connect_once.rs index cfffe3b5..bf92aa6a 100644 --- a/tokio-postgres/src/proto/connect_once.rs +++ b/tokio-postgres/src/proto/connect_once.rs @@ -1,6 +1,6 @@ #![allow(clippy::large_enum_variant)] -use futures::{try_ready, Async, Future, Poll}; +use futures::{try_ready, Async, Future, Poll, Stream}; use futures_cpupool::{CpuFuture, CpuPool}; use lazy_static::lazy_static; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; @@ -15,8 +15,8 @@ use tokio_timer::Delay; #[cfg(unix)] use tokio_uds::UnixStream; -use crate::proto::{Client, Connection, HandshakeFuture}; -use crate::{Config, Error, Host, Socket, TlsMode}; +use crate::proto::{Client, Connection, HandshakeFuture, SimpleQueryStream}; +use crate::{Config, Error, Host, Socket, TargetSessionAttrs, TlsMode}; lazy_static! { static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new() @@ -61,8 +61,17 @@ where tls_mode: T, config: Config, }, + #[state_machine_future(transitions(CheckingSessionAttrs, Finished))] + Handshaking { + future: HandshakeFuture, + target_session_attrs: TargetSessionAttrs, + }, #[state_machine_future(transitions(Finished))] - Handshaking { future: HandshakeFuture }, + CheckingSessionAttrs { + stream: SimpleQueryStream, + client: Client, + connection: Connection, + }, #[state_machine_future(ready)] Finished((Client, Connection)), #[state_machine_future(error)] @@ -130,7 +139,8 @@ where let state = state.take(); transition!(Handshaking { - future: HandshakeFuture::new(stream, state.tls_mode, state.config) + target_session_attrs: state.config.0.target_session_attrs, + future: HandshakeFuture::new(stream, state.tls_mode, state.config), }) } @@ -203,6 +213,7 @@ where let stream = Socket::new_tcp(stream); transition!(Handshaking { + target_session_attrs: state.config.0.target_session_attrs, future: HandshakeFuture::new(stream, state.tls_mode, state.config), }) } @@ -210,9 +221,37 @@ where fn poll_handshaking<'a>( state: &'a mut RentToOwn<'a, Handshaking>, ) -> Poll, Error> { - let r = try_ready!(state.future.poll()); + let (client, connection) = try_ready!(state.future.poll()); - transition!(Finished(r)) + if let TargetSessionAttrs::ReadWrite = state.target_session_attrs { + transition!(CheckingSessionAttrs { + stream: client.batch_execute("SHOW transaction_read_only"), + client, + connection, + }) + } else { + transition!(Finished((client, connection))) + } + } + + fn poll_checking_session_attrs<'a>( + state: &'a mut RentToOwn<'a, CheckingSessionAttrs>, + ) -> Poll, Error> { + if let Async::Ready(()) = state.connection.poll()? { + return Err(Error::closed()); + } + + match try_ready!(state.stream.poll()) { + Some(row) => { + if row.get(0) == Some("on") { + Err(Error::read_only_database()) + } else { + let state = state.take(); + transition!(Finished((state.client, state.connection))) + } + } + None => Err(Error::closed()), + } } } diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs index 74fedb38..02345456 100644 --- a/tokio-postgres/tests/test/parse.rs +++ b/tokio-postgres/tests/test/parse.rs @@ -1,5 +1,7 @@ #[cfg(feature = "runtime")] use std::time::Duration; +#[cfg(feature = "runtime")] +use tokio_postgres::TargetSessionAttrs; #[test] fn pairs_ok() { @@ -31,15 +33,17 @@ fn pairs_ws() { #[test] #[cfg(feature = "runtime")] fn settings() { - let params = "connect_timeout=3 keepalives=0 keepalives_idle=30" - .parse::() - .unwrap(); + let params = + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-write" + .parse::() + .unwrap(); let mut expected = tokio_postgres::Config::new(); expected .connect_timeout(Duration::from_secs(3)) .keepalives(false) - .keepalives_idle(Duration::from_secs(30)); + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::ReadWrite); assert_eq!(params, expected); } diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 67246876..29df4d8c 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -46,3 +46,23 @@ fn wrong_port_count() { ); runtime.block_on(f).err().unwrap(); } + +#[test] +fn target_session_attrs_ok() { + let mut runtime = Runtime::new().unwrap(); + let f = tokio_postgres::connect( + "host=localhost port=5433 user=postgres target_session_attrs=read-write", + NoTls, + ); + runtime.block_on(f).unwrap(); +} + +#[test] +fn target_session_attrs_err() { + let mut runtime = Runtime::new().unwrap(); + let f = tokio_postgres::connect( + "host=localhost port=5433 user=postgres target_session_attrs=read-write default_transaction_read_only=on", + NoTls, + ); + runtime.block_on(f).err().unwrap(); +}