diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index c37776dd..cf796cec 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -17,10 +17,10 @@ use tokio_io::{AsyncRead, AsyncWrite}; #[cfg(feature = "runtime")] use crate::proto::ConnectFuture; -use crate::proto::{CancelQueryRawFuture, HandshakeFuture}; -use crate::{CancelData, CancelQueryRaw, Error, Handshake, TlsMode}; +use crate::proto::HandshakeFuture; #[cfg(feature = "runtime")] use crate::{Connect, MakeTlsMode, Socket}; +use crate::{Error, Handshake, TlsMode}; #[cfg(feature = "runtime")] #[derive(Debug, Copy, Clone, PartialEq)] @@ -267,7 +267,7 @@ impl Config { S: AsyncRead + AsyncWrite, T: TlsMode, { - Handshake(HandshakeFuture::new(stream, tls_mode, self.clone())) + Handshake(HandshakeFuture::new(stream, tls_mode, self.clone(), None)) } #[cfg(feature = "runtime")] @@ -277,19 +277,6 @@ impl Config { { Connect(ConnectFuture::new(make_tls_mode, Ok(self.clone()))) } - - pub fn cancel_query_raw( - &self, - stream: S, - tls_mode: T, - cancel_data: CancelData, - ) -> CancelQueryRaw - where - S: AsyncRead + AsyncWrite, - T: TlsMode, - { - CancelQueryRaw(CancelQueryRawFuture::new(stream, tls_mode, cancel_data)) - } } impl FromStr for Config { diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index 1ae2e60d..1a5317f5 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -206,6 +206,22 @@ impl Client { BatchExecute(self.0.batch_execute(query)) } + #[cfg(feature = "runtime")] + pub fn cancel_query(&mut self, make_tls_mode: T) -> CancelQuery + where + T: MakeTlsMode, + { + CancelQuery(self.0.cancel_query(make_tls_mode)) + } + + pub fn cancel_query_raw(&mut self, stream: S, tls_mode: T) -> CancelQueryRaw + where + S: AsyncRead + AsyncWrite, + T: TlsMode, + { + CancelQueryRaw(self.0.cancel_query_raw(stream, tls_mode)) + } + pub fn is_closed(&self) -> bool { self.0.is_closed() } @@ -222,10 +238,6 @@ impl Connection where S: AsyncRead + AsyncWrite, { - pub fn cancel_data(&self) -> CancelData { - self.0.cancel_data() - } - pub fn parameter(&self, name: &str) -> Option<&str> { self.0.parameter(name) } @@ -274,6 +286,25 @@ where } } +#[cfg(feature = "runtime")] +#[must_use = "futures do nothing unless polled"] +pub struct CancelQuery(proto::CancelQueryFuture) +where + T: MakeTlsMode; + +#[cfg(feature = "runtime")] +impl Future for CancelQuery +where + T: MakeTlsMode, +{ + type Item = (); + type Error = Error; + + fn poll(&mut self) -> Poll<(), Error> { + self.0.poll() + } +} + #[must_use = "futures do nothing unless polled"] pub struct Handshake(proto::HandshakeFuture) where @@ -478,15 +509,6 @@ impl Future for BatchExecute { } } -/// Contains information necessary to cancel queries for a session. -#[derive(Copy, Clone, Debug)] -pub struct CancelData { - /// The process ID of the session. - pub process_id: i32, - /// The secret key for the session. - pub secret_key: i32, -} - /// An asynchronous notification. #[derive(Clone, Debug)] pub struct Notification { diff --git a/tokio-postgres/src/proto/cancel_query.rs b/tokio-postgres/src/proto/cancel_query.rs new file mode 100644 index 00000000..61484803 --- /dev/null +++ b/tokio-postgres/src/proto/cancel_query.rs @@ -0,0 +1,105 @@ +use futures::{try_ready, Future, Poll}; +use state_machine_future::{transition, RentToOwn, StateMachineFuture}; +use std::io; + +use crate::proto::{CancelQueryRawFuture, ConnectSocketFuture}; +use crate::{Config, Error, Host, MakeTlsMode, Socket}; + +#[derive(StateMachineFuture)] +pub enum CancelQuery +where + T: MakeTlsMode, +{ + #[state_machine_future(start, transitions(ConnectingSocket))] + Start { + make_tls_mode: T, + idx: Option, + config: Config, + process_id: i32, + secret_key: i32, + }, + #[state_machine_future(transitions(Canceling))] + ConnectingSocket { + future: ConnectSocketFuture, + tls_mode: T::TlsMode, + process_id: i32, + secret_key: i32, + }, + #[state_machine_future(transitions(Finished))] + Canceling { + future: CancelQueryRawFuture, + }, + #[state_machine_future(ready)] + Finished(()), + #[state_machine_future(error)] + Failed(Error), +} + +impl PollCancelQuery for CancelQuery +where + T: MakeTlsMode, +{ + fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { + let mut state = state.take(); + + let idx = state.idx.ok_or_else(|| { + Error::connect(io::Error::new(io::ErrorKind::InvalidInput, "unknown host")) + })?; + + let hostname = match &state.config.0.host[idx] { + Host::Tcp(host) => &**host, + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter + #[cfg(unix)] + Host::Unix(_) => "", + }; + let tls_mode = state + .make_tls_mode + .make_tls_mode(hostname) + .map_err(|e| Error::tls(e.into()))?; + + transition!(ConnectingSocket { + future: ConnectSocketFuture::new(state.config, idx), + tls_mode, + process_id: state.process_id, + secret_key: state.secret_key, + }) + } + + fn poll_connecting_socket<'a>( + state: &'a mut RentToOwn<'a, ConnectingSocket>, + ) -> Poll, Error> { + let socket = try_ready!(state.future.poll()); + let state = state.take(); + + transition!(Canceling { + future: CancelQueryRawFuture::new( + socket, + state.tls_mode, + state.process_id, + state.secret_key + ), + }) + } + + fn poll_canceling<'a>( + state: &'a mut RentToOwn<'a, Canceling>, + ) -> Poll { + try_ready!(state.future.poll()); + transition!(Finished(())) + } +} + +impl CancelQueryFuture +where + T: MakeTlsMode, +{ + pub fn new( + make_tls_mode: T, + idx: Option, + config: Config, + process_id: i32, + secret_key: i32, + ) -> CancelQueryFuture { + CancelQuery::start(make_tls_mode, idx, config, process_id, secret_key) + } +} diff --git a/tokio-postgres/src/proto/cancel_query_raw.rs b/tokio-postgres/src/proto/cancel_query_raw.rs index 0b1c7534..ae2aee45 100644 --- a/tokio-postgres/src/proto/cancel_query_raw.rs +++ b/tokio-postgres/src/proto/cancel_query_raw.rs @@ -6,7 +6,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use crate::error::Error; use crate::proto::TlsFuture; -use crate::{CancelData, TlsMode}; +use crate::TlsMode; #[derive(StateMachineFuture)] pub enum CancelQueryRaw @@ -17,7 +17,8 @@ where #[state_machine_future(start, transitions(SendingCancel))] Start { future: TlsFuture, - cancel_data: CancelData, + process_id: i32, + secret_key: i32, }, #[state_machine_future(transitions(FlushingCancel))] SendingCancel { @@ -40,11 +41,7 @@ where let (stream, _) = try_ready!(state.future.poll()); let mut buf = vec![]; - frontend::cancel_request( - state.cancel_data.process_id, - state.cancel_data.secret_key, - &mut buf, - ); + frontend::cancel_request(state.process_id, state.secret_key, &mut buf); transition!(SendingCancel { future: io::write_all(stream, buf), @@ -74,7 +71,12 @@ where S: AsyncRead + AsyncWrite, T: TlsMode, { - pub fn new(stream: S, tls_mode: T, cancel_data: CancelData) -> CancelQueryRawFuture { - CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), cancel_data) + pub fn new( + stream: S, + tls_mode: T, + process_id: i32, + secret_key: i32, + ) -> CancelQueryRawFuture { + CancelQueryRaw::start(TlsFuture::new(stream, tls_mode), process_id, secret_key) } } diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index f41b5bfe..feb06bd2 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -8,6 +8,7 @@ use postgres_protocol::message::frontend; use std::collections::HashMap; use std::error::Error as StdError; use std::sync::{Arc, Weak}; +use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::bind::BindFuture; use crate::proto::connection::{Request, RequestMessages}; @@ -20,8 +21,13 @@ use crate::proto::prepare::PrepareFuture; use crate::proto::query::QueryStream; use crate::proto::simple_query::SimpleQueryStream; use crate::proto::statement::Statement; +#[cfg(feature = "runtime")] +use crate::proto::CancelQueryFuture; +use crate::proto::CancelQueryRawFuture; use crate::types::{IsNull, Oid, ToSql, Type}; -use crate::Error; +use crate::{Config, Error, TlsMode}; +#[cfg(feature = "runtime")] +use crate::{MakeTlsMode, Socket}; pub struct PendingRequest(Result<(RequestMessages, IdleGuard), Error>); @@ -44,13 +50,25 @@ struct Inner { state: Mutex, idle: IdleState, sender: mpsc::UnboundedSender, + process_id: i32, + secret_key: i32, + #[cfg_attr(not(feature = "runtime"), allow(dead_code))] + config: Config, + #[cfg_attr(not(feature = "runtime"), allow(dead_code))] + idx: Option, } #[derive(Clone)] pub struct Client(Arc); impl Client { - pub fn new(sender: mpsc::UnboundedSender) -> Client { + pub fn new( + sender: mpsc::UnboundedSender, + process_id: i32, + secret_key: i32, + config: Config, + idx: Option, + ) -> Client { Client(Arc::new(Inner { state: Mutex::new(State { types: HashMap::new(), @@ -60,6 +78,10 @@ impl Client { }), idle: IdleState::new(), sender, + process_id, + secret_key, + config, + idx, })) } @@ -222,6 +244,28 @@ impl Client { self.close(b'P', name) } + #[cfg(feature = "runtime")] + pub fn cancel_query(&self, make_tls_mode: T) -> CancelQueryFuture + where + T: MakeTlsMode, + { + CancelQueryFuture::new( + make_tls_mode, + self.0.idx, + self.0.config.clone(), + self.0.process_id, + self.0.secret_key, + ) + } + + pub fn cancel_query_raw(&self, stream: S, tls_mode: T) -> CancelQueryRawFuture + where + S: AsyncRead + AsyncWrite, + T: TlsMode, + { + CancelQueryRawFuture::new(stream, tls_mode, self.0.process_id, self.0.secret_key) + } + fn close(&self, ty: u8, name: &str) { let mut buf = vec![]; frontend::close(ty, name, &mut buf).expect("statement name not valid"); diff --git a/tokio-postgres/src/proto/connect_once.rs b/tokio-postgres/src/proto/connect_once.rs index 2180e14f..c784e2d7 100644 --- a/tokio-postgres/src/proto/connect_once.rs +++ b/tokio-postgres/src/proto/connect_once.rs @@ -65,7 +65,7 @@ where transition!(Handshaking { target_session_attrs: state.config.0.target_session_attrs, - future: HandshakeFuture::new(socket, state.tls_mode, state.config), + future: HandshakeFuture::new(socket, state.tls_mode, state.config, Some(state.idx)), }) } diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs index e4c80fa1..dd9f30fe 100644 --- a/tokio-postgres/src/proto/connection.rs +++ b/tokio-postgres/src/proto/connection.rs @@ -11,7 +11,7 @@ use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::codec::PostgresCodec; use crate::proto::copy_in::CopyInReceiver; use crate::proto::idle::IdleGuard; -use crate::{AsyncMessage, CancelData, Notification}; +use crate::{AsyncMessage, Notification}; use crate::{DbError, Error}; pub enum RequestMessages { @@ -42,7 +42,6 @@ enum State { pub struct Connection { stream: Framed, - cancel_data: CancelData, parameters: HashMap, receiver: mpsc::UnboundedReceiver, pending_request: Option, @@ -57,13 +56,11 @@ where { pub fn new( stream: Framed, - cancel_data: CancelData, parameters: HashMap, receiver: mpsc::UnboundedReceiver, ) -> Connection { Connection { stream, - cancel_data, parameters, receiver, pending_request: None, @@ -73,10 +70,6 @@ where } } - pub fn cancel_data(&self) -> CancelData { - self.cancel_data - } - pub fn parameter(&self, name: &str) -> Option<&str> { self.parameters.get(name).map(|s| &**s) } diff --git a/tokio-postgres/src/proto/handshake.rs b/tokio-postgres/src/proto/handshake.rs index 6b27f0ea..1d245ae1 100644 --- a/tokio-postgres/src/proto/handshake.rs +++ b/tokio-postgres/src/proto/handshake.rs @@ -8,12 +8,11 @@ use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use state_machine_future::{transition, RentToOwn, StateMachineFuture}; use std::collections::HashMap; -use std::io; use tokio_codec::Framed; use tokio_io::{AsyncRead, AsyncWrite}; use crate::proto::{Client, Connection, PostgresCodec, TlsFuture}; -use crate::{CancelData, ChannelBinding, Config, Error, TlsMode}; +use crate::{ChannelBinding, Config, Error, TlsMode}; #[derive(StateMachineFuture)] pub enum Handshake @@ -25,42 +24,56 @@ where Start { future: TlsFuture, config: Config, + idx: Option, }, #[state_machine_future(transitions(ReadingAuth))] SendingStartup { future: sink::Send>, config: Config, + idx: Option, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] ReadingAuth { stream: Framed, config: Config, + idx: Option, channel_binding: ChannelBinding, }, #[state_machine_future(transitions(ReadingAuthCompletion))] SendingPassword { future: sink::Send>, + config: Config, + idx: Option, }, #[state_machine_future(transitions(ReadingSasl))] SendingSasl { future: sink::Send>, scram: ScramSha256, + config: Config, + idx: Option, }, #[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))] ReadingSasl { stream: Framed, scram: ScramSha256, + config: Config, + idx: Option, }, #[state_machine_future(transitions(ReadingInfo))] ReadingAuthCompletion { stream: Framed, + config: Config, + idx: Option, }, #[state_machine_future(transitions(Finished))] ReadingInfo { stream: Framed, - cancel_data: Option, + process_id: i32, + secret_key: i32, parameters: HashMap, + config: Config, + idx: Option, }, #[state_machine_future(ready)] Finished((Client, Connection)), @@ -99,6 +112,7 @@ where transition!(SendingStartup { future: stream.send(buf), config: state.config, + idx: state.idx, channel_binding, }) } @@ -111,6 +125,7 @@ where transition!(ReadingAuth { stream, config: state.config, + idx: state.idx, channel_binding: state.channel_binding, }) } @@ -124,8 +139,11 @@ where match message { Some(Message::AuthenticationOk) => transition!(ReadingInfo { stream: state.stream, - cancel_data: None, + process_id: 0, + secret_key: 0, parameters: HashMap::new(), + config: state.config, + idx: state.idx, }), Some(Message::AuthenticationCleartextPassword) => { let pass = state @@ -137,7 +155,9 @@ where let mut buf = vec![]; frontend::password_message(pass, &mut buf).map_err(Error::encode)?; transition!(SendingPassword { - future: state.stream.send(buf) + future: state.stream.send(buf), + config: state.config, + idx: state.idx, }) } Some(Message::AuthenticationMd5Password(body)) => { @@ -157,7 +177,9 @@ where let mut buf = vec![]; frontend::password_message(output.as_bytes(), &mut buf).map_err(Error::encode)?; transition!(SendingPassword { - future: state.stream.send(buf) + future: state.stream.send(buf), + config: state.config, + idx: state.idx, }) } Some(Message::AuthenticationSasl(body)) => { @@ -214,6 +236,8 @@ where transition!(SendingSasl { future: state.stream.send(buf), scram, + config: state.config, + idx: state.idx, }) } Some(Message::AuthenticationKerberosV5) @@ -232,7 +256,12 @@ where state: &'a mut RentToOwn<'a, SendingPassword>, ) -> Poll, Error> { let stream = try_ready!(state.future.poll().map_err(Error::io)); - transition!(ReadingAuthCompletion { stream }) + let state = state.take(); + transition!(ReadingAuthCompletion { + stream, + config: state.config, + idx: state.idx, + }) } fn poll_sending_sasl<'a>( @@ -243,6 +272,8 @@ where transition!(ReadingSasl { stream, scram: state.scram, + config: state.config, + idx: state.idx, }) } @@ -263,6 +294,8 @@ where transition!(SendingSasl { future: state.stream.send(buf), scram: state.scram, + config: state.config, + idx: state.idx, }) } Some(Message::AuthenticationSaslFinal(body)) => { @@ -271,7 +304,9 @@ where .finish(body.data()) .map_err(|e| Error::authentication(Box::new(e)))?; transition!(ReadingAuthCompletion { - stream: state.stream + stream: state.stream, + config: state.config, + idx: state.idx, }) } Some(Message::ErrorResponse(body)) => Err(Error::db(body)), @@ -289,8 +324,11 @@ where match message { Some(Message::AuthenticationOk) => transition!(ReadingInfo { stream: state.stream, - cancel_data: None, - parameters: HashMap::new() + process_id: 0, + secret_key: 0, + parameters: HashMap::new(), + config: state.config, + idx: state.idx, }), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), Some(_) => Err(Error::unexpected_message()), @@ -305,10 +343,8 @@ where let message = try_ready!(state.stream.poll().map_err(Error::io)); match message { Some(Message::BackendKeyData(body)) => { - state.cancel_data = Some(CancelData { - process_id: body.process_id(), - secret_key: body.secret_key(), - }); + state.process_id = body.process_id(); + state.secret_key = body.secret_key(); } Some(Message::ParameterStatus(body)) => { state.parameters.insert( @@ -318,16 +354,15 @@ where } Some(Message::ReadyForQuery(_)) => { let state = state.take(); - let cancel_data = state.cancel_data.ok_or_else(|| { - Error::parse(io::Error::new( - io::ErrorKind::InvalidData, - "BackendKeyData message missing", - )) - })?; let (sender, receiver) = mpsc::unbounded(); - let client = Client::new(sender); - let connection = - Connection::new(state.stream, cancel_data, state.parameters, receiver); + let client = Client::new( + sender, + state.process_id, + state.secret_key, + state.config, + state.idx, + ); + let connection = Connection::new(state.stream, state.parameters, receiver); transition!(Finished((client, connection))) } Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), @@ -344,7 +379,12 @@ where S: AsyncRead + AsyncWrite, T: TlsMode, { - pub fn new(stream: S, tls_mode: T, config: Config) -> HandshakeFuture { - Handshake::start(TlsFuture::new(stream, tls_mode), config) + pub fn new( + stream: S, + tls_mode: T, + config: Config, + idx: Option, + ) -> HandshakeFuture { + Handshake::start(TlsFuture::new(stream, tls_mode), config, idx) } } diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index ceeffd07..aff4a0a9 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -19,6 +19,8 @@ macro_rules! try_ready_closed { } mod bind; +#[cfg(feature = "runtime")] +mod cancel_query; mod cancel_query_raw; mod client; mod codec; @@ -46,6 +48,8 @@ mod typeinfo_composite; mod typeinfo_enum; pub use crate::proto::bind::BindFuture; +#[cfg(feature = "runtime")] +pub use crate::proto::cancel_query::CancelQueryFuture; pub use crate::proto::cancel_query_raw::CancelQueryRawFuture; pub use crate::proto::client::Client; pub use crate::proto::codec::PostgresCodec; diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index f21e0d7f..e68f697c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -222,12 +222,11 @@ fn query_portal() { } #[test] -fn cancel_query() { +fn cancel_query_raw() { let _ = env_logger::try_init(); let mut runtime = Runtime::new().unwrap(); let (mut client, connection) = runtime.block_on(connect("user=postgres")).unwrap(); - let cancel_data = connection.cancel_data(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.handle().spawn(connection).unwrap(); @@ -245,7 +244,7 @@ fn cancel_query() { }) .then(|r| { let s = r.unwrap(); - tokio_postgres::Config::new().cancel_query_raw(s, NoTls, cancel_data) + client.cancel_query_raw(s, NoTls) }) .then(|r| { r.unwrap(); diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index bdcd98d6..691a5161 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -1,6 +1,8 @@ use futures::Future; +use std::time::{Duration, Instant}; use tokio::runtime::current_thread::Runtime; -use tokio_postgres::NoTls; +use tokio::timer::Delay; +use tokio_postgres::{NoTls, SqlState}; fn smoke_test(s: &str) { let mut runtime = Runtime::new().unwrap(); @@ -67,3 +69,32 @@ fn target_session_attrs_err() { ); runtime.block_on(f).err().unwrap(); } + +#[test] +fn cancel_query() { + let mut runtime = Runtime::new().unwrap(); + + let connect = tokio_postgres::connect("host=localhost port=5433 user=postgres", NoTls); + let (mut client, connection) = runtime.block_on(connect).unwrap(); + let connection = connection.map_err(|e| panic!("{}", e)); + runtime.spawn(connection); + + let sleep = client + .batch_execute("SELECT pg_sleep(100)") + .then(|r| match r { + Ok(_) => panic!("unexpected success"), + Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()), + Err(e) => panic!("unexpected error {}", e), + }); + let cancel = Delay::new(Instant::now() + Duration::from_millis(100)) + .then(|r| { + r.unwrap(); + client.cancel_query(NoTls) + }) + .then(|r| { + r.unwrap(); + Ok::<(), ()>(()) + }); + + let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap(); +}