Overhaul query cancellation

Multi-host support means we can't simply take the old approach - we need
to know which of the hosts we actually connected to. It's also nice to
move this from the connection to the client since that's what you'd
normally have access to.
This commit is contained in:
Steven Fackler 2019-01-06 18:03:51 -08:00
parent a6535b4310
commit 1f6d9ddc06
11 changed files with 305 additions and 78 deletions

View File

@ -17,10 +17,10 @@ use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "runtime")]
use crate::proto::ConnectFuture;
use crate::proto::{CancelQueryRawFuture, HandshakeFuture};
use crate::{CancelData, CancelQueryRaw, Error, Handshake, TlsMode};
use crate::proto::HandshakeFuture;
#[cfg(feature = "runtime")]
use crate::{Connect, MakeTlsMode, Socket};
use crate::{Error, Handshake, TlsMode};
#[cfg(feature = "runtime")]
#[derive(Debug, Copy, Clone, PartialEq)]
@ -267,7 +267,7 @@ impl Config {
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone()))
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone(), None))
}
#[cfg(feature = "runtime")]
@ -277,19 +277,6 @@ impl Config {
{
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
}
pub fn cancel_query_raw<S, T>(
&self,
stream: S,
tls_mode: T,
cancel_data: CancelData,
) -> CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
CancelQueryRaw(CancelQueryRawFuture::new(stream, tls_mode, cancel_data))
}
}
impl FromStr for Config {

View File

@ -206,6 +206,22 @@ impl Client {
BatchExecute(self.0.batch_execute(query))
}
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&mut self, make_tls_mode: T) -> CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
CancelQuery(self.0.cancel_query(make_tls_mode))
}
pub fn cancel_query_raw<S, T>(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode))
}
pub fn is_closed(&self) -> bool {
self.0.is_closed()
}
@ -222,10 +238,6 @@ impl<S> Connection<S>
where
S: AsyncRead + AsyncWrite,
{
pub fn cancel_data(&self) -> CancelData {
self.0.cancel_data()
}
pub fn parameter(&self, name: &str) -> Option<&str> {
self.0.parameter(name)
}
@ -274,6 +286,25 @@ where
}
}
#[cfg(feature = "runtime")]
#[must_use = "futures do nothing unless polled"]
pub struct CancelQuery<T>(proto::CancelQueryFuture<T>)
where
T: MakeTlsMode<Socket>;
#[cfg(feature = "runtime")]
impl<T> Future for CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<(), Error> {
self.0.poll()
}
}
#[must_use = "futures do nothing unless polled"]
pub struct Handshake<S, T>(proto::HandshakeFuture<S, T>)
where
@ -478,15 +509,6 @@ impl Future for BatchExecute {
}
}
/// Contains information necessary to cancel queries for a session.
#[derive(Copy, Clone, Debug)]
pub struct CancelData {
/// The process ID of the session.
pub process_id: i32,
/// The secret key for the session.
pub secret_key: i32,
}
/// An asynchronous notification.
#[derive(Clone, Debug)]
pub struct Notification {

View File

@ -0,0 +1,105 @@
use futures::{try_ready, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::io;
use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture};
use crate::{Config, Error, Host, MakeTlsMode, Socket};
#[derive(StateMachineFuture)]
pub enum CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
#[state_machine_future(start, transitions(ConnectingSocket))]
Start {
make_tls_mode: T,
idx: Option<usize>,
config: Config,
process_id: i32,
secret_key: i32,
},
#[state_machine_future(transitions(Canceling))]
ConnectingSocket {
future: ConnectSocketFuture,
tls_mode: T::TlsMode,
process_id: i32,
secret_key: i32,
},
#[state_machine_future(transitions(Finished))]
Canceling {
future: CancelQueryRawFuture<Socket, T::TlsMode>,
},
#[state_machine_future(ready)]
Finished(()),
#[state_machine_future(error)]
Failed(Error),
}
impl<T> PollCancelQuery<T> for CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
let idx = state.idx.ok_or_else(|| {
Error::connect(io::Error::new(io::ErrorKind::InvalidInput, "unknown host"))
})?;
let hostname = match &state.config.0.host[idx] {
Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
.map_err(|e| Error::tls(e.into()))?;
transition!(ConnectingSocket {
future: ConnectSocketFuture::new(state.config, idx),
tls_mode,
process_id: state.process_id,
secret_key: state.secret_key,
})
}
fn poll_connecting_socket<'a>(
state: &'a mut RentToOwn<'a, ConnectingSocket<T>>,
) -> Poll<AfterConnectingSocket<T>, Error> {
let socket = try_ready!(state.future.poll());
let state = state.take();
transition!(Canceling {
future: CancelQueryRawFuture::new(
socket,
state.tls_mode,
state.process_id,
state.secret_key
),
})
}
fn poll_canceling<'a>(
state: &'a mut RentToOwn<'a, Canceling<T>>,
) -> Poll<AfterCanceling, Error> {
try_ready!(state.future.poll());
transition!(Finished(()))
}
}
impl<T> CancelQueryFuture<T>
where
T: MakeTlsMode<Socket>,
{
pub fn new(
make_tls_mode: T,
idx: Option<usize>,
config: Config,
process_id: i32,
secret_key: i32,
) -> CancelQueryFuture<T> {
CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key)
}
}

