diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs index 99ca2100..18b19cf2 100644 --- a/tokio-postgres/src/proto/client.rs +++ b/tokio-postgres/src/proto/client.rs @@ -4,7 +4,7 @@ use postgres_protocol; use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use disconnected; use error::{self, Error}; @@ -15,49 +15,91 @@ use proto::query::QueryStream; use proto::statement::Statement; use types::{IsNull, Oid, ToSql, Type}; -pub struct PendingRequest { - sender: mpsc::UnboundedSender, - messages: Result, Error>, -} +pub struct PendingRequest(Result, Error>); -impl PendingRequest { - pub fn send(self) -> Result, Error> { - let messages = self.messages?; - let (sender, receiver) = mpsc::channel(0); - self.sender - .unbounded_send(Request { messages, sender }) - .map(|_| receiver) - .map_err(|_| disconnected()) +pub struct WeakClient(Weak); + +impl WeakClient { + pub fn upgrade(&self) -> Option { + self.0.upgrade().map(Client) } } -pub struct State { - pub types: HashMap, - pub typeinfo_query: Option, - pub typeinfo_enum_query: Option, - pub typeinfo_composite_query: Option, +struct State { + types: HashMap, + typeinfo_query: Option, + typeinfo_enum_query: Option, + typeinfo_composite_query: Option, } -#[derive(Clone)] -pub struct Client { - pub state: Arc>, +struct Inner { + state: Mutex, sender: mpsc::UnboundedSender, } +#[derive(Clone)] +pub struct Client(Arc); + impl Client { pub fn new(sender: mpsc::UnboundedSender) -> Client { - Client { - state: Arc::new(Mutex::new(State { + Client(Arc::new(Inner { + state: Mutex::new(State { types: HashMap::new(), typeinfo_query: None, typeinfo_enum_query: None, typeinfo_composite_query: None, - })), + }), sender, - } + })) } - pub fn prepare(&mut self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture { + pub fn downgrade(&self) -> WeakClient { + WeakClient(Arc::downgrade(&self.0)) + } + + pub fn cached_type(&self, oid: Oid) -> Option { + self.0.state.lock().types.get(&oid).cloned() + } + + pub fn cache_type(&self, ty: &Type) { + self.0.state.lock().types.insert(ty.oid(), ty.clone()); + } + + pub fn typeinfo_query(&self) -> Option { + self.0.state.lock().typeinfo_query.clone() + } + + pub fn set_typeinfo_query(&self, statement: &Statement) { + self.0.state.lock().typeinfo_query = Some(statement.clone()); + } + + pub fn typeinfo_enum_query(&self) -> Option { + self.0.state.lock().typeinfo_enum_query.clone() + } + + pub fn set_typeinfo_enum_query(&self, statement: &Statement) { + self.0.state.lock().typeinfo_enum_query = Some(statement.clone()); + } + + pub fn typeinfo_composite_query(&self) -> Option { + self.0.state.lock().typeinfo_composite_query.clone() + } + + pub fn set_typeinfo_composite_query(&self, statement: &Statement) { + self.0.state.lock().typeinfo_composite_query = Some(statement.clone()); + } + + pub fn send(&self, request: PendingRequest) -> Result, Error> { + let messages = request.0?; + let (sender, receiver) = mpsc::channel(0); + self.0 + .sender + .unbounded_send(Request { messages, sender }) + .map(|_| receiver) + .map_err(|_| disconnected()) + } + + pub fn prepare(&self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture { let pending = self.pending(|buf| { frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), buf)?; frontend::describe(b'S', &name, buf)?; @@ -65,17 +107,28 @@ impl Client { Ok(()) }); - PrepareFuture::new(pending, self.sender.clone(), name, self.clone()) + PrepareFuture::new(self.clone(), pending, name) } - pub fn execute(&mut self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture { + pub fn execute(&self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture { let pending = self.pending_execute(statement, params); - ExecuteFuture::new(pending, statement.clone()) + ExecuteFuture::new(self.clone(), pending, statement.clone()) } - pub fn query(&mut self, statement: &Statement, params: &[&ToSql]) -> QueryStream { + pub fn query(&self, statement: &Statement, params: &[&ToSql]) -> QueryStream { let pending = self.pending_execute(statement, params); - QueryStream::new(pending, statement.clone()) + QueryStream::new(self.clone(), pending, statement.clone()) + } + + pub fn close_statement(&self, name: &str) { + let mut buf = vec![]; + frontend::close(b'S', name, &mut buf).expect("statement name not valid"); + frontend::sync(&mut buf); + let (sender, _) = mpsc::channel(0); + let _ = self.0.sender.unbounded_send(Request { + messages: buf, + sender, + }); } fn pending_execute(&self, statement: &Statement, params: &[&ToSql]) -> PendingRequest { @@ -109,9 +162,6 @@ impl Client { F: FnOnce(&mut Vec) -> Result<(), Error>, { let mut buf = vec![]; - PendingRequest { - sender: self.sender.clone(), - messages: messages(&mut buf).map(|()| buf), - } + PendingRequest(messages(&mut buf).map(|()| buf)) } } diff --git a/tokio-postgres/src/proto/execute.rs b/tokio-postgres/src/proto/execute.rs index 012e334c..67315c88 100644 --- a/tokio-postgres/src/proto/execute.rs +++ b/tokio-postgres/src/proto/execute.rs @@ -4,7 +4,7 @@ use postgres_protocol::message::backend::Message; use state_machine_future::RentToOwn; use error::{self, Error}; -use proto::client::PendingRequest; +use proto::client::{Client, PendingRequest}; use proto::statement::Statement; use {bad_response, disconnected}; @@ -12,6 +12,7 @@ use {bad_response, disconnected}; pub enum Execute { #[state_machine_future(start, transitions(ReadResponse))] Start { + client: Client, request: PendingRequest, statement: Statement, }, @@ -31,7 +32,7 @@ pub enum Execute { impl PollExecute for Execute { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { let state = state.take(); - let receiver = state.request.send()?; + let receiver = state.client.send(state.request)?; // the statement can drop after this point, since its close will queue up after the execution transition!(ReadResponse { receiver }) @@ -82,7 +83,7 @@ impl PollExecute for Execute { } impl ExecuteFuture { - pub fn new(request: PendingRequest, statement: Statement) -> ExecuteFuture { - Execute::start(request, statement) + pub fn new(client: Client, request: PendingRequest, statement: Statement) -> ExecuteFuture { + Execute::start(client, request, statement) } } diff --git a/tokio-postgres/src/proto/prepare.rs b/tokio-postgres/src/proto/prepare.rs index 70fe3b50..2252189e 100644 --- a/tokio-postgres/src/proto/prepare.rs +++ b/tokio-postgres/src/proto/prepare.rs @@ -8,7 +8,6 @@ use std::vec; use error::{self, Error}; use proto::client::{Client, PendingRequest}; -use proto::connection::Request; use proto::statement::Statement; use proto::typeinfo::TypeinfoFuture; use types::{Oid, Type}; @@ -19,47 +18,41 @@ use {bad_response, disconnected}; pub enum Prepare { #[state_machine_future(start, transitions(ReadParseComplete))] Start { - request: PendingRequest, - sender: mpsc::UnboundedSender, - name: String, client: Client, + request: PendingRequest, + name: String, }, #[state_machine_future(transitions(ReadParameterDescription))] ReadParseComplete { - sender: mpsc::UnboundedSender, + client: Client, receiver: mpsc::Receiver, name: String, - client: Client, }, #[state_machine_future(transitions(ReadRowDescription))] ReadParameterDescription { - sender: mpsc::UnboundedSender, + client: Client, receiver: mpsc::Receiver, name: String, - client: Client, }, #[state_machine_future(transitions(ReadReadyForQuery))] ReadRowDescription { - sender: mpsc::UnboundedSender, + client: Client, receiver: mpsc::Receiver, name: String, parameters: Vec, - client: Client, }, #[state_machine_future(transitions(GetParameterTypes, GetColumnTypes, Finished))] ReadReadyForQuery { - sender: mpsc::UnboundedSender, + client: Client, receiver: mpsc::Receiver, name: String, parameters: Vec, columns: Vec<(String, Oid)>, - client: Client, }, #[state_machine_future(transitions(GetColumnTypes, Finished))] GetParameterTypes { future: TypeinfoFuture, remaining_parameters: vec::IntoIter, - sender: mpsc::UnboundedSender, name: String, parameters: Vec, columns: Vec<(String, Oid)>, @@ -69,7 +62,6 @@ pub enum Prepare { future: TypeinfoFuture, cur_column_name: String, remaining_columns: vec::IntoIter<(String, Oid)>, - sender: mpsc::UnboundedSender, name: String, parameters: Vec, columns: Vec, @@ -83,10 +75,9 @@ pub enum Prepare { impl PollPrepare for Prepare { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { let state = state.take(); - let receiver = state.request.send()?; + let receiver = state.client.send(state.request)?; transition!(ReadParseComplete { - sender: state.sender, receiver, name: state.name, client: state.client, @@ -101,7 +92,6 @@ impl PollPrepare for Prepare { match message { Some(Message::ParseComplete) => transition!(ReadParameterDescription { - sender: state.sender, receiver: state.receiver, name: state.name, client: state.client, @@ -120,7 +110,6 @@ impl PollPrepare for Prepare { match message { Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription { - sender: state.sender, receiver: state.receiver, name: state.name, parameters: body.parameters().collect()?, @@ -148,7 +137,6 @@ impl PollPrepare for Prepare { }; transition!(ReadReadyForQuery { - sender: state.sender, receiver: state.receiver, name: state.name, parameters: state.parameters, @@ -174,7 +162,6 @@ impl PollPrepare for Prepare { transition!(GetParameterTypes { future: TypeinfoFuture::new(oid, state.client), remaining_parameters: parameters, - sender: state.sender, name: state.name, parameters: vec![], columns: state.columns, @@ -187,7 +174,6 @@ impl PollPrepare for Prepare { future: TypeinfoFuture::new(oid, state.client), cur_column_name: name, remaining_columns: columns, - sender: state.sender, name: state.name, parameters: vec![], columns: vec![], @@ -195,7 +181,7 @@ impl PollPrepare for Prepare { } transition!(Finished(Statement::new( - state.sender, + state.client.downgrade(), state.name, vec![], vec![] @@ -222,7 +208,6 @@ impl PollPrepare for Prepare { future: TypeinfoFuture::new(oid, client), cur_column_name: name, remaining_columns: columns, - sender: state.sender, name: state.name, parameters: state.parameters, columns: vec![], @@ -230,7 +215,7 @@ impl PollPrepare for Prepare { } transition!(Finished(Statement::new( - state.sender, + client.downgrade(), state.name, state.parameters, vec![], @@ -240,7 +225,7 @@ impl PollPrepare for Prepare { fn poll_get_column_types<'a>( state: &'a mut RentToOwn<'a, GetColumnTypes>, ) -> Poll { - loop { + let client = loop { let (ty, client) = try_ready!(state.future.poll()); let name = mem::replace(&mut state.cur_column_name, String::new()); state.columns.push(Column::new(name, ty)); @@ -250,13 +235,13 @@ impl PollPrepare for Prepare { state.cur_column_name = name; state.future = TypeinfoFuture::new(oid, client); } - None => break, + None => break client, } - } + }; let state = state.take(); transition!(Finished(Statement::new( - state.sender, + client.downgrade(), state.name, state.parameters, state.columns, @@ -265,12 +250,7 @@ impl PollPrepare for Prepare { } impl PrepareFuture { - pub fn new( - request: PendingRequest, - sender: mpsc::UnboundedSender, - name: String, - client: Client, - ) -> PrepareFuture { - Prepare::start(request, sender, name, client) + pub fn new(client: Client, request: PendingRequest, name: String) -> PrepareFuture { + Prepare::start(client, request, name) } } diff --git a/tokio-postgres/src/proto/query.rs b/tokio-postgres/src/proto/query.rs index 8ba108ba..5922b481 100644 --- a/tokio-postgres/src/proto/query.rs +++ b/tokio-postgres/src/proto/query.rs @@ -4,13 +4,14 @@ use postgres_protocol::message::backend::Message; use std::mem; use error::{self, Error}; -use proto::client::PendingRequest; +use proto::client::{Client, PendingRequest}; use proto::row::Row; use proto::statement::Statement; use {bad_response, disconnected}; enum State { Start { + client: Client, request: PendingRequest, statement: Statement, }, @@ -33,8 +34,12 @@ impl Stream for QueryStream { fn poll(&mut self) -> Poll, Error> { loop { match mem::replace(&mut self.0, State::Done) { - State::Start { request, statement } => { - let receiver = request.send()?; + State::Start { + client, + request, + statement, + } => { + let receiver = client.send(request)?; self.0 = State::ReadingResponse { receiver, statement, @@ -102,7 +107,11 @@ impl Stream for QueryStream { } impl QueryStream { - pub fn new(request: PendingRequest, statement: Statement) -> QueryStream { - QueryStream(State::Start { request, statement }) + pub fn new(client: Client, request: PendingRequest, statement: Statement) -> QueryStream { + QueryStream(State::Start { + client, + request, + statement, + }) } } diff --git a/tokio-postgres/src/proto/statement.rs b/tokio-postgres/src/proto/statement.rs index def1613c..3460a76c 100644 --- a/tokio-postgres/src/proto/statement.rs +++ b/tokio-postgres/src/proto/statement.rs @@ -1,13 +1,11 @@ -use futures::sync::mpsc; -use postgres_protocol::message::frontend; use postgres_shared::stmt::Column; use std::sync::Arc; -use proto::connection::Request; +use proto::client::WeakClient; use types::Type; pub struct StatementInner { - sender: mpsc::UnboundedSender, + client: WeakClient, name: String, params: Vec, columns: Vec, @@ -15,14 +13,9 @@ pub struct StatementInner { impl Drop for StatementInner { fn drop(&mut self) { - let mut buf = vec![]; - frontend::close(b'S', &self.name, &mut buf).expect("statement name not valid"); - frontend::sync(&mut buf); - let (sender, _) = mpsc::channel(0); - let _ = self.sender.unbounded_send(Request { - messages: buf, - sender, - }); + if let Some(client) = self.client.upgrade() { + client.close_statement(&self.name); + } } } @@ -31,13 +24,13 @@ pub struct Statement(Arc); impl Statement { pub fn new( - sender: mpsc::UnboundedSender, + client: WeakClient, name: String, params: Vec, columns: Vec, ) -> Statement { Statement(Arc::new(StatementInner { - sender, + client, name, params, columns, diff --git a/tokio-postgres/src/proto/typeinfo.rs b/tokio-postgres/src/proto/typeinfo.rs index fbdc30e5..f50741fe 100644 --- a/tokio-postgres/src/proto/typeinfo.rs +++ b/tokio-postgres/src/proto/typeinfo.rs @@ -31,7 +31,10 @@ WHERE t.oid = $1 #[derive(StateMachineFuture)] pub enum Typeinfo { - #[state_machine_future(start, transitions(PreparingTypeinfo, QueryingTypeinfo, Finished))] + #[state_machine_future( + start, + transitions(PreparingTypeinfo, QueryingTypeinfo, Finished) + )] Start { oid: Oid, client: Client }, #[state_machine_future(transitions(PreparingTypeinfoFallback, QueryingTypeinfo))] PreparingTypeinfo { @@ -47,8 +50,12 @@ pub enum Typeinfo { }, #[state_machine_future( transitions( - CachingType, QueryingEnumVariants, QueryingDomainBasetype, QueryingArrayElem, - QueryingCompositeFields, QueryingRangeSubtype + CachingType, + QueryingEnumVariants, + QueryingDomainBasetype, + QueryingArrayElem, + QueryingCompositeFields, + QueryingRangeSubtype ) )] QueryingTypeinfo { @@ -101,19 +108,17 @@ pub enum Typeinfo { impl PollTypeinfo for Typeinfo { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let mut state = state.take(); + let state = state.take(); if let Some(ty) = Type::from_oid(state.oid) { transition!(Finished((ty, state.client))); } - let ty = state.client.state.lock().types.get(&state.oid).cloned(); - if let Some(ty) = ty { + if let Some(ty) = state.client.cached_type(state.oid) { transition!(Finished((ty, state.client))); } - let statement = state.client.state.lock().typeinfo_query.clone(); - match statement { + match state.client.typeinfo_query() { Some(statement) => transition!(QueryingTypeinfo { future: state.client.query(&statement, &[&state.oid]).collect(), oid: state.oid, @@ -152,10 +157,10 @@ impl PollTypeinfo for Typeinfo { } Err(e) => return Err(e), }; - let mut state = state.take(); + let state = state.take(); let future = state.client.query(&statement, &[&state.oid]).collect(); - state.client.state.lock().typeinfo_query = Some(statement); + state.client.set_typeinfo_query(&statement); transition!(QueryingTypeinfo { future, oid: state.oid, @@ -167,10 +172,10 @@ impl PollTypeinfo for Typeinfo { state: &'a mut RentToOwn<'a, PreparingTypeinfoFallback>, ) -> Poll { let statement = try_ready!(state.future.poll()); - let mut state = state.take(); + let state = state.take(); let future = state.client.query(&statement, &[&state.oid]).collect(); - state.client.state.lock().typeinfo_query = Some(statement); + state.client.set_typeinfo_query(&statement); transition!(QueryingTypeinfo { future, oid: state.oid, @@ -320,12 +325,7 @@ impl PollTypeinfo for Typeinfo { state: &'a mut RentToOwn<'a, CachingType>, ) -> Poll { let state = state.take(); - state - .client - .state - .lock() - .types - .insert(state.oid, state.ty.clone()); + state.client.cache_type(&state.ty); transition!(Finished((state.ty, state.client))) } } diff --git a/tokio-postgres/src/proto/typeinfo_composite.rs b/tokio-postgres/src/proto/typeinfo_composite.rs index d5841cff..ebbe041b 100644 --- a/tokio-postgres/src/proto/typeinfo_composite.rs +++ b/tokio-postgres/src/proto/typeinfo_composite.rs @@ -26,7 +26,8 @@ ORDER BY attnum #[derive(StateMachineFuture)] pub enum TypeinfoComposite { #[state_machine_future( - start, transitions(PreparingTypeinfoComposite, QueryingCompositeFields) + start, + transitions(PreparingTypeinfoComposite, QueryingCompositeFields) )] Start { oid: Oid, client: Client }, #[state_machine_future(transitions(QueryingCompositeFields))] @@ -55,10 +56,9 @@ pub enum TypeinfoComposite { impl PollTypeinfoComposite for TypeinfoComposite { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let mut state = state.take(); + let state = state.take(); - let statement = state.client.state.lock().typeinfo_composite_query.clone(); - match statement { + match state.client.typeinfo_composite_query() { Some(statement) => transition!(QueryingCompositeFields { future: state.client.query(&statement, &[&state.oid]).collect(), client: state.client, @@ -79,9 +79,9 @@ impl PollTypeinfoComposite for TypeinfoComposite { state: &'a mut RentToOwn<'a, PreparingTypeinfoComposite>, ) -> Poll { let statement = try_ready!(state.future.poll()); - let mut state = state.take(); + let state = state.take(); - state.client.state.lock().typeinfo_composite_query = Some(statement.clone()); + state.client.set_typeinfo_composite_query(&statement); transition!(QueryingCompositeFields { future: state.client.query(&statement, &[&state.oid]).collect(), client: state.client, diff --git a/tokio-postgres/src/proto/typeinfo_enum.rs b/tokio-postgres/src/proto/typeinfo_enum.rs index 6cb10e5c..f8eebf9a 100644 --- a/tokio-postgres/src/proto/typeinfo_enum.rs +++ b/tokio-postgres/src/proto/typeinfo_enum.rs @@ -28,7 +28,10 @@ ORDER BY oid #[derive(StateMachineFuture)] pub enum TypeinfoEnum { - #[state_machine_future(start, transitions(PreparingTypeinfoEnum, QueryingEnumVariants))] + #[state_machine_future( + start, + transitions(PreparingTypeinfoEnum, QueryingEnumVariants) + )] Start { oid: Oid, client: Client }, #[state_machine_future(transitions(PreparingTypeinfoEnumFallback, QueryingEnumVariants))] PreparingTypeinfoEnum { @@ -55,10 +58,9 @@ pub enum TypeinfoEnum { impl PollTypeinfoEnum for TypeinfoEnum { fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let mut state = state.take(); + let state = state.take(); - let statement = state.client.state.lock().typeinfo_enum_query.clone(); - match statement { + match state.client.typeinfo_enum_query() { Some(statement) => transition!(QueryingEnumVariants { future: state.client.query(&statement, &[&state.oid]).collect(), client: state.client, @@ -96,9 +98,9 @@ impl PollTypeinfoEnum for TypeinfoEnum { } Err(e) => return Err(e), }; - let mut state = state.take(); + let state = state.take(); - state.client.state.lock().typeinfo_enum_query = Some(statement.clone()); + state.client.set_typeinfo_enum_query(&statement); transition!(QueryingEnumVariants { future: state.client.query(&statement, &[&state.oid]).collect(), client: state.client, @@ -109,9 +111,9 @@ impl PollTypeinfoEnum for TypeinfoEnum { state: &'a mut RentToOwn<'a, PreparingTypeinfoEnumFallback>, ) -> Poll { let statement = try_ready!(state.future.poll()); - let mut state = state.take(); + let state = state.take(); - state.client.state.lock().typeinfo_enum_query = Some(statement.clone()); + state.client.set_typeinfo_enum_query(&statement); transition!(QueryingEnumVariants { future: state.client.query(&statement, &[&state.oid]).collect(), client: state.client, diff --git a/tokio-postgres/tests/test.rs b/tokio-postgres/tests/test.rs index ee126661..69ce8942 100644 --- a/tokio-postgres/tests/test.rs +++ b/tokio-postgres/tests/test.rs @@ -482,17 +482,14 @@ fn notifications() { let listen = client.prepare("LISTEN test_notifications"); let listen = runtime.block_on(listen).unwrap(); runtime.block_on(client.execute(&listen, &[])).unwrap(); - drop(listen); // FIXME let notify = client.prepare("NOTIFY test_notifications, 'hello'"); let notify = runtime.block_on(notify).unwrap(); runtime.block_on(client.execute(¬ify, &[])).unwrap(); - drop(notify); // FIXME let notify = client.prepare("NOTIFY test_notifications, 'world'"); let notify = runtime.block_on(notify).unwrap(); runtime.block_on(client.execute(¬ify, &[])).unwrap(); - drop(notify); // FIXME drop(client); runtime.run().unwrap();