Overhaul query cancellation

Multi-host support means we can't simply take the old approach - we need
to know which of the hosts we actually connected to. It's also nice to
move this from the connection to the client since that's what you'd
normally have access to.
This commit is contained in:
Steven Fackler 2019-01-06 18:03:51 -08:00
parent a6535b4310
commit 1f6d9ddc06
11 changed files with 305 additions and 78 deletions

View File

@ -17,10 +17,10 @@ use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::proto::ConnectFuture; use crate::proto::ConnectFuture;
use crate::proto::{CancelQueryRawFuture, HandshakeFuture}; use crate::proto::HandshakeFuture;
use crate::{CancelData, CancelQueryRaw, Error, Handshake, TlsMode};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
use crate::{Connect, MakeTlsMode, Socket}; use crate::{Connect, MakeTlsMode, Socket};
use crate::{Error, Handshake, TlsMode};
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq)]
@ -267,7 +267,7 @@ impl Config {
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite,
T: TlsMode<S>, T: TlsMode<S>,
{ {
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone())) Handshake(HandshakeFuture::new(stream, tls_mode, self.clone(), None))
} }
#[cfg(feature = "runtime")] #[cfg(feature = "runtime")]
@ -277,19 +277,6 @@ impl Config {
{ {
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone()))) Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
} }
pub fn cancel_query_raw<S, T>(
&self,
stream: S,
tls_mode: T,
cancel_data: CancelData,
) -> CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
CancelQueryRaw(CancelQueryRawFuture::new(stream, tls_mode, cancel_data))
}
} }
impl FromStr for Config { impl FromStr for Config {

View File

@ -206,6 +206,22 @@ impl Client {
BatchExecute(self.0.batch_execute(query)) BatchExecute(self.0.batch_execute(query))
} }
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&mut self, make_tls_mode: T) -> CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
CancelQuery(self.0.cancel_query(make_tls_mode))
}
pub fn cancel_query_raw<S, T>(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode))
}
pub fn is_closed(&self) -> bool { pub fn is_closed(&self) -> bool {
self.0.is_closed() self.0.is_closed()
} }
@ -222,10 +238,6 @@ impl<S> Connection<S>
where where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite,
{ {
pub fn cancel_data(&self) -> CancelData {
self.0.cancel_data()
}
pub fn parameter(&self, name: &str) -> Option<&str> { pub fn parameter(&self, name: &str) -> Option<&str> {
self.0.parameter(name) self.0.parameter(name)
} }
@ -274,6 +286,25 @@ where
} }
} }
#[cfg(feature = "runtime")]
#[must_use = "futures do nothing unless polled"]
pub struct CancelQuery<T>(proto::CancelQueryFuture<T>)
where
T: MakeTlsMode<Socket>;
#[cfg(feature = "runtime")]
impl<T> Future for CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
type Item = ();
type Error = Error;
fn poll(&mut self) -> Poll<(), Error> {
self.0.poll()
}
}
#[must_use = "futures do nothing unless polled"] #[must_use = "futures do nothing unless polled"]
pub struct Handshake<S, T>(proto::HandshakeFuture<S, T>) pub struct Handshake<S, T>(proto::HandshakeFuture<S, T>)
where where
@ -478,15 +509,6 @@ impl Future for BatchExecute {
} }
} }
/// Contains information necessary to cancel queries for a session.
#[derive(Copy, Clone, Debug)]
pub struct CancelData {
/// The process ID of the session.
pub process_id: i32,
/// The secret key for the session.
pub secret_key: i32,
}
/// An asynchronous notification. /// An asynchronous notification.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Notification { pub struct Notification {

View File

