Overhaul query cancellation
Multi-host support means we can't simply take the old approach - we need to know which of the hosts we actually connected to. It's also nice to move this from the connection to the client since that's what you'd normally have access to.
This commit is contained in:
parent
a6535b4310
commit
1f6d9ddc06
@ -17,10 +17,10 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::proto::ConnectFuture;
|
||||
use crate::proto::{CancelQueryRawFuture, HandshakeFuture};
|
||||
use crate::{CancelData, CancelQueryRaw, Error, Handshake, TlsMode};
|
||||
use crate::proto::HandshakeFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::{Connect, MakeTlsMode, Socket};
|
||||
use crate::{Error, Handshake, TlsMode};
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
@ -267,7 +267,7 @@ impl Config {
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone()))
|
||||
Handshake(HandshakeFuture::new(stream, tls_mode, self.clone(), None))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
@ -277,19 +277,6 @@ impl Config {
|
||||
{
|
||||
Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone())))
|
||||
}
|
||||
|
||||
pub fn cancel_query_raw<S, T>(
|
||||
&self,
|
||||
stream: S,
|
||||
tls_mode: T,
|
||||
cancel_data: CancelData,
|
||||
) -> CancelQueryRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
CancelQueryRaw(CancelQueryRawFuture::new(stream, tls_mode, cancel_data))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
|
@ -206,6 +206,22 @@ impl Client {
|
||||
BatchExecute(self.0.batch_execute(query))
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn cancel_query<T>(&mut self, make_tls_mode: T) -> CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
CancelQuery(self.0.cancel_query(make_tls_mode))
|
||||
}
|
||||
|
||||
pub fn cancel_query_raw<S, T>(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode))
|
||||
}
|
||||
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.0.is_closed()
|
||||
}
|
||||
@ -222,10 +238,6 @@ impl<S> Connection<S>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
{
|
||||
pub fn cancel_data(&self) -> CancelData {
|
||||
self.0.cancel_data()
|
||||
}
|
||||
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
self.0.parameter(name)
|
||||
}
|
||||
@ -274,6 +286,25 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct CancelQuery<T>(proto::CancelQueryFuture<T>)
|
||||
where
|
||||
T: MakeTlsMode<Socket>;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<T> Future for CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
type Item = ();
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(), Error> {
|
||||
self.0.poll()
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Handshake<S, T>(proto::HandshakeFuture<S, T>)
|
||||
where
|
||||
@ -478,15 +509,6 @@ impl Future for BatchExecute {
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains information necessary to cancel queries for a session.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct CancelData {
|
||||
/// The process ID of the session.
|
||||
pub process_id: i32,
|
||||
/// The secret key for the session.
|
||||
pub secret_key: i32,
|
||||
}
|
||||
|
||||
/// An asynchronous notification.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Notification {
|
||||
|
105
tokio-postgres/src/proto/cancel_query.rs
Normal file
105
tokio-postgres/src/proto/cancel_query.rs
Normal file
@ -0,0 +1,105 @@
|
||||
use futures::{try_ready, Future, Poll};
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::io;
|
||||
|
||||
use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture};
|
||||
use crate::{Config, Error, Host, MakeTlsMode, Socket};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
#[state_machine_future(start, transitions(ConnectingSocket))]
|
||||
Start {
|
||||
make_tls_mode: T,
|
||||
idx: Option<usize>,
|
||||
config: Config,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
},
|
||||
#[state_machine_future(transitions(Canceling))]
|
||||
ConnectingSocket {
|
||||
future: ConnectSocketFuture,
|
||||
tls_mode: T::TlsMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
Canceling {
|
||||
future: CancelQueryRawFuture<Socket, T::TlsMode>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished(()),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
|
||||
impl<T> PollCancelQuery<T> for CancelQuery<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start<T>>) -> Poll<AfterStart<T>, Error> {
|
||||
let mut state = state.take();
|
||||
|
||||
let idx = state.idx.ok_or_else(|| {
|
||||
Error::connect(io::Error::new(io::ErrorKind::InvalidInput, "unknown host"))
|
||||
})?;
|
||||
|
||||
let hostname = match &state.config.0.host[idx] {
|
||||
Host::Tcp(host) => &**host,
|
||||
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
let tls_mode = state
|
||||
.make_tls_mode
|
||||
.make_tls_mode(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
transition!(ConnectingSocket {
|
||||
future: ConnectSocketFuture::new(state.config, idx),
|
||||
tls_mode,
|
||||
process_id: state.process_id,
|
||||
secret_key: state.secret_key,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_connecting_socket<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingSocket<T>>,
|
||||
) -> Poll<AfterConnectingSocket<T>, Error> {
|
||||
let socket = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
|
||||
transition!(Canceling {
|
||||
future: CancelQueryRawFuture::new(
|
||||
socket,
|
||||
state.tls_mode,
|
||||
state.process_id,
|
||||
state.secret_key
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_canceling<'a>(
|
||||
state: &'a mut RentToOwn<'a, Canceling<T>>,
|
||||
) -> Poll<AfterCanceling, Error> {
|
||||
try_ready!(state.future.poll());
|
||||
transition!(Finished(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> CancelQueryFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
pub fn new(
|
||||
make_tls_mode: T,
|
||||
idx: Option<usize>,
|
||||
config: Config,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> CancelQueryFuture<T> {
|
||||
CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key)
|
||||
}
|
||||
}
|
@ -6,7 +6,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::proto::TlsFuture;
|
||||
use crate::{CancelData, TlsMode};
|
||||
use crate::TlsMode;
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum CancelQueryRaw<S, T>
|
||||
@ -17,7 +17,8 @@ where
|
||||
#[state_machine_future(start, transitions(SendingCancel))]
|
||||
Start {
|
||||
future: TlsFuture<S, T>,
|
||||
cancel_data: CancelData,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
},
|
||||
#[state_machine_future(transitions(FlushingCancel))]
|
||||
SendingCancel {
|
||||
@ -40,11 +41,7 @@ where
|
||||
let (stream, _) = try_ready!(state.future.poll());
|
||||
|
||||
let mut buf = vec![];
|
||||
frontend::cancel_request(
|
||||
state.cancel_data.process_id,
|
||||
state.cancel_data.secret_key,
|
||||
&mut buf,
|
||||
);
|
||||
frontend::cancel_request(state.process_id, state.secret_key, &mut buf);
|
||||
|
||||
transition!(SendingCancel {
|
||||
future: io::write_all(stream, buf),
|
||||
@ -74,7 +71,12 @@ where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
pub fn new(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQueryRawFuture<S, T> {
|
||||
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), cancel_data)
|
||||
pub fn new(
|
||||
stream: S,
|
||||
tls_mode: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> CancelQueryRawFuture<S, T> {
|
||||
CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key)
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ use postgres_protocol::message::frontend;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error as StdError;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::bind::BindFuture;
|
||||
use crate::proto::connection::{Request, RequestMessages};
|
||||
@ -20,8 +21,13 @@ use crate::proto::prepare::PrepareFuture;
|
||||
use crate::proto::query::QueryStream;
|
||||
use crate::proto::simple_query::SimpleQueryStream;
|
||||
use crate::proto::statement::Statement;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::proto::CancelQueryFuture;
|
||||
use crate::proto::CancelQueryRawFuture;
|
||||
use crate::types::{IsNull, Oid, ToSql, Type};
|
||||
use crate::Error;
|
||||
use crate::{Config, Error, TlsMode};
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::{MakeTlsMode, Socket};
|
||||
|
||||
pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>);
|
||||
|
||||
@ -44,13 +50,25 @@ struct Inner {
|
||||
state: Mutex<State>,
|
||||
idle: IdleState,
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
|
||||
config: Config,
|
||||
#[cfg_attr(not(feature = "runtime"), allow(dead_code))]
|
||||
idx: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Client(Arc<Inner>);
|
||||
|
||||
impl Client {
|
||||
pub fn new(sender: mpsc::UnboundedSender<Request>) -> Client {
|
||||
pub fn new(
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
) -> Client {
|
||||
Client(Arc::new(Inner {
|
||||
state: Mutex::new(State {
|
||||
types: HashMap::new(),
|
||||
@ -60,6 +78,10 @@ impl Client {
|
||||
}),
|
||||
idle: IdleState::new(),
|
||||
sender,
|
||||
process_id,
|
||||
secret_key,
|
||||
config,
|
||||
idx,
|
||||
}))
|
||||
}
|
||||
|
||||
@ -222,6 +244,28 @@ impl Client {
|
||||
self.close(b'P', name)
|
||||
}
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
pub fn cancel_query<T>(&self, make_tls_mode: T) -> CancelQueryFuture<T>
|
||||
where
|
||||
T: MakeTlsMode<Socket>,
|
||||
{
|
||||
CancelQueryFuture::new(
|
||||
make_tls_mode,
|
||||
self.0.idx,
|
||||
self.0.config.clone(),
|
||||
self.0.process_id,
|
||||
self.0.secret_key,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cancel_query_raw<S, T>(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key)
|
||||
}
|
||||
|
||||
fn close(&self, ty: u8, name: &str) {
|
||||
let mut buf = vec![];
|
||||
frontend::close(ty, name, &mut buf).expect("statement name not valid");
|
||||
|
@ -65,7 +65,7 @@ where
|
||||
|
||||
transition!(Handshaking {
|
||||
target_session_attrs: state.config.0.target_session_attrs,
|
||||
future: HandshakeFuture::new(socket, state.tls_mode, state.config),
|
||||
future: HandshakeFuture::new(socket, state.tls_mode, state.config, Some(state.idx)),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -11,7 +11,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use crate::proto::codec::PostgresCodec;
|
||||
use crate::proto::copy_in::CopyInReceiver;
|
||||
use crate::proto::idle::IdleGuard;
|
||||
use crate::{AsyncMessage, CancelData, Notification};
|
||||
use crate::{AsyncMessage, Notification};
|
||||
use crate::{DbError, Error};
|
||||
|
||||
pub enum RequestMessages {
|
||||
@ -42,7 +42,6 @@ enum State {
|
||||
|
||||
pub struct Connection<S> {
|
||||
stream: Framed<S, PostgresCodec>,
|
||||
cancel_data: CancelData,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
@ -57,13 +56,11 @@ where
|
||||
{
|
||||
pub fn new(
|
||||
stream: Framed<S, PostgresCodec>,
|
||||
cancel_data: CancelData,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
) -> Connection<S> {
|
||||
Connection {
|
||||
stream,
|
||||
cancel_data,
|
||||
parameters,
|
||||
receiver,
|
||||
pending_request: None,
|
||||
@ -73,10 +70,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancel_data(&self) -> CancelData {
|
||||
self.cancel_data
|
||||
}
|
||||
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
self.parameters.get(name).map(|s| &**s)
|
||||
}
|
||||
|
@ -8,12 +8,11 @@ use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use state_machine_future::{transition, RentToOwn, StateMachineFuture};
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use tokio_codec::Framed;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use crate::proto::{Client, Connection, PostgresCodec, TlsFuture};
|
||||
use crate::{CancelData, ChannelBinding, Config, Error, TlsMode};
|
||||
use crate::{ChannelBinding, Config, Error, TlsMode};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Handshake<S, T>
|
||||
@ -25,42 +24,56 @@ where
|
||||
Start {
|
||||
future: TlsFuture<S, T>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuth))]
|
||||
SendingStartup {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
|
||||
ReadingAuth {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuthCompletion))]
|
||||
SendingPassword {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingSasl))]
|
||||
SendingSasl {
|
||||
future: sink::Send<Framed<T::Stream, PostgresCodec>>,
|
||||
scram: ScramSha256,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
|
||||
ReadingSasl {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
scram: ScramSha256,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo))]
|
||||
ReadingAuthCompletion {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
ReadingInfo {
|
||||
stream: Framed<T::Stream, PostgresCodec>,
|
||||
cancel_data: Option<CancelData>,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
parameters: HashMap<String, String>,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((Client, Connection<T::Stream>)),
|
||||
@ -99,6 +112,7 @@ where
|
||||
transition!(SendingStartup {
|
||||
future: stream.send(buf),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
channel_binding,
|
||||
})
|
||||
}
|
||||
@ -111,6 +125,7 @@ where
|
||||
transition!(ReadingAuth {
|
||||
stream,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
channel_binding: state.channel_binding,
|
||||
})
|
||||
}
|
||||
@ -124,8 +139,11 @@ where
|
||||
match message {
|
||||
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
|
||||
stream: state.stream,
|
||||
cancel_data: None,
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
parameters: HashMap::new(),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
}),
|
||||
Some(Message::AuthenticationCleartextPassword) => {
|
||||
let pass = state
|
||||
@ -137,7 +155,9 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(pass, &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf)
|
||||
future: state.stream.send(buf),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
@ -157,7 +177,9 @@ where
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf)
|
||||
future: state.stream.send(buf),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
@ -214,6 +236,8 @@ where
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationKerberosV5)
|
||||
@ -232,7 +256,12 @@ where
|
||||
state: &'a mut RentToOwn<'a, SendingPassword<S, T>>,
|
||||
) -> Poll<AfterSendingPassword<S, T>, Error> {
|
||||
let stream = try_ready!(state.future.poll().map_err(Error::io));
|
||||
transition!(ReadingAuthCompletion { stream })
|
||||
let state = state.take();
|
||||
transition!(ReadingAuthCompletion {
|
||||
stream,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_sending_sasl<'a>(
|
||||
@ -243,6 +272,8 @@ where
|
||||
transition!(ReadingSasl {
|
||||
stream,
|
||||
scram: state.scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
|
||||
@ -263,6 +294,8 @@ where
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
scram: state.scram,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationSaslFinal(body)) => {
|
||||
@ -271,7 +304,9 @@ where
|
||||
.finish(body.data())
|
||||
.map_err(|e| Error::authentication(Box::new(e)))?;
|
||||
transition!(ReadingAuthCompletion {
|
||||
stream: state.stream
|
||||
stream: state.stream,
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
})
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
|
||||
@ -289,8 +324,11 @@ where
|
||||
match message {
|
||||
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
|
||||
stream: state.stream,
|
||||
cancel_data: None,
|
||||
parameters: HashMap::new()
|
||||
process_id: 0,
|
||||
secret_key: 0,
|
||||
parameters: HashMap::new(),
|
||||
config: state.config,
|
||||
idx: state.idx,
|
||||
}),
|
||||
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
|
||||
Some(_) => Err(Error::unexpected_message()),
|
||||
@ -305,10 +343,8 @@ where
|
||||
let message = try_ready!(state.stream.poll().map_err(Error::io));
|
||||
match message {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
state.cancel_data = Some(CancelData {
|
||||
process_id: body.process_id(),
|
||||
secret_key: body.secret_key(),
|
||||
});
|
||||
state.process_id = body.process_id();
|
||||
state.secret_key = body.secret_key();
|
||||
}
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
state.parameters.insert(
|
||||
@ -318,16 +354,15 @@ where
|
||||
}
|
||||
Some(Message::ReadyForQuery(_)) => {
|
||||
let state = state.take();
|
||||
let cancel_data = state.cancel_data.ok_or_else(|| {
|
||||
Error::parse(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"BackendKeyData message missing",
|
||||
))
|
||||
})?;
|
||||
let (sender, receiver) = mpsc::unbounded();
|
||||
let client = Client::new(sender);
|
||||
let connection =
|
||||
Connection::new(state.stream, cancel_data, state.parameters, receiver);
|
||||
let client = Client::new(
|
||||
sender,
|
||||
state.process_id,
|
||||
state.secret_key,
|
||||
state.config,
|
||||
state.idx,
|
||||
);
|
||||
let connection = Connection::new(state.stream, state.parameters, receiver);
|
||||
transition!(Finished((client, connection)))
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
@ -344,7 +379,12 @@ where
|
||||
S: AsyncRead + AsyncWrite,
|
||||
T: TlsMode<S>,
|
||||
{
|
||||
pub fn new(stream: S, tls_mode: T, config: Config) -> HandshakeFuture<S, T> {
|
||||
Handshake::start(TlsFuture::new(stream, tls_mode), config)
|
||||
pub fn new(
|
||||
stream: S,
|
||||
tls_mode: T,
|
||||
config: Config,
|
||||
idx: Option<usize>,
|
||||
) -> HandshakeFuture<S, T> {
|
||||
Handshake::start(TlsFuture::new(stream, tls_mode), config, idx)
|
||||
}
|
||||
}
|
||||
|
@ -19,6 +19,8 @@ macro_rules! try_ready_closed {
|
||||
}
|
||||
|
||||
mod bind;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod cancel_query;
|
||||
mod cancel_query_raw;
|
||||
mod client;
|
||||
mod codec;
|
||||
@ -46,6 +48,8 @@ mod typeinfo_composite;
|
||||
mod typeinfo_enum;
|
||||
|
||||
pub use crate::proto::bind::BindFuture;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::proto::cancel_query::CancelQueryFuture;
|
||||
pub use crate::proto::cancel_query_raw::CancelQueryRawFuture;
|
||||
pub use crate::proto::client::Client;
|
||||
pub use crate::proto::codec::PostgresCodec;
|
||||
|
@ -222,12 +222,11 @@ fn query_portal() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cancel_query() {
|
||||
fn cancel_query_raw() {
|
||||
let _ = env_logger::try_init();
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap();
|
||||
let cancel_data = connection.cancel_data();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.handle().spawn(connection).unwrap();
|
||||
|
||||
@ -245,7 +244,7 @@ fn cancel_query() {
|
||||
})
|
||||
.then(|r| {
|
||||
let s = r.unwrap();
|
||||
tokio_postgres::Config::new().cancel_query_raw(s, NoTls, cancel_data)
|
||||
client.cancel_query_raw(s, NoTls)
|
||||
})
|
||||
.then(|r| {
|
||||
r.unwrap();
|
||||
|
@ -1,6 +1,8 @@
|
||||
use futures::Future;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio::timer::Delay;
|
||||
use tokio_postgres::{NoTls, SqlState};
|
||||
|
||||
fn smoke_test(s: &str) {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
@ -67,3 +69,32 @@ fn target_session_attrs_err() {
|
||||
);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cancel_query() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
|
||||
let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls);
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
|
||||
let sleep = client
|
||||
.batch_execute("SELECT pg_sleep(100)")
|
||||
.then(|r| match r {
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()),
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
});
|
||||
let cancel = Delay::new(Instant::now() + Duration::from_millis(100))
|
||||
.then(|r| {
|
||||
r.unwrap();
|
||||
client.cancel_query(NoTls)
|
||||
})
|
||||
.then(|r| {
|
||||
r.unwrap();
|
||||
Ok::<(), ()>(())
|
||||
});
|
||||
|
||||
let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user