Support target_session_attrs

Closes #399
This commit is contained in:
Steven Fackler 2018-12-30 11:50:15 -08:00
parent 38db34eb6a
commit 45444d6c51
5 changed files with 123 additions and 11 deletions

View File

@ -15,6 +15,15 @@ use crate::proto::HandshakeFuture;
use crate::{Connect, MakeTlsMode, Socket};
use crate::{Error, Handshake, TlsMode};
#[cfg(feature = "runtime")]
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum TargetSessionAttrs {
Any,
ReadWrite,
#[doc(hidden)]
__NonExhaustive,
}
#[cfg(feature = "runtime")]
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Host {
@ -37,6 +46,8 @@ pub(crate) struct Inner {
pub(crate) keepalives: bool,
#[cfg(feature = "runtime")]
pub(crate) keepalives_idle: Duration,
#[cfg(feature = "runtime")]
pub(crate) target_session_attrs: TargetSessionAttrs,
}
#[derive(Debug, Clone, PartialEq)]
@ -67,6 +78,8 @@ impl Config {
keepalives: true,
#[cfg(feature = "runtime")]
keepalives_idle: Duration::from_secs(2 * 60 * 60),
#[cfg(feature = "runtime")]
target_session_attrs: TargetSessionAttrs::Any,
}))
}
@ -120,6 +133,15 @@ impl Config {
self
}
#[cfg(feature = "runtime")]
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
Arc::make_mut(&mut self.0).target_session_attrs = target_session_attrs;
self
}
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
@ -204,6 +226,15 @@ impl FromStr for Config {
builder.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
}
}
#[cfg(feature = "runtime")]
"target_session_attrs" => {
let target_session_attrs = match &*value {
"any" => TargetSessionAttrs::Any,
"read-write" => TargetSessionAttrs::ReadWrite,
_ => return Err(Error::invalid_target_session_attrs()),
};
builder.target_session_attrs(target_session_attrs);
}
key => {
builder.param(key, &value);
}

View File