@ -0,0 +1,105 @@
use futures::{try_ready, Future, Poll};
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::io;
use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture};
use crate::{Config, Error, Host, MakeTlsMode, Socket};
#[derive(StateMachineFuture)]
pub enum CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
#[state_machine_future(start, transitions(ConnectingSocket))]
Start {
make_tls_mode: T,
idx: Option<usize>,
config: Config,
process_id: i32,
secret_key: i32,
},
#[state_machine_future(transitions(Canceling))]
ConnectingSocket {
future: ConnectSocketFuture,
tls_mode: T::TlsMode,
process_id: i32,
secret_key: i32,
},
#[state_machine_future(transitions(Finished))]
Canceling {
future: CancelQueryRawFuture<Socket, T::TlsMode>,
},
#[state_machine_future(ready)]
Finished(()),
#[state_machine_future(error)]
Failed(Error),
}
impl<T> PollCancelQuery<T> for CancelQuery<T>
where
T: MakeTlsMode<Socket>,
{
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
let mut state = state.take();
let idx = state.idx.ok_or_else(|| {
Error::connect(io::Error::new(io::ErrorKind::InvalidInput, "unknown host"))
})?;
let hostname = match &state.config.0.host[idx] {
Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls_mode = state
.make_tls_mode
.make_tls_mode(hostname)
.map_err(|e| Error::tls(e.into()))?;
transition!(ConnectingSocket {
future: ConnectSocketFuture::new(state.config, idx),
tls_mode,
process_id: state.process_id,
secret_key: state.secret_key,
})
}
fn poll_connecting_socket<'a>(
state: &'a mut RentToOwn<'a, ConnectingSocket<T>>,
) -> Poll<AfterConnectingSocket<T>, Error> {
let socket = try_ready!(state.future.poll());
let state = state.take();
transition!(Canceling {
future: CancelQueryRawFuture::new(
socket,
state.tls_mode,
state.process_id,
state.secret_key
),
})
}
fn poll_canceling<'a>(
state: &'a mut RentToOwn<'a, Canceling<T>>,
) -> Poll<AfterCanceling, Error> {
try_ready!(state.future.poll());
transition!(Finished(()))
}
}
impl<T> CancelQueryFuture<T>
where
T: MakeTlsMode<Socket>,
{
pub fn new(
make_tls_mode: T,
idx: Option<usize>,
config: Config,
process_id: i32,
secret_key: i32,
) -> CancelQueryFuture<T> {
CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key)
}
}

View File