View File

@ -6,7 +6,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::error::Error;
use crate::proto::TlsFuture;
use crate::{CancelData, TlsMode};
use crate::TlsMode;
#[derive(StateMachineFuture)]
pub enum CancelQueryRaw<S, T>
@ -17,7 +17,8 @@ where
#[state_machine_future(start, transitions(SendingCancel))]
Start {
future: TlsFuture<S, T>,
cancel_data: CancelData,
process_id: i32,
secret_key: i32,
},
#[state_machine_future(transitions(FlushingCancel))]
SendingCancel {
@ -40,11 +41,7 @@ where
let (stream, _) = try_ready!(state.future.poll());
let mut buf = vec![];
frontend::cancel_request(
state.cancel_data.process_id,
state.cancel_data.secret_key,
&mut buf,
);
frontend::cancel_request(state.process_id, state.secret_key, &mut buf);
transition!(SendingCancel {
future: io::write_all(stream, buf),
@ -74,7 +71,12 @@ where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
pub fn new(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQueryRawFuture<S, T> {
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), cancel_data)
pub fn new(
stream: S,
tls_mode: T,
process_id: i32,
secret_key: i32,
) -> CancelQueryRawFuture<S, T> {
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key)
}
}

View File

@ -8,6 +8,7 @@ use postgres_protocol::message::frontend;
use std::collections::HashMap;
use std::error::Error as StdError;
use std::sync::{Arc, Weak};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::bind::BindFuture;
use crate::proto::connection::{Request, RequestMessages};
@ -20,8 +21,13 @@ use crate::proto::prepare::PrepareFuture;
use crate::proto::query::QueryStream;
use crate::proto::simple_query::SimpleQueryStream;
use crate::proto::statement::Statement;
#[cfg(feature = "runtime")]
use crate::proto::CancelQueryFuture;
use crate::proto::CancelQueryRawFuture;
use crate::types::{IsNull, Oid, ToSql, Type};
use crate::Error;
use crate::{Config, Error, TlsMode};
#[cfg(feature = "runtime")]
use crate::{MakeTlsMode, Socket};
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
@ -44,13 +50,25 @@ struct Inner {
state: Mutex<State>,
idle: IdleState,
sender: mpsc::UnboundedSender<Request>,
process_id: i32,
secret_key: i32,
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
config: Config,
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
idx: Option<usize>,
}
#[derive(Clone)]
pub struct Client(Arc<Inner>);
impl Client {
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client {
pub fn new(
sender: mpsc::UnboundedSender<Request>,
process_id: i32,
secret_key: i32,
config: Config,
idx: Option<usize>,
) -> Client {
Client(Arc::new(Inner {
state: Mutex::new(State {
types: HashMap::new(),
@ -60,6 +78,10 @@ impl Client {
}),
idle: IdleState::new(),
sender,
process_id,
secret_key,
config,
idx,
}))
}
@ -222,6 +244,28 @@ impl Client {
self.close(b'P', name)
}
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, make_tls_mode: T) -> CancelQueryFuture<T>
where
T: MakeTlsMode<Socket>,
{
CancelQueryFuture::new(
make_tls_mode,
self.0.idx,
self.0.config.clone(),
self.0.process_id,
self.0.secret_key,
)
}
pub fn cancel_query_raw<S, T>(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key)
}
fn close(&self, ty: u8, name: &str) {
let mut buf = vec![];
frontend::close(ty, name, &mut buf).expect("statement name not valid");

View File

@ -65,7 +65,7 @@ where
transition!(Handshaking {
target_session_attrs: state.config.0.target_session_attrs,
future: HandshakeFuture::new(socket, state.tls_mode, state.config),
future: HandshakeFuture::new(socket, state.tls_mode, state.config, Some(state.idx)),
})
}

View File

@ -11,7 +11,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::codec::PostgresCodec;
use crate::proto::copy_in::CopyInReceiver;
use crate::proto::idle::IdleGuard;
use crate::{AsyncMessage, CancelData, Notification};
use crate::{AsyncMessage, Notification};
use crate::{DbError, Error};
pub enum RequestMessages {
@ -42,7 +42,6 @@ enum State {
pub struct Connection<S> {
stream: Framed<S, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
@ -57,13 +56,11 @@ where
{
pub fn new(
stream: Framed<S, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S> {
Connection {
stream,
cancel_data,
parameters,
receiver,
pending_request: None,
@ -73,10 +70,6 @@ where
}
}
pub fn cancel_data(&self) -> CancelData {
self.cancel_data
}
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}

View File

@ -8,12 +8,11 @@ use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend;
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::collections::HashMap;
use std::io;
use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::{Client, Connection, PostgresCodec, TlsFuture};
use crate::{CancelData, ChannelBinding, Config, Error, TlsMode};
use crate::{ChannelBinding, Config, Error, TlsMode};
#[derive(StateMachineFuture)]
pub enum Handshake<S, T>
@ -25,42 +24,56 @@ where
Start {
future: TlsFuture<S, T>,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(ReadingAuth))]
SendingStartup {
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
config: Config,
idx: Option<usize>,
channel_binding: ChannelBinding,
},
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
ReadingAuth {
stream: Framed<T::Stream, PostgresCodec>,
config: Config,
idx: Option<usize>,
channel_binding: ChannelBinding,
},
#[state_machine_future(transitions(ReadingAuthCompletion))]
SendingPassword {
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(ReadingSasl))]
SendingSasl {
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
scram: ScramSha256,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
ReadingSasl {
stream: Framed<T::Stream, PostgresCodec>,
scram: ScramSha256,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(ReadingInfo))]
ReadingAuthCompletion {
stream: Framed<T::Stream, PostgresCodec>,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(transitions(Finished))]
ReadingInfo {
stream: Framed<T::Stream, PostgresCodec>,
cancel_data: Option<CancelData>,
process_id: i32,
secret_key: i32,
parameters: HashMap<String, String>,
config: Config,
idx: Option<usize>,
},
#[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)),
@ -99,6 +112,7 @@ where
transition!(SendingStartup {
future: stream.send(buf),
config: state.config,
idx: state.idx,
channel_binding,
})
}
@ -111,6 +125,7 @@ where
transition!(ReadingAuth {
stream,
config: state.config,
idx: state.idx,
channel_binding: state.channel_binding,
})
}
@ -124,8 +139,11 @@ where
match message {
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
stream: state.stream,
cancel_data: None,
process_id: 0,
secret_key: 0,
parameters: HashMap::new(),
config: state.config,
idx: state.idx,
}),
Some(Message::AuthenticationCleartextPassword) => {
let pass = state
@ -137,7 +155,9 @@ where
let mut buf = vec![];
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf)
future: state.stream.send(buf),
config: state.config,
idx: state.idx,
})
}
Some(Message::AuthenticationMd5Password(body)) => {
@ -157,7 +177,9 @@ where
let mut buf = vec![];
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
transition!(SendingPassword {
future: state.stream.send(buf)
future: state.stream.send(buf),
config: state.config,
idx: state.idx,
})
}
Some(Message::AuthenticationSasl(body)) => {
@ -214,6 +236,8 @@ where
transition!(SendingSasl {
future: state.stream.send(buf),
scram,
config: state.config,
idx: state.idx,
})
}
Some(Message::AuthenticationKerberosV5)
@ -232,7 +256,12 @@ where
state: &'a mut RentToOwn<'a, SendingPassword<S, T>>,
) -> Poll<AfterSendingPassword<S, T>, Error> {
let stream = try_ready!(state.future.poll().map_err(Error::io));
transition!(ReadingAuthCompletion { stream })
let state = state.take();
transition!(ReadingAuthCompletion {
stream,
config: state.config,
idx: state.idx,
})
}
fn poll_sending_sasl<'a>(
@ -243,6 +272,8 @@ where
transition!(ReadingSasl {
stream,
scram: state.scram,
config: state.config,
idx: state.idx,
})
}
@ -263,6 +294,8 @@ where
transition!(SendingSasl {
future: state.stream.send(buf),
scram: state.scram,
config: state.config,
idx: state.idx,
})
}
Some(Message::AuthenticationSaslFinal(body)) => {
@ -271,7 +304,9 @@ where
.finish(body.data())
.map_err(|e| Error::authentication(Box::new(e)))?;
transition!(ReadingAuthCompletion {
stream: state.stream
stream: state.stream,
config: state.config,
idx: state.idx,
})
}
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
@ -289,8 +324,11 @@ where
match message {
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
stream: state.stream,
cancel_data: None,
parameters: HashMap::new()
process_id: 0,
secret_key: 0,
parameters: HashMap::new(),
config: state.config,
idx: state.idx,
}),
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
Some(_) => Err(Error::unexpected_message()),
@ -305,10 +343,8 @@ where
let message = try_ready!(state.stream.poll().map_err(Error::io));
match message {
Some(Message::BackendKeyData(body)) => {
state.cancel_data = Some(CancelData {
process_id: body.process_id(),
secret_key: body.secret_key(),
});
state.process_id = body.process_id();
state.secret_key = body.secret_key();
}
Some(Message::ParameterStatus(body)) => {
state.parameters.insert(
@ -318,16 +354,15 @@ where
}
Some(Message::ReadyForQuery(_)) => {
let state = state.take();
let cancel_data = state.cancel_data.ok_or_else(|| {
Error::parse(io::Error::new(
io::ErrorKind::InvalidData,
"BackendKeyData message missing",
))
})?;
let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender);
let connection =
Connection::new(state.stream, cancel_data, state.parameters, receiver);
let client = Client::new(
sender,
state.process_id,
state.secret_key,
state.config,
state.idx,
);
let connection = Connection::new(state.stream, state.parameters, receiver);
transition!(Finished((client, connection)))
}
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
@ -344,7 +379,12 @@ where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
pub fn new(stream: S, tls_mode: T, config: Config) -> HandshakeFuture<S, T> {
Handshake::start(TlsFuture::new(stream, tls_mode), config)
pub fn new(
stream: S,
tls_mode: T,
config: Config,
idx: Option<usize>,
) -> HandshakeFuture<S, T> {
Handshake::start(TlsFuture::new(stream, tls_mode), config, idx)
}
}

View File

@ -19,6 +19,8 @@ macro_rules! try_ready_closed {
}
mod bind;
#[cfg(feature = "runtime")]
mod cancel_query;
mod cancel_query_raw;
mod client;
mod codec;
@ -46,6 +48,8 @@ mod typeinfo_composite;
mod typeinfo_enum;
pub use crate::proto::bind::BindFuture;
#[cfg(feature = "runtime")]
pub use crate::proto::cancel_query::CancelQueryFuture;
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
pub use crate::proto::client::Client;
pub use crate::proto::codec::PostgresCodec;

View File

@ -222,12 +222,11 @@ fn query_portal() {
}
#[test]
fn cancel_query() {
fn cancel_query_raw() {
let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let cancel_data = connection.cancel_data();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();
@ -245,7 +244,7 @@ fn cancel_query() {
})
.then(|r| {
let s = r.unwrap();
tokio_postgres::Config::new().cancel_query_raw(s, NoTls, cancel_data)
client.cancel_query_raw(s, NoTls)
})
.then(|r| {
r.unwrap();

View File

@ -1,6 +1,8 @@
use futures::Future;
use std::time::{Duration, Instant};
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::NoTls;
use tokio::timer::Delay;
use tokio_postgres::{NoTls, SqlState};
fn smoke_test(s: &str) {
let mut runtime = Runtime::new().unwrap();
@ -67,3 +69,32 @@ fn target_session_attrs_err() {
);
runtime.block_on(f).err().unwrap();
}
#[test]
fn cancel_query() {
let mut runtime = Runtime::new().unwrap();
let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls);
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let sleep = client
.batch_execute("SELECT pg_sleep(100)")
.then(|r| match r {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
Err(e) => panic!("unexpected error {}", e),
});
let cancel = Delay::new(Instant::now() + Duration::from_millis(100))
.then(|r| {
r.unwrap();
client.cancel_query(NoTls)
})
.then(|r| {
r.unwrap();
Ok::<(), ()>(())
});
let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap();
}