@ -361,11 +361,15 @@ enum Kind {
#[cfg(feature = "runtime")]
InvalidKeepalives,
#[cfg(feature = "runtime")]
InvalidTargetSessionAttrs,
#[cfg(feature = "runtime")]
InvalidKeepalivesIdle,
#[cfg(feature = "runtime")]
Timer,
#[cfg(feature = "runtime")]
ConnectTimeout,
#[cfg(feature = "runtime")]
ReadOnlyDatabase,
}
struct ErrorInner {
@ -418,9 +422,13 @@ impl fmt::Display for Error {
#[cfg(feature = "runtime")]
Kind::InvalidKeepalivesIdle => "invalid keepalives_value",
#[cfg(feature = "runtime")]
Kind::InvalidTargetSessionAttrs => "invalid target_session_attrs",
#[cfg(feature = "runtime")]
Kind::Timer => "timer error",
#[cfg(feature = "runtime")]
Kind::ConnectTimeout => "timed out connecting to server",
#[cfg(feature = "runtime")]
Kind::ReadOnlyDatabase => "the database was read-only",
};
fmt.write_str(s)?;
if let Some(ref cause) = self.0.cause {
@ -559,6 +567,11 @@ impl Error {
Error::new(Kind::InvalidKeepalivesIdle, Some(Box::new(e)))
}
#[cfg(feature = "runtime")]
pub(crate) fn invalid_target_session_attrs() -> Error {
Error::new(Kind::InvalidTargetSessionAttrs, None)
}
#[cfg(feature = "runtime")]
pub(crate) fn timer(e: tokio_timer::Error) -> Error {
Error::new(Kind::Timer, Some(Box::new(e)))
@ -568,4 +581,9 @@ impl Error {
pub(crate) fn connect_timeout() -> Error {
Error::new(Kind::ConnectTimeout, None)
}
#[cfg(feature = "runtime")]
pub(crate) fn read_only_database() -> Error {
Error::new(Kind::ReadOnlyDatabase, None)
}
}

View File

@ -1,6 +1,6 @@
#![allow(clippy::large_enum_variant)]
use futures::{try_ready, Async, Future, Poll};
use futures::{try_ready, Async, Future, Poll, Stream};
use futures_cpupool::{CpuFuture, CpuPool};
use lazy_static::lazy_static;
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
@ -15,8 +15,8 @@ use tokio_timer::Delay;
#[cfg(unix)]
use tokio_uds::UnixStream;
use crate::proto::{Client, Connection, HandshakeFuture};
use crate::{Config, Error, Host, Socket, TlsMode};
use crate::proto::{Client, Connection, HandshakeFuture, SimpleQueryStream};
use crate::{Config, Error, Host, Socket, TargetSessionAttrs, TlsMode};
lazy_static! {
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
@ -61,8 +61,17 @@ where
tls_mode: T,
config: Config,
},
#[state_machine_future(transitions(CheckingSessionAttrs, Finished))]
Handshaking {
future: HandshakeFuture<Socket, T>,
target_session_attrs: TargetSessionAttrs,
},
#[state_machine_future(transitions(Finished))]
Handshaking { future: HandshakeFuture<Socket, T> },
CheckingSessionAttrs {
stream: SimpleQueryStream,
client: Client,
connection: Connection<T::Stream>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
#[state_machine_future(error)]
@ -130,7 +139,8 @@ where
let state = state.take();
transition!(Handshaking {
future: HandshakeFuture::new(stream, state.tls_mode, state.config)
target_session_attrs: state.config.0.target_session_attrs,
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
})
}
@ -203,6 +213,7 @@ where
let stream = Socket::new_tcp(stream);
transition!(Handshaking {
target_session_attrs: state.config.0.target_session_attrs,
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
})
}
@ -210,9 +221,37 @@ where
fn poll_handshaking<'a>(
state: &'a mut RentToOwn<'a, Handshaking<T>>,
) -> Poll<AfterHandshaking<T>, Error> {
let r = try_ready!(state.future.poll());
let (client, connection) = try_ready!(state.future.poll());
transition!(Finished(r))
if let TargetSessionAttrs::ReadWrite = state.target_session_attrs {
transition!(CheckingSessionAttrs {
stream: client.batch_execute("SHOW transaction_read_only"),
client,
connection,
})
} else {
transition!(Finished((client, connection)))
}
}
fn poll_checking_session_attrs<'a>(
state: &'a mut RentToOwn<'a, CheckingSessionAttrs<T>>,
) -> Poll<AfterCheckingSessionAttrs<T>, Error> {
if let Async::Ready(()) = state.connection.poll()? {
return Err(Error::closed());
}
match try_ready!(state.stream.poll()) {
Some(row) => {
if row.get(0) == Some("on") {
Err(Error::read_only_database())
} else {
let state = state.take();
transition!(Finished((state.client, state.connection)))
}
}
None => Err(Error::closed()),
}
}
}

View File

@ -1,5 +1,7 @@
#[cfg(feature = "runtime")]
use std::time::Duration;
#[cfg(feature = "runtime")]
use tokio_postgres::TargetSessionAttrs;
#[test]
fn pairs_ok() {
@ -31,15 +33,17 @@ fn pairs_ws() {
#[test]
#[cfg(feature = "runtime")]
fn settings() {
let params = "connect_timeout=3 keepalives=0 keepalives_idle=30"
.parse::<tokio_postgres::Config>()
.unwrap();
let params =
"connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-write"
.parse::<tokio_postgres::Config>()
.unwrap();
let mut expected = tokio_postgres::Config::new();
expected
.connect_timeout(Duration::from_secs(3))
.keepalives(false)
.keepalives_idle(Duration::from_secs(30));
.keepalives_idle(Duration::from_secs(30))
.target_session_attrs(TargetSessionAttrs::ReadWrite);
assert_eq!(params, expected);
}

View File

@ -46,3 +46,23 @@ fn wrong_port_count() {
);
runtime.block_on(f).err().unwrap();
}
#[test]
fn target_session_attrs_ok() {
let mut runtime = Runtime::new().unwrap();
let f = tokio_postgres::connect(
"host=localhost port=5433 user=postgres target_session_attrs=read-write",
NoTls,
);
runtime.block_on(f).unwrap();
}
#[test]
fn target_session_attrs_err() {
let mut runtime = Runtime::new().unwrap();
let f = tokio_postgres::connect(
"host=localhost port=5433 user=postgres target_session_attrs=read-write default_transaction_read_only=on",
NoTls,
);
runtime.block_on(f).err().unwrap();
}