@ -6,7 +6,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::error::Error; use crate::error::Error;
use crate::proto::TlsFuture; use crate::proto::TlsFuture;
use crate::{CancelData, TlsMode}; use crate::TlsMode;
#[derive(StateMachineFuture)] #[derive(StateMachineFuture)]
pub enum CancelQueryRaw<S, T> pub enum CancelQueryRaw<S, T>
@ -17,7 +17,8 @@ where
#[state_machine_future(start, transitions(SendingCancel))] #[state_machine_future(start, transitions(SendingCancel))]
Start { Start {
future: TlsFuture<S, T>, future: TlsFuture<S, T>,
cancel_data: CancelData, process_id: i32,
secret_key: i32,
}, },
#[state_machine_future(transitions(FlushingCancel))] #[state_machine_future(transitions(FlushingCancel))]
SendingCancel { SendingCancel {
@ -40,11 +41,7 @@ where
let (stream, _) = try_ready!(state.future.poll()); let (stream, _) = try_ready!(state.future.poll());
let mut buf = vec![]; let mut buf = vec![];
frontend::cancel_request( frontend::cancel_request(state.process_id, state.secret_key, &mut buf);
state.cancel_data.process_id,
state.cancel_data.secret_key,
&mut buf,
);
transition!(SendingCancel { transition!(SendingCancel {
future: io::write_all(stream, buf), future: io::write_all(stream, buf),
@ -74,7 +71,12 @@ where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite,
T: TlsMode<S>, T: TlsMode<S>,
{ {
pub fn new(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQueryRawFuture<S, T> { pub fn new(
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), cancel_data) stream: S,
tls_mode: T,
process_id: i32,
secret_key: i32,
) -> CancelQueryRawFuture<S, T> {
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key)
} }
} }

View File

@ -8,6 +8,7 @@ use postgres_protocol::message::frontend;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error as StdError; use std::error::Error as StdError;
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::bind::BindFuture; use crate::proto::bind::BindFuture;
use crate::proto::connection::{Request, RequestMessages}; use crate::proto::connection::{Request, RequestMessages};
@ -20,8 +21,13 @@ use crate::proto::prepare::PrepareFuture;
use crate::proto::query::QueryStream; use crate::proto::query::QueryStream;
use crate::proto::simple_query::SimpleQueryStream; use crate::proto::simple_query::SimpleQueryStream;
use crate::proto::statement::Statement; use crate::proto::statement::Statement;
#[cfg(feature = "runtime")]
use crate::proto::CancelQueryFuture;
use crate::proto::CancelQueryRawFuture;
use crate::types::{IsNull, Oid, ToSql, Type}; use crate::types::{IsNull, Oid, ToSql, Type};
use crate::Error; use crate::{Config, Error, TlsMode};
#[cfg(feature = "runtime")]
use crate::{MakeTlsMode, Socket};
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>); pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
@ -44,13 +50,25 @@ struct Inner {
state: Mutex<State>, state: Mutex<State>,
idle: IdleState, idle: IdleState,
sender: mpsc::UnboundedSender<Request>, sender: mpsc::UnboundedSender<Request>,
process_id: i32,
secret_key: i32,
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
config: Config,
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
idx: Option<usize>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct Client(Arc<Inner>); pub struct Client(Arc<Inner>);
impl Client { impl Client {
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client { pub fn new(
sender: mpsc::UnboundedSender<Request>,
process_id: i32,
secret_key: i32,
config: Config,
idx: Option<usize>,
) -> Client {
Client(Arc::new(Inner { Client(Arc::new(Inner {
state: Mutex::new(State { state: Mutex::new(State {
types: HashMap::new(), types: HashMap::new(),
@ -60,6 +78,10 @@ impl Client {
}), }),
idle: IdleState::new(), idle: IdleState::new(),
sender, sender,
process_id,
secret_key,
config,
idx,
})) }))
} }
@ -222,6 +244,28 @@ impl Client {
self.close(b'P', name) self.close(b'P', name)
} }
#[cfg(feature = "runtime")]
pub fn cancel_query<T>(&self, make_tls_mode: T) -> CancelQueryFuture<T>
where
T: MakeTlsMode<Socket>,
{
CancelQueryFuture::new(
make_tls_mode,
self.0.idx,
self.0.config.clone(),
self.0.process_id,
self.0.secret_key,
)
}
pub fn cancel_query_raw<S, T>(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture<S, T>
where
S: AsyncRead + AsyncWrite,
T: TlsMode<S>,
{
CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key)
}
fn close(&self, ty: u8, name: &str) { fn close(&self, ty: u8, name: &str) {
let mut buf = vec![]; let mut buf = vec![];
frontend::close(ty, name, &mut buf).expect("statement name not valid"); frontend::close(ty, name, &mut buf).expect("statement name not valid");

View File

@ -65,7 +65,7 @@ where
transition!(Handshaking { transition!(Handshaking {
target_session_attrs: state.config.0.target_session_attrs, target_session_attrs: state.config.0.target_session_attrs,
future: HandshakeFuture::new(socket, state.tls_mode, state.config), future: HandshakeFuture::new(socket, state.tls_mode, state.config, Some(state.idx)),
}) })
} }

View File

@ -11,7 +11,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::codec::PostgresCodec; use crate::proto::codec::PostgresCodec;
use crate::proto::copy_in::CopyInReceiver; use crate::proto::copy_in::CopyInReceiver;
use crate::proto::idle::IdleGuard; use crate::proto::idle::IdleGuard;
use crate::{AsyncMessage, CancelData, Notification}; use crate::{AsyncMessage, Notification};
use crate::{DbError, Error}; use crate::{DbError, Error};
pub enum RequestMessages { pub enum RequestMessages {
@ -42,7 +42,6 @@ enum State {
pub struct Connection<S> { pub struct Connection<S> {
stream: Framed<S, PostgresCodec>, stream: Framed<S, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>, parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>, receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>, pending_request: Option<RequestMessages>,
@ -57,13 +56,11 @@ where
{ {
pub fn new( pub fn new(
stream: Framed<S, PostgresCodec>, stream: Framed<S, PostgresCodec>,
cancel_data: CancelData,
parameters: HashMap<String, String>, parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>, receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S> { ) -> Connection<S> {
Connection { Connection {
stream, stream,
cancel_data,
parameters, parameters,
receiver, receiver,
pending_request: None, pending_request: None,
@ -73,10 +70,6 @@ where
} }
} }
pub fn cancel_data(&self) -> CancelData {
self.cancel_data
}
pub fn parameter(&self, name: &str) -> Option<&str> { pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s) self.parameters.get(name).map(|s| &**s)
} }

View File

@ -8,12 +8,11 @@ use postgres_protocol::message::backend::Message;
use postgres_protocol::message::frontend; use postgres_protocol::message::frontend;
use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use state_machine_future::{transition, RentToOwn, StateMachineFuture};
use std::collections::HashMap; use std::collections::HashMap;
use std::io;
use tokio_codec::Framed; use tokio_codec::Framed;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use crate::proto::{Client, Connection, PostgresCodec, TlsFuture}; use crate::proto::{Client, Connection, PostgresCodec, TlsFuture};
use crate::{CancelData, ChannelBinding, Config, Error, TlsMode}; use crate::{ChannelBinding, Config, Error, TlsMode};
#[derive(StateMachineFuture)] #[derive(StateMachineFuture)]
pub enum Handshake<S, T> pub enum Handshake<S, T>
@ -25,42 +24,56 @@ where
Start { Start {
future: TlsFuture<S, T>, future: TlsFuture<S, T>,
config: Config, config: Config,
idx: Option<usize>,
}, },
#[state_machine_future(transitions(ReadingAuth))] #[state_machine_future(transitions(ReadingAuth))]
SendingStartup { SendingStartup {
future: sink::Send<Framed<T::Stream, PostgresCodec>>, future: sink::Send<Framed<T::Stream, PostgresCodec>>,
config: Config, config: Config,
idx: Option<usize>,
channel_binding: ChannelBinding, channel_binding: ChannelBinding,
}, },
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
ReadingAuth { ReadingAuth {
stream: Framed<T::Stream, PostgresCodec>, stream: Framed<T::Stream, PostgresCodec>,
config: Config, config: Config,
idx: Option<usize>,
channel_binding: ChannelBinding, channel_binding: ChannelBinding,
}, },
#[state_machine_future(transitions(ReadingAuthCompletion))] #[state_machine_future(transitions(ReadingAuthCompletion))]
SendingPassword { SendingPassword {
future: sink::Send<Framed<T::Stream, PostgresCodec>>, future: sink::Send<Framed<T::Stream, PostgresCodec>>,
config: Config,
idx: Option<usize>,
}, },
#[state_machine_future(transitions(ReadingSasl))] #[state_machine_future(transitions(ReadingSasl))]
SendingSasl { SendingSasl {
future: sink::Send<Framed<T::Stream, PostgresCodec>>, future: sink::Send<Framed<T::Stream, PostgresCodec>>,
scram: ScramSha256, scram: ScramSha256,
config: Config,
idx: Option<usize>,
}, },
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))] #[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
ReadingSasl { ReadingSasl {
stream: Framed<T::Stream, PostgresCodec>, stream: Framed<T::Stream, PostgresCodec>,
scram: ScramSha256, scram: ScramSha256,
config: Config,
idx: Option<usize>,
}, },
#[state_machine_future(transitions(ReadingInfo))] #[state_machine_future(transitions(ReadingInfo))]
ReadingAuthCompletion { ReadingAuthCompletion {
stream: Framed<T::Stream, PostgresCodec>, stream: Framed<T::Stream, PostgresCodec>,
config: Config,
idx: Option<usize>,
}, },
#[state_machine_future(transitions(Finished))] #[state_machine_future(transitions(Finished))]
ReadingInfo { ReadingInfo {
stream: Framed<T::Stream, PostgresCodec>, stream: Framed<T::Stream, PostgresCodec>,
cancel_data: Option<CancelData>, process_id: i32,
secret_key: i32,
parameters: HashMap<String, String>, parameters: HashMap<String, String>,
config: Config,
idx: Option<usize>,
}, },
#[state_machine_future(ready)] #[state_machine_future(ready)]
Finished((Client, Connection<T::Stream>)), Finished((Client, Connection<T::Stream>)),
@ -99,6 +112,7 @@ where
transition!(SendingStartup { transition!(SendingStartup {
future: stream.send(buf), future: stream.send(buf),
config: state.config, config: state.config,
idx: state.idx,
channel_binding, channel_binding,
}) })
} }
@ -111,6 +125,7 @@ where
transition!(ReadingAuth { transition!(ReadingAuth {
stream, stream,
config: state.config, config: state.config,
idx: state.idx,
channel_binding: state.channel_binding, channel_binding: state.channel_binding,
}) })
} }
@ -124,8 +139,11 @@ where
match message { match message {
Some(Message::AuthenticationOk) => transition!(ReadingInfo { Some(Message::AuthenticationOk) => transition!(ReadingInfo {
stream: state.stream, stream: state.stream,
cancel_data: None, process_id: 0,
secret_key: 0,
parameters: HashMap::new(), parameters: HashMap::new(),
config: state.config,
idx: state.idx,
}), }),
Some(Message::AuthenticationCleartextPassword) => { Some(Message::AuthenticationCleartextPassword) => {
let pass = state let pass = state
@ -137,7 +155,9 @@ where
let mut buf = vec![]; let mut buf = vec![];
frontend::password_message(pass, &mut buf).map_err(Error::encode)?; frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
transition!(SendingPassword { transition!(SendingPassword {
future: state.stream.send(buf) future: state.stream.send(buf),
config: state.config,
idx: state.idx,
}) })
} }
Some(Message::AuthenticationMd5Password(body)) => { Some(Message::AuthenticationMd5Password(body)) => {
@ -157,7 +177,9 @@ where
let mut buf = vec![]; let mut buf = vec![];
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?; frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
transition!(SendingPassword { transition!(SendingPassword {
future: state.stream.send(buf) future: state.stream.send(buf),
config: state.config,
idx: state.idx,
}) })
} }
Some(Message::AuthenticationSasl(body)) => { Some(Message::AuthenticationSasl(body)) => {
@ -214,6 +236,8 @@ where
transition!(SendingSasl { transition!(SendingSasl {
future: state.stream.send(buf), future: state.stream.send(buf),
scram, scram,
config: state.config,
idx: state.idx,
}) })
} }
Some(Message::AuthenticationKerberosV5) Some(Message::AuthenticationKerberosV5)
@ -232,7 +256,12 @@ where
state: &'a mut RentToOwn<'a, SendingPassword<S, T>>, state: &'a mut RentToOwn<'a, SendingPassword<S, T>>,
) -> Poll<AfterSendingPassword<S, T>, Error> { ) -> Poll<AfterSendingPassword<S, T>, Error> {
let stream = try_ready!(state.future.poll().map_err(Error::io)); let stream = try_ready!(state.future.poll().map_err(Error::io));
transition!(ReadingAuthCompletion { stream }) let state = state.take();
transition!(ReadingAuthCompletion {
stream,
config: state.config,
idx: state.idx,
})
} }
fn poll_sending_sasl<'a>( fn poll_sending_sasl<'a>(
@ -243,6 +272,8 @@ where
transition!(ReadingSasl { transition!(ReadingSasl {
stream, stream,
scram: state.scram, scram: state.scram,
config: state.config,
idx: state.idx,
}) })
} }
@ -263,6 +294,8 @@ where
transition!(SendingSasl { transition!(SendingSasl {
future: state.stream.send(buf), future: state.stream.send(buf),
scram: state.scram, scram: state.scram,
config: state.config,
idx: state.idx,
}) })
} }
Some(Message::AuthenticationSaslFinal(body)) => { Some(Message::AuthenticationSaslFinal(body)) => {
@ -271,7 +304,9 @@ where
.finish(body.data()) .finish(body.data())
.map_err(|e| Error::authentication(Box::new(e)))?; .map_err(|e| Error::authentication(Box::new(e)))?;
transition!(ReadingAuthCompletion { transition!(ReadingAuthCompletion {
stream: state.stream stream: state.stream,
config: state.config,
idx: state.idx,
}) })
} }
Some(Message::ErrorResponse(body)) => Err(Error::db(body)), Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
@ -289,8 +324,11 @@ where
match message { match message {
Some(Message::AuthenticationOk) => transition!(ReadingInfo { Some(Message::AuthenticationOk) => transition!(ReadingInfo {
stream: state.stream, stream: state.stream,
cancel_data: None, process_id: 0,
parameters: HashMap::new() secret_key: 0,
parameters: HashMap::new(),
config: state.config,
idx: state.idx,
}), }),
Some(Message::ErrorResponse(body)) => Err(Error::db(body)), Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
Some(_) => Err(Error::unexpected_message()), Some(_) => Err(Error::unexpected_message()),
@ -305,10 +343,8 @@ where
let message = try_ready!(state.stream.poll().map_err(Error::io)); let message = try_ready!(state.stream.poll().map_err(Error::io));
match message { match message {
Some(Message::BackendKeyData(body)) => { Some(Message::BackendKeyData(body)) => {
state.cancel_data = Some(CancelData { state.process_id = body.process_id();
process_id: body.process_id(), state.secret_key = body.secret_key();
secret_key: body.secret_key(),
});
} }
Some(Message::ParameterStatus(body)) => { Some(Message::ParameterStatus(body)) => {
state.parameters.insert( state.parameters.insert(
@ -318,16 +354,15 @@ where
} }
Some(Message::ReadyForQuery(_)) => { Some(Message::ReadyForQuery(_)) => {
let state = state.take(); let state = state.take();
let cancel_data = state.cancel_data.ok_or_else(|| {
Error::parse(io::Error::new(
io::ErrorKind::InvalidData,
"BackendKeyData message missing",
))
})?;
let (sender, receiver) = mpsc::unbounded(); let (sender, receiver) = mpsc::unbounded();
let client = Client::new(sender); let client = Client::new(
let connection = sender,
Connection::new(state.stream, cancel_data, state.parameters, receiver); state.process_id,
state.secret_key,
state.config,
state.idx,
);
let connection = Connection::new(state.stream, state.parameters, receiver);
transition!(Finished((client, connection))) transition!(Finished((client, connection)))
} }
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
@ -344,7 +379,12 @@ where
S: AsyncRead + AsyncWrite, S: AsyncRead + AsyncWrite,
T: TlsMode<S>, T: TlsMode<S>,
{ {
pub fn new(stream: S, tls_mode: T, config: Config) -> HandshakeFuture<S, T> { pub fn new(
Handshake::start(TlsFuture::new(stream, tls_mode), config) stream: S,
tls_mode: T,
config: Config,
idx: Option<usize>,
) -> HandshakeFuture<S, T> {
Handshake::start(TlsFuture::new(stream, tls_mode), config, idx)
} }
} }

View File

@ -19,6 +19,8 @@ macro_rules! try_ready_closed {
} }
mod bind; mod bind;
#[cfg(feature = "runtime")]
mod cancel_query;
mod cancel_query_raw; mod cancel_query_raw;
mod client; mod client;
mod codec; mod codec;
@ -46,6 +48,8 @@ mod typeinfo_composite;
mod typeinfo_enum; mod typeinfo_enum;
pub use crate::proto::bind::BindFuture; pub use crate::proto::bind::BindFuture;
#[cfg(feature = "runtime")]
pub use crate::proto::cancel_query::CancelQueryFuture;
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture; pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
pub use crate::proto::client::Client; pub use crate::proto::client::Client;
pub use crate::proto::codec::PostgresCodec; pub use crate::proto::codec::PostgresCodec;

View File

@ -222,12 +222,11 @@ fn query_portal() {
} }
#[test] #[test]
fn cancel_query() { fn cancel_query_raw() {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
let mut runtime = Runtime::new().unwrap(); let mut runtime = Runtime::new().unwrap();
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
let cancel_data = connection.cancel_data();
let connection = connection.map_err(|e| panic!("{}", e)); let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap(); runtime.handle().spawn(connection).unwrap();
@ -245,7 +244,7 @@ fn cancel_query() {
}) })
.then(|r| { .then(|r| {
let s = r.unwrap(); let s = r.unwrap();
tokio_postgres::Config::new().cancel_query_raw(s, NoTls, cancel_data) client.cancel_query_raw(s, NoTls)
}) })
.then(|r| { .then(|r| {
r.unwrap(); r.unwrap();

View File

@ -1,6 +1,8 @@
use futures::Future; use futures::Future;
use std::time::{Duration, Instant};
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
use tokio_postgres::NoTls; use tokio::timer::Delay;
use tokio_postgres::{NoTls, SqlState};
fn smoke_test(s: &str) { fn smoke_test(s: &str) {
let mut runtime = Runtime::new().unwrap(); let mut runtime = Runtime::new().unwrap();
@ -67,3 +69,32 @@ fn target_session_attrs_err() {
); );
runtime.block_on(f).err().unwrap(); runtime.block_on(f).err().unwrap();
} }
#[test]
fn cancel_query() {
let mut runtime = Runtime::new().unwrap();
let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls);
let (mut client, connection) = runtime.block_on(connect).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let sleep = client
.batch_execute("SELECT pg_sleep(100)")
.then(|r| match r {
Ok(_) => panic!("unexpected success"),
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
Err(e) => panic!("unexpected error {}", e),
});
let cancel = Delay::new(Instant::now() + Duration::from_millis(100))
.then(|r| {
r.unwrap();
client.cancel_query(NoTls)
})
.then(|r| {
r.unwrap();
Ok::<(), ()>(())
});
let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap();
}