parent
38db34eb6a
commit
45444d6c51
@ -15,6 +15,15 @@ use crate::proto::HandshakeFuture;
|
||||
use crate::{Connect, MakeTlsMode, Socket};
|
||||
use crate::{Error, Handshake, TlsMode};
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub enum TargetSessionAttrs {
|
||||
Any,
|
||||
ReadWrite,
|
||||
#[doc(hidden)]
|
||||
__NonExhaustive,
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub(crate) enum Host {
|
||||
@ -37,6 +46,8 @@ pub(crate) struct Inner {
|
||||
pub(crate) keepalives: bool,
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) keepalives_idle: Duration,
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) target_session_attrs: TargetSessionAttrs,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -67,6 +78,8 @@ impl Config {
|
||||
keepalives: true,
|
||||
#[cfg(feature = "runtime")]
|
||||
keepalives_idle: Duration::from_secs(2 * 60 * 60),
|
||||
#[cfg(feature = "runtime")]
|
||||
target_session_attrs: TargetSessionAttrs::Any,
|
||||
}))
|
||||
}
|
||||
|
||||
@ -120,6 +133,15 @@ impl Config {
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn target_session_attrs(
|
||||
&mut self,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
) -> &mut Config {
|
||||
Arc::make_mut(&mut self.0).target_session_attrs = target_session_attrs;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
@ -204,6 +226,15 @@ impl FromStr for Config {
|
||||
builder.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "runtime")]
|
||||
"target_session_attrs" => {
|
||||
let target_session_attrs = match &*value {
|
||||
"any" => TargetSessionAttrs::Any,
|
||||
"read-write" => TargetSessionAttrs::ReadWrite,
|
||||
_ => return Err(Error::invalid_target_session_attrs()),
|
||||
};
|
||||
builder.target_session_attrs(target_session_attrs);
|
||||
}
|
||||
key => {
|
||||
builder.param(key, &value);
|
||||
}
|
||||
|
@ -361,11 +361,15 @@ enum Kind {
|
||||
#[cfg(feature = "runtime")]
|
||||
InvalidKeepalives,
|
||||
#[cfg(feature = "runtime")]
|
||||
InvalidTargetSessionAttrs,
|
||||
#[cfg(feature = "runtime")]
|
||||
InvalidKeepalivesIdle,
|
||||
#[cfg(feature = "runtime")]
|
||||
Timer,
|
||||
#[cfg(feature = "runtime")]
|
||||
ConnectTimeout,
|
||||
#[cfg(feature = "runtime")]
|
||||
ReadOnlyDatabase,
|
||||
}
|
||||
|
||||
struct ErrorInner {
|
||||
@ -418,9 +422,13 @@ impl fmt::Display for Error {
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::InvalidKeepalivesIdle => "invalid keepalives_value",
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::InvalidTargetSessionAttrs => "invalid target_session_attrs",
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::Timer => "timer error",
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::ConnectTimeout => "timed out connecting to server",
|
||||
#[cfg(feature = "runtime")]
|
||||
Kind::ReadOnlyDatabase => "the database was read-only",
|
||||
};
|
||||
fmt.write_str(s)?;
|
||||
if let Some(ref cause) = self.0.cause {
|
||||
@ -559,6 +567,11 @@ impl Error {
|
||||
Error::new(Kind::InvalidKeepalivesIdle, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) fn invalid_target_session_attrs() -> Error {
|
||||
Error::new(Kind::InvalidTargetSessionAttrs, None)
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) fn timer(e: tokio_timer::Error) -> Error {
|
||||
Error::new(Kind::Timer, Some(Box::new(e)))
|
||||
@ -568,4 +581,9 @@ impl Error {
|
||||
pub(crate) fn connect_timeout() -> Error {
|
||||
Error::new(Kind::ConnectTimeout, None)
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub(crate) fn read_only_database() -> Error {
|
||||
Error::new(Kind::ReadOnlyDatabase, None)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
#![allow(clippy::large_enum_variant)]
|
||||
|
||||
use futures::{try_ready, Async, Future, Poll};
|
||||
use futures::{try_ready, Async, Future, Poll, Stream};
|
||||
use futures_cpupool::{CpuFuture, CpuPool};
|
||||
use lazy_static::lazy_static;
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
@ -15,8 +15,8 @@ use tokio_timer::Delay;
|
||||
#[cfg(unix)]
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
use crate::proto::{Client, Connection, HandshakeFuture};
|
||||
use crate::{Config, Error, Host, Socket, TlsMode};
|
||||
use crate::proto::{Client, Connection, HandshakeFuture, SimpleQueryStream};
|
||||
use crate::{Config, Error, Host, Socket, TargetSessionAttrs, TlsMode};
|
||||
|
||||
lazy_static! {
|
||||
static ref DNS_POOL: CpuPool = futures_cpupool::Builder::new()
|
||||
@ -61,8 +61,17 @@ where
|
||||
tls_mode: T,
|
||||
config: Config,
|
||||
},
|
||||
#[state_machine_future(transitions(CheckingSessionAttrs, Finished))]
|
||||
Handshaking {
|
||||
future: HandshakeFuture<Socket, T>,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Handshaking { future: HandshakeFuture<Socket, T> },
|
||||
CheckingSessionAttrs {
|
||||
stream: SimpleQueryStream,
|
||||
client: Client,
|
||||
connection: Connection<T::Stream>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
#[state_machine_future(error)]
|
||||
@ -130,7 +139,8 @@ where
|
||||
let state = state.take();
|
||||
|
||||
transition!(Handshaking {
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.config)
|
||||
target_session_attrs: state.config.0.target_session_attrs,
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
|
||||
})
|
||||
}
|
||||
|
||||
@ -203,6 +213,7 @@ where
|
||||
let stream = Socket::new_tcp(stream);
|
||||
|
||||
transition!(Handshaking {
|
||||
target_session_attrs: state.config.0.target_session_attrs,
|
||||
future: HandshakeFuture::new(stream, state.tls_mode, state.config),
|
||||
})
|
||||
}
|
||||
@ -210,9 +221,37 @@ where
|
||||
fn poll_handshaking<'a>(
|
||||
state: &'a mut RentToOwn<'a, Handshaking<T>>,
|
||||
) -> Poll<AfterHandshaking<T>, Error> {
|
||||
let r = try_ready!(state.future.poll());
|
||||
let (client, connection) = try_ready!(state.future.poll());
|
||||
|
||||
transition!(Finished(r))
|
||||
if let TargetSessionAttrs::ReadWrite = state.target_session_attrs {
|
||||
transition!(CheckingSessionAttrs {
|
||||
stream: client.batch_execute("SHOW transaction_read_only"),
|
||||
client,
|
||||
connection,
|
||||
})
|
||||
} else {
|
||||
transition!(Finished((client, connection)))
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_checking_session_attrs<'a>(
|
||||
state: &'a mut RentToOwn<'a, CheckingSessionAttrs<T>>,
|
||||
) -> Poll<AfterCheckingSessionAttrs<T>, Error> {
|
||||
if let Async::Ready(()) = state.connection.poll()? {
|
||||
return Err(Error::closed());
|
||||
}
|
||||
|
||||
match try_ready!(state.stream.poll()) {
|
||||
Some(row) => {
|
||||
if row.get(0) == Some("on") {
|
||||
Err(Error::read_only_database())
|
||||
} else {
|
||||
let state = state.take();
|
||||
transition!(Finished((state.client, state.connection)))
|
||||
}
|
||||
}
|
||||
None => Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
#[cfg(feature = "runtime")]
|
||||
use std::time::Duration;
|
||||
#[cfg(feature = "runtime")]
|
||||
use tokio_postgres::TargetSessionAttrs;
|
||||
|
||||
#[test]
|
||||
fn pairs_ok() {
|
||||
@ -31,7 +33,8 @@ fn pairs_ws() {
|
||||
#[test]
|
||||
#[cfg(feature = "runtime")]
|
||||
fn settings() {
|
||||
let params = "connect_timeout=3 keepalives=0 keepalives_idle=30"
|
||||
let params =
|
||||
"connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-write"
|
||||
.parse::<tokio_postgres::Config>()
|
||||
.unwrap();
|
||||
|
||||
@ -39,7 +42,8 @@ fn settings() {
|
||||
expected
|
||||
.connect_timeout(Duration::from_secs(3))
|
||||
.keepalives(false)
|
||||
.keepalives_idle(Duration::from_secs(30));
|
||||
.keepalives_idle(Duration::from_secs(30))
|
||||
.target_session_attrs(TargetSessionAttrs::ReadWrite);
|
||||
|
||||
assert_eq!(params, expected);
|
||||
}
|
||||
|
@ -46,3 +46,23 @@ fn wrong_port_count() {
|
||||
);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_session_attrs_ok() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = tokio_postgres::connect(
|
||||
"host=localhost port=5433 user=postgres target_session_attrs=read-write",
|
||||
NoTls,
|
||||
);
|
||||
runtime.block_on(f).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_session_attrs_err() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = tokio_postgres::connect(
|
||||
"host=localhost port=5433 user=postgres target_session_attrs=read-write default_transaction_read_only=on",
|
||||
NoTls,
|
||||
);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user