Start on tokio-postgres rewrite
This commit is contained in:
parent
5ad7850009
commit
8c3770bd57
@ -8,3 +8,9 @@ members = [
|
||||
"postgres-native-tls",
|
||||
"tokio-postgres",
|
||||
]
|
||||
|
||||
[patch.crates-io]
|
||||
tokio-uds = { git = "https://github.com/sfackler/tokio" }
|
||||
tokio-io = { git = "https://github.com/sfackler/tokio" }
|
||||
tokio-timer = { git = "https://github.com/sfackler/tokio" }
|
||||
tokio-codec = { git = "https://github.com/sfackler/tokio" }
|
||||
|
@ -1,9 +1,9 @@
|
||||
//! Errors.
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::ErrorFields;
|
||||
use std::error;
|
||||
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use std::convert::From;
|
||||
use std::error;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
|
||||
@ -214,36 +214,29 @@ impl DbError {
|
||||
}
|
||||
|
||||
Ok(DbError {
|
||||
severity: severity.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing")
|
||||
})?,
|
||||
severity: severity
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
|
||||
parsed_severity: parsed_severity,
|
||||
code: code.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing")
|
||||
})?,
|
||||
message: message.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing")
|
||||
})?,
|
||||
code: code
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
|
||||
message: message
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
|
||||
detail: detail,
|
||||
hint: hint,
|
||||
position: match normal_position {
|
||||
Some(position) => Some(ErrorPosition::Normal(position)),
|
||||
None => {
|
||||
match internal_position {
|
||||
Some(position) => {
|
||||
Some(ErrorPosition::Internal {
|
||||
position: position,
|
||||
query: internal_query.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`q` field missing but `p` field present",
|
||||
)
|
||||
})?,
|
||||
})
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
None => match internal_position {
|
||||
Some(position) => Some(ErrorPosition::Internal {
|
||||
position: position,
|
||||
query: internal_query.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`q` field missing but `p` field present",
|
||||
)
|
||||
})?,
|
||||
}),
|
||||
None => None,
|
||||
},
|
||||
},
|
||||
where_: where_,
|
||||
schema: schema,
|
||||
@ -324,6 +317,14 @@ pub fn db(e: DbError) -> Error {
|
||||
Error(Box::new(ErrorKind::Db(e)))
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn __db(e: ErrorResponseBody) -> Error {
|
||||
match DbError::new(&mut e.fields()) {
|
||||
Ok(e) => Error(Box::new(ErrorKind::Db(e))),
|
||||
Err(e) => Error(Box::new(ErrorKind::Io(e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn io(e: io::Error) -> Error {
|
||||
Error(Box::new(ErrorKind::Io(e)))
|
||||
@ -401,7 +402,7 @@ impl Error {
|
||||
pub fn as_db(&self) -> Option<&DbError> {
|
||||
match *self.0 {
|
||||
ErrorKind::Db(ref err) => Some(err),
|
||||
_ => None
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -31,21 +31,19 @@ circle-ci = { repository = "sfackler/rust-postgres" }
|
||||
"with-serde_json-1" = ["postgres-shared/with-serde_json-1"]
|
||||
"with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"]
|
||||
|
||||
with-openssl = ["tokio-openssl", "openssl"]
|
||||
|
||||
[dependencies]
|
||||
bytes = "0.4"
|
||||
fallible-iterator = "0.1.3"
|
||||
futures = "0.1.7"
|
||||
futures-state-stream = "0.2"
|
||||
futures-cpupool = "0.1"
|
||||
lazy_static = "1.0"
|
||||
postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" }
|
||||
postgres-shared = { version = "0.4.0", path = "../postgres-shared" }
|
||||
tokio-core = "0.1.8"
|
||||
tokio-dns-unofficial = "0.1"
|
||||
state_machine_future = "0.1.7"
|
||||
tokio-codec = "0.1"
|
||||
tokio-io = "0.1"
|
||||
|
||||
tokio-openssl = { version = "0.2", optional = true }
|
||||
openssl = { version = "0.10", optional = true }
|
||||
tokio-tcp = "0.1"
|
||||
tokio-timer = "0.2"
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
tokio-uds = "0.1"
|
||||
tokio-uds = "0.2"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,31 +0,0 @@
|
||||
/// Generates a simple implementation of `ToSql::accepts` which accepts the
|
||||
/// types passed to it.
|
||||
#[macro_export]
|
||||
macro_rules! accepts {
|
||||
($($expected:pat),+) => (
|
||||
fn accepts(ty: &$crate::types::Type) -> bool {
|
||||
match *ty {
|
||||
$($expected)|+ => true,
|
||||
_ => false
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates an implementation of `ToSql::to_sql_checked`.
|
||||
///
|
||||
/// All `ToSql` implementations should use this macro.
|
||||
#[macro_export]
|
||||
macro_rules! to_sql_checked {
|
||||
() => {
|
||||
fn to_sql_checked(&self,
|
||||
ty: &$crate::types::Type,
|
||||
out: &mut ::std::vec::Vec<u8>)
|
||||
-> ::std::result::Result<$crate::types::IsNull,
|
||||
Box<::std::error::Error +
|
||||
::std::marker::Sync +
|
||||
::std::marker::Send>> {
|
||||
$crate::types::__to_sql_checked(self, ty, out)
|
||||
}
|
||||
}
|
||||
}
|
25
tokio-postgres/src/proto/codec.rs
Normal file
25
tokio-postgres/src/proto/codec.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use bytes::BytesMut;
|
||||
use postgres_protocol::message::backend;
|
||||
use std::io;
|
||||
use tokio_codec::{Decoder, Encoder};
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder for PostgresCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), io::Error> {
|
||||
dst.extend_from_slice(&item);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<backend::Message>, io::Error> {
|
||||
backend::Message::parse(src)
|
||||
}
|
||||
}
|
192
tokio-postgres/src/proto/connection.rs
Normal file
192
tokio-postgres/src/proto/connection.rs
Normal file
@ -0,0 +1,192 @@
|
||||
use futures::sync::mpsc;
|
||||
use futures::{Async, AsyncSink, Future, Poll, Sink, Stream};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io;
|
||||
use tokio_codec::Framed;
|
||||
|
||||
use error::{self, Error};
|
||||
use proto::codec::PostgresCodec;
|
||||
use proto::socket::Socket;
|
||||
use {bad_response, CancelData};
|
||||
|
||||
pub struct Request {
|
||||
pub messages: Vec<u8>,
|
||||
pub sender: mpsc::Sender<Message>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum State {
|
||||
Active,
|
||||
Terminating,
|
||||
Closing,
|
||||
}
|
||||
|
||||
pub struct Connection {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
cancel_data: CancelData,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::Receiver<Request>,
|
||||
pending_request: Option<Vec<u8>>,
|
||||
pending_response: Option<Message>,
|
||||
responses: VecDeque<mpsc::Sender<Message>>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new(
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
cancel_data: CancelData,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::Receiver<Request>,
|
||||
) -> Connection {
|
||||
Connection {
|
||||
stream,
|
||||
cancel_data,
|
||||
parameters,
|
||||
receiver,
|
||||
pending_request: None,
|
||||
pending_response: None,
|
||||
responses: VecDeque::new(),
|
||||
state: State::Active,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancel_data(&self) -> CancelData {
|
||||
self.cancel_data
|
||||
}
|
||||
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
self.parameters.get(name).map(|s| &**s)
|
||||
}
|
||||
|
||||
fn poll_response(&mut self) -> Poll<Option<Message>, io::Error> {
|
||||
if let Some(message) = self.pending_response.take() {
|
||||
return Ok(Async::Ready(Some(message)));
|
||||
}
|
||||
|
||||
self.stream.poll()
|
||||
}
|
||||
|
||||
fn poll_read(&mut self) -> Result<(), Error> {
|
||||
loop {
|
||||
let message = match self.poll_response()? {
|
||||
Async::Ready(Some(message)) => message,
|
||||
Async::Ready(None) => {
|
||||
return Err(Error::from(io::Error::from(io::ErrorKind::UnexpectedEof)));
|
||||
}
|
||||
Async::NotReady => return Ok(()),
|
||||
};
|
||||
|
||||
let message = match message {
|
||||
Message::NoticeResponse(_) | Message::NotificationResponse(_) => {
|
||||
// FIXME handle these
|
||||
continue;
|
||||
}
|
||||
Message::ParameterStatus(body) => {
|
||||
self.parameters
|
||||
.insert(body.name()?.to_string(), body.value()?.to_string());
|
||||
continue;
|
||||
}
|
||||
m => m,
|
||||
};
|
||||
|
||||
let mut sender = match self.responses.pop_front() {
|
||||
Some(sender) => sender,
|
||||
None => match message {
|
||||
Message::ErrorResponse(error) => return Err(error::__db(error)),
|
||||
_ => return Err(bad_response()),
|
||||
},
|
||||
};
|
||||
|
||||
let ready = match message {
|
||||
Message::ReadyForQuery(_) => true,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
match sender.start_send(message) {
|
||||
// if the receiver's hung up we still need to page through the rest of the messages
|
||||
// designated to it
|
||||
Ok(AsyncSink::Ready) | Err(_) => {
|
||||
if !ready {
|
||||
self.responses.push_front(sender);
|
||||
}
|
||||
}
|
||||
Ok(AsyncSink::NotReady(message)) => {
|
||||
self.responses.push_front(sender);
|
||||
self.pending_response = Some(message);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_request(&mut self) -> Poll<Option<Vec<u8>>, Error> {
|
||||
if let Some(message) = self.pending_request.take() {
|
||||
return Ok(Async::Ready(Some(message)));
|
||||
}
|
||||
|
||||
match self.receiver.poll() {
|
||||
Ok(Async::Ready(Some(request))) => {
|
||||
self.responses.push_back(request.sender);
|
||||
Ok(Async::Ready(Some(request.messages)))
|
||||
}
|
||||
Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => Ok(Async::NotReady),
|
||||
Err(()) => unreachable!("mpsc::Receiver doesn't return errors"),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write(&mut self) -> Result<(), Error> {
|
||||
loop {
|
||||
let request = match self.poll_request()? {
|
||||
Async::Ready(Some(request)) => request,
|
||||
Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
|
||||
self.state = State::Terminating;
|
||||
let mut request = vec![];
|
||||
frontend::terminate(&mut request);
|
||||
request
|
||||
}
|
||||
Async::Ready(None) => return Ok(()),
|
||||
Async::NotReady => return Ok(()),
|
||||
};
|
||||
|
||||
match self.stream.start_send(request)? {
|
||||
AsyncSink::Ready => {
|
||||
if self.state == State::Terminating {
|
||||
self.state = State::Closing;
|
||||
}
|
||||
}
|
||||
AsyncSink::NotReady(request) => {
|
||||
self.pending_request = Some(request);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self) -> Result<(), Error> {
|
||||
self.stream.poll_complete()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self) -> Poll<(), Error> {
|
||||
match self.state {
|
||||
State::Active | State::Terminating => Ok(Async::NotReady),
|
||||
State::Closing => self.stream.close().map_err(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Connection {
|
||||
type Item = ();
|
||||
type Error = Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<(), Error> {
|
||||
self.poll_read()?;
|
||||
self.poll_write()?;
|
||||
self.poll_flush()?;
|
||||
self.poll_shutdown()
|
||||
}
|
||||
}
|
306
tokio-postgres/src/proto/handshake.rs
Normal file
306
tokio-postgres/src/proto/handshake.rs
Normal file
@ -0,0 +1,306 @@
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::sink;
|
||||
use futures::sync::mpsc;
|
||||
use futures::{Future, Poll, Sink, Stream};
|
||||
use postgres_protocol::authentication;
|
||||
use postgres_protocol::authentication::sasl::{self, ChannelBinding, ScramSha256};
|
||||
use postgres_protocol::message::backend::Message;
|
||||
use postgres_protocol::message::frontend;
|
||||
use state_machine_future::RentToOwn;
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use tokio_codec::Framed;
|
||||
|
||||
use error::{self, Error};
|
||||
use params::{ConnectParams, User};
|
||||
use proto::codec::PostgresCodec;
|
||||
use proto::connection::{Connection, Request};
|
||||
use proto::socket::{ConnectFuture, Socket};
|
||||
use {bad_response, disconnected, CancelData};
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Handshake {
|
||||
#[state_machine_future(start, transitions(SendingStartup))]
|
||||
Start {
|
||||
future: ConnectFuture,
|
||||
params: ConnectParams,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuth))]
|
||||
SendingStartup {
|
||||
future: sink::Send<Framed<Socket, PostgresCodec>>,
|
||||
user: User,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))]
|
||||
ReadingAuth {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
user: User,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingAuthCompletion))]
|
||||
SendingPassword {
|
||||
future: sink::Send<Framed<Socket, PostgresCodec>>,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingSasl))]
|
||||
SendingSasl {
|
||||
future: sink::Send<Framed<Socket, PostgresCodec>>,
|
||||
scram: ScramSha256,
|
||||
},
|
||||
#[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))]
|
||||
ReadingSasl {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
scram: ScramSha256,
|
||||
},
|
||||
#[state_machine_future(transitions(ReadingInfo))]
|
||||
ReadingAuthCompletion {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
},
|
||||
#[state_machine_future(transitions(Finished))]
|
||||
ReadingInfo {
|
||||
stream: Framed<Socket, PostgresCodec>,
|
||||
cancel_data: Option<CancelData>,
|
||||
parameters: HashMap<String, String>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Finished((mpsc::Sender<Request>, Connection)),
|
||||
#[state_machine_future(error)]
|
||||
Failed(Error),
|
||||
}
|
||||
|
||||
impl PollHandshake for Handshake {
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, Error> {
|
||||
let stream = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
|
||||
let user = match state.params.user() {
|
||||
Some(user) => user.clone(),
|
||||
None => {
|
||||
return Err(error::connect(
|
||||
"user missing from connection parameters".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut buf = vec![];
|
||||
{
|
||||
let options = state
|
||||
.params
|
||||
.options()
|
||||
.iter()
|
||||
.map(|&(ref key, ref value)| (&**key, &**value));
|
||||
let client_encoding = Some(("client_encoding", "UTF8"));
|
||||
let timezone = Some(("timezone", "GMT"));
|
||||
let user = Some(("user", user.name()));
|
||||
let database = state.params.database().map(|s| ("database", s));
|
||||
|
||||
frontend::startup_message(
|
||||
options
|
||||
.chain(client_encoding)
|
||||
.chain(timezone)
|
||||
.chain(user)
|
||||
.chain(database),
|
||||
&mut buf,
|
||||
)?;
|
||||
}
|
||||
|
||||
let stream = Framed::new(stream, PostgresCodec);
|
||||
|
||||
transition!(SendingStartup {
|
||||
future: stream.send(buf),
|
||||
user,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_sending_startup<'a>(
|
||||
state: &'a mut RentToOwn<'a, SendingStartup>,
|
||||
) -> Poll<AfterSendingStartup, Error> {
|
||||
let stream = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
transition!(ReadingAuth {
|
||||
stream,
|
||||
user: state.user,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_reading_auth<'a>(
|
||||
state: &'a mut RentToOwn<'a, ReadingAuth>,
|
||||
) -> Poll<AfterReadingAuth, Error> {
|
||||
let message = try_ready!(state.stream.poll());
|
||||
let state = state.take();
|
||||
|
||||
match message {
|
||||
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
|
||||
stream: state.stream,
|
||||
cancel_data: None,
|
||||
parameters: HashMap::new(),
|
||||
}),
|
||||
Some(Message::AuthenticationCleartextPassword) => {
|
||||
let pass = state.user.password().ok_or_else(missing_password)?;
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(pass, &mut buf)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf)
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
let pass = state.user.password().ok_or_else(missing_password)?;
|
||||
let output = authentication::md5_hash(
|
||||
state.user.name().as_bytes(),
|
||||
pass.as_bytes(),
|
||||
body.salt(),
|
||||
);
|
||||
let mut buf = vec![];
|
||||
frontend::password_message(&output, &mut buf)?;
|
||||
transition!(SendingPassword {
|
||||
future: state.stream.send(buf)
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
let pass = state.user.password().ok_or_else(missing_password)?;
|
||||
|
||||
let mut has_scram = false;
|
||||
let mut mechanisms = body.mechanisms();
|
||||
while let Some(mechanism) = mechanisms.next()? {
|
||||
match mechanism {
|
||||
sasl::SCRAM_SHA_256 => has_scram = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !has_scram {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"unsupported SASL authentication",
|
||||
).into());
|
||||
}
|
||||
|
||||
let mut scram = ScramSha256::new(pass.as_bytes(), ChannelBinding::unsupported())?;
|
||||
|
||||
let mut buf = vec![];
|
||||
frontend::sasl_initial_response(sasl::SCRAM_SHA_256, scram.message(), &mut buf)?;
|
||||
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
scram,
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationKerberosV5)
|
||||
| Some(Message::AuthenticationScmCredential)
|
||||
| Some(Message::AuthenticationGss)
|
||||
| Some(Message::AuthenticationSspi) => Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"unsupported authentication method",
|
||||
).into()),
|
||||
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
|
||||
Some(_) => Err(bad_response()),
|
||||
None => Err(disconnected()),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_sending_password<'a>(
|
||||
state: &'a mut RentToOwn<'a, SendingPassword>,
|
||||
) -> Poll<AfterSendingPassword, Error> {
|
||||
let stream = try_ready!(state.future.poll());
|
||||
transition!(ReadingAuthCompletion { stream })
|
||||
}
|
||||
|
||||
fn poll_sending_sasl<'a>(
|
||||
state: &'a mut RentToOwn<'a, SendingSasl>,
|
||||
) -> Poll<AfterSendingSasl, Error> {
|
||||
let stream = try_ready!(state.future.poll());
|
||||
let state = state.take();
|
||||
transition!(ReadingSasl {
|
||||
stream,
|
||||
scram: state.scram
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_reading_sasl<'a>(
|
||||
state: &'a mut RentToOwn<'a, ReadingSasl>,
|
||||
) -> Poll<AfterReadingSasl, Error> {
|
||||
let message = try_ready!(state.stream.poll());
|
||||
let mut state = state.take();
|
||||
|
||||
match message {
|
||||
Some(Message::AuthenticationSaslContinue(body)) => {
|
||||
state.scram.update(body.data())?;
|
||||
let mut buf = vec![];
|
||||
frontend::sasl_response(state.scram.message(), &mut buf)?;
|
||||
transition!(SendingSasl {
|
||||
future: state.stream.send(buf),
|
||||
scram: state.scram,
|
||||
})
|
||||
}
|
||||
Some(Message::AuthenticationSaslFinal(body)) => {
|
||||
state.scram.finish(body.data())?;
|
||||
transition!(ReadingAuthCompletion {
|
||||
stream: state.stream,
|
||||
})
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
|
||||
Some(_) => Err(bad_response()),
|
||||
None => Err(disconnected()),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_reading_auth_completion<'a>(
|
||||
state: &'a mut RentToOwn<'a, ReadingAuthCompletion>,
|
||||
) -> Poll<AfterReadingAuthCompletion, Error> {
|
||||
let message = try_ready!(state.stream.poll());
|
||||
let state = state.take();
|
||||
|
||||
match message {
|
||||
Some(Message::AuthenticationOk) => transition!(ReadingInfo {
|
||||
stream: state.stream,
|
||||
cancel_data: None,
|
||||
parameters: HashMap::new(),
|
||||
}),
|
||||
Some(Message::ErrorResponse(body)) => Err(error::__db(body)),
|
||||
Some(_) => Err(bad_response()),
|
||||
None => Err(disconnected()),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_reading_info<'a>(
|
||||
state: &'a mut RentToOwn<'a, ReadingInfo>,
|
||||
) -> Poll<AfterReadingInfo, Error> {
|
||||
loop {
|
||||
let message = try_ready!(state.stream.poll());
|
||||
match message {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
state.cancel_data = Some(CancelData {
|
||||
process_id: body.process_id(),
|
||||
secret_key: body.secret_key(),
|
||||
});
|
||||
}
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
state
|
||||
.parameters
|
||||
.insert(body.name()?.to_string(), body.value()?.to_string());
|
||||
}
|
||||
Some(Message::ReadyForQuery(_)) => {
|
||||
let state = state.take();
|
||||
let cancel_data = state.cancel_data.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidData, "BackendKeyData message missing")
|
||||
})?;
|
||||
let (sender, receiver) = mpsc::channel(0);
|
||||
let connection =
|
||||
Connection::new(state.stream, cancel_data, state.parameters, receiver);
|
||||
transition!(Finished((sender, connection)))
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => return Err(error::__db(body)),
|
||||
Some(Message::NoticeResponse(_)) => {}
|
||||
Some(_) => return Err(bad_response()),
|
||||
None => return Err(disconnected()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HandshakeFuture {
|
||||
pub fn new(params: ConnectParams) -> HandshakeFuture {
|
||||
Handshake::start(Socket::connect(¶ms), params)
|
||||
}
|
||||
}
|
||||
|
||||
fn missing_password() -> Error {
|
||||
error::connect("a password was requested but not provided".into())
|
||||
}
|
9
tokio-postgres/src/proto/mod.rs
Normal file
9
tokio-postgres/src/proto/mod.rs
Normal file
@ -0,0 +1,9 @@
|
||||
mod codec;
|
||||
mod connection;
|
||||
mod handshake;
|
||||
mod socket;
|
||||
|
||||
pub use proto::codec::PostgresCodec;
|
||||
pub use proto::connection::{Connection, Request};
|
||||
pub use proto::handshake::HandshakeFuture;
|
||||
pub use proto::socket::Socket;
|
233
tokio-postgres/src/proto/socket.rs
Normal file
233
tokio-postgres/src/proto/socket.rs
Normal file
@ -0,0 +1,233 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use futures::{Async, Future, Poll};
|
||||
use futures_cpupool::{CpuFuture, CpuPool};
|
||||
use state_machine_future::RentToOwn;
|
||||
use std::io::{self, Read, Write};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
use std::time::{Duration, Instant};
|
||||
use std::vec;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use tokio_tcp::{self, TcpStream};
|
||||
use tokio_timer::Delay;
|
||||
|
||||
#[cfg(unix)]
|
||||
use tokio_uds::{self, UnixStream};
|
||||
|
||||
use params::{ConnectParams, Host};
|
||||
|
||||
lazy_static! {
|
||||
static ref DNS_POOL: CpuPool = CpuPool::new(2);
|
||||
}
|
||||
|
||||
pub enum Socket {
|
||||
Tcp(TcpStream),
|
||||
#[cfg(unix)]
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl Read for Socket {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.read(buf),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Socket {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.prepare_uninitialized_buffer(buf),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: BufMut,
|
||||
{
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.read_buf(buf),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.read_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Socket {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.write(buf),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.flush(),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socket {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.shutdown(),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.shutdown(),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: Buf,
|
||||
{
|
||||
match self {
|
||||
Socket::Tcp(stream) => stream.write_buf(buf),
|
||||
#[cfg(unix)]
|
||||
Socket::Unix(stream) => stream.write_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Socket {
|
||||
pub fn connect(params: &ConnectParams) -> ConnectFuture {
|
||||
Connect::start(params.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(StateMachineFuture)]
|
||||
pub enum Connect {
|
||||
#[state_machine_future(start)]
|
||||
#[cfg_attr(unix, state_machine_future(transitions(ResolvingDns, ConnectingUnix)))]
|
||||
#[cfg_attr(not(unix), state_machine_future(transitions(ResolvingDns)))]
|
||||
Start { params: ConnectParams },
|
||||
#[state_machine_future(transitions(ConnectingTcp))]
|
||||
ResolvingDns {
|
||||
future: CpuFuture<vec::IntoIter<SocketAddr>, io::Error>,
|
||||
timeout: Option<Duration>,
|
||||
},
|
||||
#[state_machine_future(transitions(Ready))]
|
||||
ConnectingTcp {
|
||||
addrs: vec::IntoIter<SocketAddr>,
|
||||
future: tokio_tcp::ConnectFuture,
|
||||
timeout: Option<(Duration, Delay)>,
|
||||
},
|
||||
#[cfg(unix)]
|
||||
#[state_machine_future(transitions(Ready))]
|
||||
ConnectingUnix {
|
||||
future: tokio_uds::ConnectFuture,
|
||||
timeout: Option<Delay>,
|
||||
},
|
||||
#[state_machine_future(ready)]
|
||||
Ready(Socket),
|
||||
#[state_machine_future(error)]
|
||||
Failed(io::Error),
|
||||
}
|
||||
|
||||
impl PollConnect for Connect {
|
||||
fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll<AfterStart, io::Error> {
|
||||
let timeout = state.params.connect_timeout();
|
||||
let port = state.params.port();
|
||||
|
||||
match state.params.host() {
|
||||
Host::Tcp(ref host) => {
|
||||
let host = host.clone();
|
||||
transition!(ResolvingDns {
|
||||
future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()),
|
||||
timeout,
|
||||
})
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Host::Unix(ref path) => {
|
||||
let path = path.join(format!(".s.PGSQL.{}", port));
|
||||
transition!(ConnectingUnix {
|
||||
future: UnixStream::connect(path),
|
||||
timeout: timeout.map(|t| Delay::new(Instant::now() + t))
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_resolving_dns<'a>(
|
||||
state: &'a mut RentToOwn<'a, ResolvingDns>,
|
||||
) -> Poll<AfterResolvingDns, io::Error> {
|
||||
let mut addrs = try_ready!(state.future.poll());
|
||||
|
||||
let addr = match addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"resolved to 0 addresses",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
transition!(ConnectingTcp {
|
||||
addrs,
|
||||
future: TcpStream::connect(&addr),
|
||||
timeout: state.timeout.map(|t| (t, Delay::new(Instant::now() + t))),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_connecting_tcp<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingTcp>,
|
||||
) -> Poll<AfterConnectingTcp, io::Error> {
|
||||
loop {
|
||||
let error = match state.future.poll() {
|
||||
Ok(Async::Ready(socket)) => transition!(Ready(Socket::Tcp(socket))),
|
||||
Ok(Async::NotReady) => match state.timeout {
|
||||
Some((_, ref mut delay)) => {
|
||||
try_ready!(
|
||||
delay
|
||||
.poll()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
);
|
||||
io::Error::new(io::ErrorKind::TimedOut, "connection timed out")
|
||||
}
|
||||
None => return Ok(Async::NotReady),
|
||||
},
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
let addr = match state.addrs.next() {
|
||||
Some(addr) => addr,
|
||||
None => return Err(error),
|
||||
};
|
||||
|
||||
state.future = TcpStream::connect(&addr);
|
||||
if let Some((timeout, ref mut delay)) = state.timeout {
|
||||
delay.reset(Instant::now() + timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn poll_connecting_unix<'a>(
|
||||
state: &'a mut RentToOwn<'a, ConnectingUnix>,
|
||||
) -> Poll<AfterConnectingUnix, io::Error> {
|
||||
match state.future.poll()? {
|
||||
Async::Ready(socket) => transition!(Ready(Socket::Unix(socket))),
|
||||
Async::NotReady => match state.timeout {
|
||||
Some(ref mut delay) => {
|
||||
try_ready!(
|
||||
delay
|
||||
.poll()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||
);
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"connection timed out",
|
||||
))
|
||||
}
|
||||
None => Ok(Async::NotReady),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
@ -1,84 +0,0 @@
|
||||
//! Postgres rows.
|
||||
|
||||
use postgres_shared::rows::RowData;
|
||||
use postgres_shared::stmt::Column;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use postgres_shared::rows::RowIndex;
|
||||
|
||||
use types::{FromSql, WrongType};
|
||||
|
||||
/// A row from Postgres.
|
||||
pub struct Row {
|
||||
columns: Arc<Vec<Column>>,
|
||||
data: RowData,
|
||||
}
|
||||
|
||||
impl Row {
|
||||
pub(crate) fn new(columns: Arc<Vec<Column>>, data: RowData) -> Row {
|
||||
Row {
|
||||
columns: columns,
|
||||
data: data,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns information about the columns in the row.
|
||||
pub fn columns(&self) -> &[Column] {
|
||||
&self.columns
|
||||
}
|
||||
|
||||
/// Returns the number of values in the row
|
||||
pub fn len(&self) -> usize {
|
||||
self.columns.len()
|
||||
}
|
||||
|
||||
/// Retrieves the contents of a field of the row.
|
||||
///
|
||||
/// A field can be accessed by the name or index of its column, though
|
||||
/// access by index is more efficient. Rows are 0-indexed.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the index does not reference a column or the return type is
|
||||
/// not compatible with the Postgres type.
|
||||
pub fn get<'a, T, I>(&'a self, idx: I) -> T
|
||||
where
|
||||
T: FromSql<'a>,
|
||||
I: RowIndex + fmt::Debug,
|
||||
{
|
||||
match self.try_get(&idx) {
|
||||
Ok(Some(v)) => v,
|
||||
Ok(None) => panic!("no such column {:?}", idx),
|
||||
Err(e) => panic!("error retrieving row {:?}: {}", idx, e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves the contents of a field of the row.
|
||||
///
|
||||
/// A field can be accessed by the name or index of its column, though
|
||||
/// access by index is more efficient. Rows are 0-indexed.
|
||||
///
|
||||
/// Returns `None` if the index does not reference a column, `Some(Err(..))`
|
||||
/// if there was an error converting the result value, and `Some(Ok(..))`
|
||||
/// on success.
|
||||
pub fn try_get<'a, T, I>(&'a self, idx: I) -> Result<Option<T>, Box<Error + Sync + Send>>
|
||||
where
|
||||
T: FromSql<'a>,
|
||||
I: RowIndex,
|
||||
{
|
||||
let idx = match idx.__idx(&self.columns) {
|
||||
Some(idx) => idx,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let ty = self.columns[idx].type_();
|
||||
if !T::accepts(ty) {
|
||||
return Err(Box::new(WrongType::new(ty.clone())));
|
||||
}
|
||||
|
||||
T::from_sql_nullable(ty, self.data.get(idx)).map(Some)
|
||||
}
|
||||
}
|
@ -1,162 +0,0 @@
|
||||
use futures::{Sink, Future, Poll, AsyncSink, Async, Stream};
|
||||
use futures::stream::Fuse;
|
||||
|
||||
pub trait SinkExt: Sink {
|
||||
fn send2(self, item: Self::SinkItem) -> Send2<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
// unlike send_all, this doesn't close the stream.
|
||||
fn send_all2<S>(self, stream: S) -> SendAll2<Self, S>
|
||||
where
|
||||
S: Stream<Item = Self::SinkItem>,
|
||||
Self::SinkError: From<S::Error>,
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
impl<T> SinkExt for T
|
||||
where
|
||||
T: Sink,
|
||||
{
|
||||
fn send2(self, item: Self::SinkItem) -> Send2<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Send2 {
|
||||
sink: Some(self),
|
||||
item: Some(item),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_all2<S>(self, stream: S) -> SendAll2<Self, S>
|
||||
where
|
||||
S: Stream<Item = Self::SinkItem>,
|
||||
Self::SinkError: From<S::Error>,
|
||||
Self: Sized,
|
||||
{
|
||||
SendAll2 {
|
||||
sink: Some(self),
|
||||
stream: Some(stream.fuse()),
|
||||
buffered: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Send2<T>
|
||||
where
|
||||
T: Sink,
|
||||
{
|
||||
sink: Option<T>,
|
||||
item: Option<T::SinkItem>,
|
||||
}
|
||||
|
||||
impl<T> Future for Send2<T>
|
||||
where
|
||||
T: Sink,
|
||||
{
|
||||
type Item = T;
|
||||
type Error = (T::SinkError, T);
|
||||
|
||||
fn poll(&mut self) -> Poll<T, (T::SinkError, T)> {
|
||||
let mut sink = self.sink.take().expect("poll called after completion");
|
||||
|
||||
if let Some(item) = self.item.take() {
|
||||
match sink.start_send(item) {
|
||||
Ok(AsyncSink::NotReady(item)) => {
|
||||
self.sink = Some(sink);
|
||||
self.item = Some(item);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Ok(AsyncSink::Ready) => {}
|
||||
Err(e) => return Err((e, sink)),
|
||||
}
|
||||
}
|
||||
|
||||
match sink.poll_complete() {
|
||||
Ok(Async::Ready(())) => {}
|
||||
Ok(Async::NotReady) => {
|
||||
self.sink = Some(sink);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(e) => return Err((e, sink)),
|
||||
}
|
||||
|
||||
Ok(Async::Ready(sink))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SendAll2<T, U>
|
||||
where
|
||||
U: Stream,
|
||||
{
|
||||
sink: Option<T>,
|
||||
stream: Option<Fuse<U>>,
|
||||
buffered: Option<U::Item>,
|
||||
}
|
||||
|
||||
impl<T, U> Future for SendAll2<T, U>
|
||||
where
|
||||
T: Sink,
|
||||
U: Stream<Item = T::SinkItem>,
|
||||
T::SinkError: From<U::Error>,
|
||||
{
|
||||
type Item = (T, U);
|
||||
type Error = (T::SinkError, T, U);
|
||||
|
||||
fn poll(&mut self) -> Poll<(T, U), (T::SinkError, T, U)> {
|
||||
let mut stream = self.stream.take().expect("poll called after completion");
|
||||
let mut sink = self.sink.take().expect("poll called after completion");
|
||||
|
||||
if let Some(item) = self.buffered.take() {
|
||||
match sink.start_send(item) {
|
||||
Ok(AsyncSink::Ready) => {}
|
||||
Ok(AsyncSink::NotReady(item)) => {
|
||||
self.sink = Some(sink);
|
||||
self.buffered = Some(item);
|
||||
self.stream = Some(stream);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(e) => return Err((e, sink, stream.into_inner())),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
match stream.poll() {
|
||||
Ok(Async::Ready(Some(item))) => {
|
||||
match sink.start_send(item) {
|
||||
Ok(AsyncSink::Ready) => {}
|
||||
Ok(AsyncSink::NotReady(item)) => {
|
||||
self.sink = Some(sink);
|
||||
self.buffered = Some(item);
|
||||
self.stream = Some(stream);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(e) => return Err((e, sink, stream.into_inner())),
|
||||
}
|
||||
}
|
||||
Ok(Async::Ready(None)) => {
|
||||
match sink.poll_complete() {
|
||||
Ok(Async::Ready(())) => return Ok(Async::Ready((sink, stream.into_inner()))),
|
||||
Ok(Async::NotReady) => {
|
||||
self.sink = Some(sink);
|
||||
self.stream = Some(stream);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(e) => return Err((e, sink, stream.into_inner())),
|
||||
}
|
||||
}
|
||||
Ok(Async::NotReady) => {
|
||||
match sink.poll_complete() {
|
||||
Ok(Async::Ready(())) | Ok(Async::NotReady) => {
|
||||
self.sink = Some(sink);
|
||||
self.stream = Some(stream);
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(e) => return Err((e, sink, stream.into_inner())),
|
||||
}
|
||||
}
|
||||
Err(e) => return Err((e.into(), sink, stream.into_inner())),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
//! Prepared statements.
|
||||
|
||||
use std::mem;
|
||||
use std::sync::Arc;
|
||||
use std::sync::mpsc::Sender;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use postgres_shared::stmt::Column;
|
||||
|
||||
use types::Type;
|
||||
|
||||
/// A prepared statement.
|
||||
pub struct Statement {
|
||||
close_sender: Sender<(u8, String)>,
|
||||
name: String,
|
||||
params: Vec<Type>,
|
||||
columns: Arc<Vec<Column>>,
|
||||
}
|
||||
|
||||
impl Drop for Statement {
|
||||
fn drop(&mut self) {
|
||||
let name = mem::replace(&mut self.name, String::new());
|
||||
let _ = self.close_sender.send((b'S', name));
|
||||
}
|
||||
}
|
||||
|
||||
impl Statement {
|
||||
pub(crate) fn new(
|
||||
close_sender: Sender<(u8, String)>,
|
||||
name: String,
|
||||
params: Vec<Type>,
|
||||
columns: Arc<Vec<Column>>,
|
||||
) -> Statement {
|
||||
Statement {
|
||||
close_sender: close_sender,
|
||||
name: name,
|
||||
params: params,
|
||||
columns: columns,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn columns_arc(&self) -> &Arc<Vec<Column>> {
|
||||
&self.columns
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns the types of query parameters for this statement.
|
||||
pub fn parameters(&self) -> &[Type] {
|
||||
&self.params
|
||||
}
|
||||
|
||||
/// Returns information about the resulting columns for this statement.
|
||||
pub fn columns(&self) -> &[Column] {
|
||||
&self.columns
|
||||
}
|
||||
}
|
@ -1,221 +0,0 @@
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use futures::future::Either;
|
||||
use futures::{Future, IntoFuture, Poll, Sink, Stream as FuturesStream};
|
||||
use postgres_protocol::message::backend;
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_shared::params::Host;
|
||||
use std::io::{self, Read, Write};
|
||||
use std::time::Duration;
|
||||
use tokio_core::net::TcpStream;
|
||||
use tokio_core::reactor::Handle;
|
||||
use tokio_dns;
|
||||
use tokio_io::codec::{Decoder, Encoder, Framed};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[cfg(unix)]
|
||||
use tokio_uds::UnixStream;
|
||||
|
||||
use error;
|
||||
use tls::TlsStream;
|
||||
use {BoxedFuture, Error, TlsMode};
|
||||
|
||||
pub type PostgresStream = Framed<Box<TlsStream>, PostgresCodec>;
|
||||
|
||||
pub fn connect(
|
||||
host: Host,
|
||||
port: u16,
|
||||
keepalive: Option<Duration>,
|
||||
tls_mode: TlsMode,
|
||||
handle: &Handle,
|
||||
) -> Box<Future<Item = PostgresStream, Error = Error> + Send> {
|
||||
let inner = match host {
|
||||
Host::Tcp(ref host) => Either::A(
|
||||
tokio_dns::tcp_connect((&**host, port), handle.remote().clone())
|
||||
.and_then(move |s| match keepalive {
|
||||
Some(keepalive) => s.set_keepalive(Some(keepalive)).map(|_| s),
|
||||
None => Ok(s),
|
||||
})
|
||||
.map(|s| Stream(InnerStream::Tcp(s)))
|
||||
.map_err(error::io),
|
||||
),
|
||||
#[cfg(unix)]
|
||||
Host::Unix(ref host) => {
|
||||
let addr = host.join(format!(".s.PGSQL.{}", port));
|
||||
Either::B(
|
||||
UnixStream::connect(addr, handle)
|
||||
.map(|s| Stream(InnerStream::Unix(s)))
|
||||
.map_err(error::io)
|
||||
.into_future(),
|
||||
)
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
Host::Unix(_) => Either::B(
|
||||
Err(error::connect(
|
||||
"unix sockets are not supported on this platform".into(),
|
||||
)).into_future(),
|
||||
),
|
||||
};
|
||||
|
||||
let (required, handshaker) = match tls_mode {
|
||||
TlsMode::Require(h) => (true, h),
|
||||
TlsMode::Prefer(h) => (false, h),
|
||||
TlsMode::None => {
|
||||
return inner
|
||||
.map(|s| {
|
||||
let s: Box<TlsStream> = Box::new(s);
|
||||
s.framed(PostgresCodec)
|
||||
})
|
||||
.boxed2()
|
||||
}
|
||||
};
|
||||
|
||||
inner
|
||||
.map(|s| s.framed(SslCodec))
|
||||
.and_then(|s| {
|
||||
let mut buf = vec![];
|
||||
frontend::ssl_request(&mut buf);
|
||||
s.send(buf).map_err(error::io)
|
||||
})
|
||||
.and_then(|s| s.into_future().map_err(|e| error::io(e.0)))
|
||||
.and_then(move |(m, s)| {
|
||||
let s = s.into_inner();
|
||||
match (m, required) {
|
||||
(Some(b'N'), true) => Either::A(
|
||||
Err(error::tls("the server does not support TLS".into())).into_future(),
|
||||
),
|
||||
(Some(b'N'), false) => {
|
||||
let s: Box<TlsStream> = Box::new(s);
|
||||
Either::A(Ok(s).into_future())
|
||||
}
|
||||
(None, _) => Either::A(
|
||||
Err(error::io(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF",
|
||||
))).into_future(),
|
||||
),
|
||||
_ => {
|
||||
let host = match host {
|
||||
Host::Tcp(ref host) => host,
|
||||
Host::Unix(_) => unreachable!(),
|
||||
};
|
||||
Either::B(handshaker.handshake(host, s).map_err(error::tls))
|
||||
}
|
||||
}
|
||||
})
|
||||
.map(|s| s.framed(PostgresCodec))
|
||||
.boxed2()
|
||||
}
|
||||
|
||||
/// A raw connection to the database.
|
||||
pub struct Stream(InnerStream);
|
||||
|
||||
enum InnerStream {
|
||||
Tcp(TcpStream),
|
||||
#[cfg(unix)]
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl Read for Stream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.read(buf),
|
||||
#[cfg(unix)]
|
||||
InnerStream::Unix(ref mut s) => s.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.write(buf),
|
||||
#[cfg(unix)]
|
||||
InnerStream::Unix(ref mut s) => s.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.flush(),
|
||||
#[cfg(unix)]
|
||||
InnerStream::Unix(ref mut s) => s.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Stream {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref s) => s.prepare_uninitialized_buffer(buf),
|
||||
#[cfg(unix)]
|
||||
InnerStream::Unix(ref s) => s.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
|
||||
where
|
||||
B: BufMut,
|
||||
{
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.read_buf(buf),
|
||||
#[cfg(unix)]
|
||||
InnerStream::Unix(ref mut s) => s.read_buf(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
match self.0 {
|
||||
InnerStream::Tcp(ref mut s) => s.shutdown(),
|
||||
#[cfg(unix)]
|
||||
InnerStream::Unix(ref mut s) => s.shutdown(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for PostgresCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, msg: Vec<u8>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.extend(&msg);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct SslCodec;
|
||||
|
||||
impl Decoder for SslCodec {
|
||||
type Item = u8;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<u8>> {
|
||||
if buf.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(buf.split_to(1)[0]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for SslCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, msg: Vec<u8>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.extend(&msg);
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,481 +0,0 @@
|
||||
use futures::{Future, Stream};
|
||||
use futures_state_stream::StateStream;
|
||||
use std::error::Error as StdError;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use tokio_core::reactor::{Core, Interval};
|
||||
|
||||
use super::*;
|
||||
use error::SqlState;
|
||||
use params::{ConnectParams, Host};
|
||||
use types::{FromSql, IsNull, Kind, ToSql, Type};
|
||||
|
||||
#[test]
|
||||
fn md5_user() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect(
|
||||
"postgres://md5_user:password@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn md5_user_no_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect(
|
||||
"postgres://md5_user@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
match l.run(done) {
|
||||
Err(ref e) if e.as_connection().is_some() => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn md5_user_wrong_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect(
|
||||
"postgres://md5_user:foobar@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
match l.run(done) {
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pass_user() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect(
|
||||
"postgres://pass_user:password@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pass_user_no_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect(
|
||||
"postgres://pass_user@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
match l.run(done) {
|
||||
Err(ref e) if e.as_connection().is_some() => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pass_user_wrong_pass() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect(
|
||||
"postgres://pass_user:foobar@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
match l.run(done) {
|
||||
Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {}
|
||||
Err(e) => panic!("unexpected error {}", e),
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_execute_ok() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
&l.handle(),
|
||||
).then(|c| {
|
||||
c.unwrap()
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL);")
|
||||
});
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn batch_execute_err() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
&l.handle(),
|
||||
).then(|r| {
|
||||
r.unwrap().batch_execute(
|
||||
"CREATE TEMPORARY TABLE foo (id SERIAL); INSERT INTO foo DEFAULT \
|
||||
VALUES;",
|
||||
)
|
||||
})
|
||||
.and_then(|c| c.batch_execute("SELECT * FROM bogo"))
|
||||
.then(|r| match r {
|
||||
Err((e, s)) => {
|
||||
assert_eq!(e.code(), Some(&SqlState::UNDEFINED_TABLE));
|
||||
s.batch_execute("SELECT * FROM foo")
|
||||
}
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
});
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_execute() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
&l.handle(),
|
||||
).then(|c| {
|
||||
c.unwrap()
|
||||
.prepare("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, name VARCHAR)")
|
||||
})
|
||||
.and_then(|(s, c)| c.execute(&s, &[]))
|
||||
.and_then(|(n, c)| {
|
||||
assert_eq!(0, n);
|
||||
c.prepare("INSERT INTO foo (name) VALUES ($1), ($2)")
|
||||
})
|
||||
.and_then(|(s, c)| c.execute(&s, &[&"steven", &"bob"]))
|
||||
.map(|(n, _)| assert_eq!(n, 2));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepare_execute_rows() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
&l.handle(),
|
||||
).then(|c| c.unwrap().prepare("SELECT 1"))
|
||||
.and_then(|(s, c)| c.execute(&s, &[]));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
&l.handle(),
|
||||
).then(|c| {
|
||||
c.unwrap().batch_execute(
|
||||
"CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);
|
||||
INSERT INTO foo (name) VALUES ('joe'), ('bob')",
|
||||
)
|
||||
})
|
||||
.and_then(|c| c.prepare("SELECT id, name FROM foo ORDER BY id"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.and_then(|(r, c)| {
|
||||
assert_eq!(r[0].get::<i32, _>("id"), 1);
|
||||
assert_eq!(r[0].get::<String, _>("name"), "joe");
|
||||
assert_eq!(r[1].get::<i32, _>("id"), 2);
|
||||
assert_eq!(r[1].get::<String, _>("name"), "bob");
|
||||
c.prepare("")
|
||||
})
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.map(|(r, _)| assert!(r.is_empty()));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn transaction() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
&l.handle(),
|
||||
).then(|c| {
|
||||
c.unwrap()
|
||||
.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name VARCHAR);")
|
||||
})
|
||||
.then(|c| c.unwrap().transaction())
|
||||
.then(|t| {
|
||||
t.unwrap()
|
||||
.batch_execute("INSERT INTO foo (name) VALUES ('joe');")
|
||||
})
|
||||
.then(|t| t.unwrap().rollback())
|
||||
.then(|c| c.unwrap().transaction())
|
||||
.then(|t| {
|
||||
t.unwrap()
|
||||
.batch_execute("INSERT INTO foo (name) VALUES ('bob');")
|
||||
})
|
||||
.then(|t| t.unwrap().commit())
|
||||
.then(|c| c.unwrap().prepare("SELECT name FROM foo"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.map(|(r, _)| {
|
||||
assert_eq!(r.len(), 1);
|
||||
assert_eq!(r[0].get::<String, _>("name"), "bob");
|
||||
});
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // not supported on our CI setup :(
|
||||
fn unix_socket() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
|
||||
.then(|c| c.unwrap().prepare("SHOW unix_socket_directories"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.then(|r| {
|
||||
let r = r.unwrap().0;
|
||||
let params = ConnectParams::builder()
|
||||
.user("postgres", None)
|
||||
.build(Host::Unix(PathBuf::from(r[0].get::<String, _>(0))));
|
||||
Connection::connect(params, TlsMode::None, &handle)
|
||||
})
|
||||
.then(|c| c.unwrap().batch_execute(""));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssl_user_ssl_required() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
|
||||
let done = Connection::connect(
|
||||
"postgres://ssl_user@localhost:5433/postgres",
|
||||
TlsMode::None,
|
||||
&handle,
|
||||
);
|
||||
|
||||
match l.run(done) {
|
||||
Err(ref e) => assert_eq!(
|
||||
e.code(),
|
||||
Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION)
|
||||
),
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "with-openssl")]
|
||||
#[test]
|
||||
fn openssl_required() {
|
||||
use tls::openssl::openssl::ssl::{SslConnector, SslMethod};
|
||||
use tls::openssl::OpenSsl;
|
||||
|
||||
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
|
||||
builder.set_ca_file("../test/server.crt").unwrap();
|
||||
let negotiator = OpenSsl::from(builder.build());
|
||||
|
||||
let mut l = Core::new().unwrap();
|
||||
let done = Connection::connect(
|
||||
"postgres://ssl_user@localhost:5433/postgres",
|
||||
TlsMode::Require(Box::new(negotiator)),
|
||||
&l.handle(),
|
||||
).then(|c| c.unwrap().prepare("SELECT 1"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.map(|(r, _)| assert_eq!(r[0].get::<i32, _>(0), 1));
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn domain() {
|
||||
#[derive(Debug, PartialEq)]
|
||||
struct SessionId(Vec<u8>);
|
||||
|
||||
impl ToSql for SessionId {
|
||||
fn to_sql(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut Vec<u8>,
|
||||
) -> Result<IsNull, Box<StdError + Sync + Send>> {
|
||||
let inner = match *ty.kind() {
|
||||
Kind::Domain(ref inner) => inner,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.0.to_sql(inner, out)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
match *ty.kind() {
|
||||
Kind::Domain(Type::BYTEA) => ty.name() == "session_id",
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
to_sql_checked!();
|
||||
}
|
||||
|
||||
impl<'a> FromSql<'a> for SessionId {
|
||||
fn from_sql(ty: &Type, raw: &[u8]) -> Result<Self, Box<StdError + Sync + Send>> {
|
||||
Vec::<u8>::from_sql(ty, raw).map(SessionId)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
// This is super weird!
|
||||
<Vec<u8> as FromSql>::accepts(ty)
|
||||
}
|
||||
}
|
||||
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
|
||||
.then(|c| {
|
||||
c.unwrap().batch_execute(
|
||||
"CREATE DOMAIN pg_temp.session_id AS bytea \
|
||||
CHECK(octet_length(VALUE) = 16);
|
||||
CREATE \
|
||||
TABLE pg_temp.foo (id pg_temp.session_id);",
|
||||
)
|
||||
})
|
||||
.and_then(|c| c.prepare("INSERT INTO pg_temp.foo (id) VALUES ($1)"))
|
||||
.and_then(|(s, c)| {
|
||||
let id = SessionId(b"0123456789abcdef".to_vec());
|
||||
c.execute(&s, &[&id])
|
||||
})
|
||||
.and_then(|(_, c)| c.prepare("SELECT id FROM pg_temp.foo"))
|
||||
.and_then(|(s, c)| c.query(&s, &[]).collect())
|
||||
.map(|(r, _)| {
|
||||
let id = SessionId(b"0123456789abcdef".to_vec());
|
||||
assert_eq!(id, r[0].get(0));
|
||||
});
|
||||
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composite() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
|
||||
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
|
||||
.then(|c| {
|
||||
c.unwrap().batch_execute(
|
||||
"CREATE TYPE pg_temp.inventory_item AS (
|
||||
name TEXT,
|
||||
supplier INTEGER,
|
||||
price NUMERIC
|
||||
)",
|
||||
)
|
||||
})
|
||||
.and_then(|c| c.prepare("SELECT $1::inventory_item"))
|
||||
.map(|(s, _)| {
|
||||
let type_ = &s.parameters()[0];
|
||||
assert_eq!(type_.name(), "inventory_item");
|
||||
match *type_.kind() {
|
||||
Kind::Composite(ref fields) => {
|
||||
assert_eq!(fields[0].name(), "name");
|
||||
assert_eq!(fields[0].type_(), &Type::TEXT);
|
||||
assert_eq!(fields[1].name(), "supplier");
|
||||
assert_eq!(fields[1].type_(), &Type::INT4);
|
||||
assert_eq!(fields[2].name(), "price");
|
||||
assert_eq!(fields[2].type_(), &Type::NUMERIC);
|
||||
}
|
||||
ref t => panic!("bad type {:?}", t),
|
||||
}
|
||||
});
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enum_() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
|
||||
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
|
||||
.then(|c| {
|
||||
c.unwrap()
|
||||
.batch_execute("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');")
|
||||
})
|
||||
.and_then(|c| c.prepare("SELECT $1::mood"))
|
||||
.map(|(s, _)| {
|
||||
let type_ = &s.parameters()[0];
|
||||
assert_eq!(type_.name(), "mood");
|
||||
match *type_.kind() {
|
||||
Kind::Enum(ref variants) => {
|
||||
assert_eq!(
|
||||
variants,
|
||||
&["sad".to_owned(), "ok".to_owned(), "happy".to_owned()]
|
||||
);
|
||||
}
|
||||
_ => panic!("bad type"),
|
||||
}
|
||||
});
|
||||
|
||||
l.run(done).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cancel() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
|
||||
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
|
||||
.then(move |c| {
|
||||
let c = c.unwrap();
|
||||
let cancel_data = c.cancel_data();
|
||||
let cancel = Interval::new(Duration::from_secs(1), &handle)
|
||||
.unwrap()
|
||||
.into_future()
|
||||
.then(move |r| {
|
||||
assert!(r.is_ok());
|
||||
cancel_query(
|
||||
"postgres://postgres@localhost:5433",
|
||||
TlsMode::None,
|
||||
cancel_data,
|
||||
&handle,
|
||||
)
|
||||
})
|
||||
.then(Ok::<_, ()>);
|
||||
c.batch_execute("SELECT pg_sleep(10)")
|
||||
.then(Ok::<_, ()>)
|
||||
.join(cancel)
|
||||
});
|
||||
|
||||
let (select, cancel) = l.run(done).unwrap();
|
||||
cancel.unwrap();
|
||||
match select {
|
||||
Err((e, _)) => assert_eq!(e.code(), Some(&SqlState::QUERY_CANCELED)),
|
||||
Ok(_) => panic!("unexpected success"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn notifications() {
|
||||
let mut l = Core::new().unwrap();
|
||||
let handle = l.handle();
|
||||
|
||||
let done = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle)
|
||||
.then(|c| c.unwrap().batch_execute("LISTEN test_notifications"))
|
||||
.and_then(|c1| {
|
||||
Connection::connect("postgres://postgres@localhost:5433", TlsMode::None, &handle).then(
|
||||
|c2| {
|
||||
c2.unwrap()
|
||||
.batch_execute("NOTIFY test_notifications, 'foo'")
|
||||
.map(|_| c1)
|
||||
},
|
||||
)
|
||||
})
|
||||
.and_then(|c| {
|
||||
c.notifications()
|
||||
.into_future()
|
||||
.map_err(|(e, n)| (e, n.into_inner()))
|
||||
})
|
||||
.map(|(n, _)| {
|
||||
let n = n.unwrap();
|
||||
assert_eq!(n.channel, "test_notifications");
|
||||
assert_eq!(n.payload, "foo");
|
||||
});
|
||||
|
||||
l.run(done).unwrap();
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
//! TLS support.
|
||||
|
||||
use futures::Future;
|
||||
use std::error::Error;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use stream::Stream;
|
||||
|
||||
#[cfg(feature = "with-openssl")]
|
||||
pub mod openssl;
|
||||
|
||||
/// A trait implemented by streams returned from `Handshake` implementations.
|
||||
pub trait TlsStream: AsyncRead + AsyncWrite + Send {
|
||||
/// Returns a shared reference to the inner stream.
|
||||
fn get_ref(&self) -> &Stream;
|
||||
|
||||
/// Returns a mutable reference to the inner stream.
|
||||
fn get_mut(&mut self) -> &mut Stream;
|
||||
}
|
||||
|
||||
impl TlsStream for Stream {
|
||||
fn get_ref(&self) -> &Stream {
|
||||
self
|
||||
}
|
||||
|
||||
fn get_mut(&mut self) -> &mut Stream {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait implemented by types that can manage TLS encryption for a stream.
|
||||
pub trait Handshake: 'static + Sync + Send {
|
||||
/// Performs a TLS handshake, returning a wrapped stream.
|
||||
fn handshake(
|
||||
&self,
|
||||
host: &str,
|
||||
stream: Stream,
|
||||
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Send>;
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
//! OpenSSL support.
|
||||
extern crate tokio_openssl;
|
||||
pub extern crate openssl;
|
||||
|
||||
use futures::Future;
|
||||
use self::openssl::ssl::{SslMethod, SslConnector};
|
||||
use self::openssl::error::ErrorStack;
|
||||
use std::error::Error;
|
||||
use self::tokio_openssl::{SslConnectorExt, SslStream};
|
||||
|
||||
use BoxedFuture;
|
||||
use tls::{Stream, TlsStream, Handshake};
|
||||
|
||||
impl TlsStream for SslStream<Stream> {
|
||||
fn get_ref(&self) -> &Stream {
|
||||
self.get_ref().get_ref()
|
||||
}
|
||||
|
||||
fn get_mut(&mut self) -> &mut Stream {
|
||||
self.get_mut().get_mut()
|
||||
}
|
||||
}
|
||||
|
||||
/// A `Handshake` implementation using OpenSSL.
|
||||
pub struct OpenSsl(SslConnector);
|
||||
|
||||
impl OpenSsl {
|
||||
/// Creates a new `OpenSsl` with default settings.
|
||||
pub fn new() -> Result<OpenSsl, ErrorStack> {
|
||||
let connector = SslConnector::builder(SslMethod::tls())?.build();
|
||||
Ok(OpenSsl(connector))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SslConnector> for OpenSsl {
|
||||
fn from(connector: SslConnector) -> OpenSsl {
|
||||
OpenSsl(connector)
|
||||
}
|
||||
}
|
||||
|
||||
impl Handshake for OpenSsl {
|
||||
fn handshake(
|
||||
&self,
|
||||
host: &str,
|
||||
stream: Stream,
|
||||
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Send> {
|
||||
self.0
|
||||
.connect_async(host, stream)
|
||||
.map(|s| {
|
||||
let s: Box<TlsStream> = Box::new(s);
|
||||
s
|
||||
})
|
||||
.map_err(|e| {
|
||||
let e: Box<Error + Sync + Send> = Box::new(e);
|
||||
e
|
||||
})
|
||||
.boxed2()
|
||||
}
|
||||
}
|
@ -1,90 +0,0 @@
|
||||
//! Transactions.
|
||||
|
||||
use futures::Future;
|
||||
use futures_state_stream::StateStream;
|
||||
|
||||
use {Connection, BoxedFuture, BoxedStateStream};
|
||||
use error::Error;
|
||||
use stmt::Statement;
|
||||
use types::ToSql;
|
||||
use rows::Row;
|
||||
|
||||
/// An in progress Postgres transaction.
|
||||
#[derive(Debug)]
|
||||
pub struct Transaction(Connection);
|
||||
|
||||
impl Transaction {
|
||||
pub(crate) fn new(c: Connection) -> Transaction {
|
||||
Transaction(c)
|
||||
}
|
||||
|
||||
/// Like `Connection::batch_execute`.
|
||||
pub fn batch_execute(
|
||||
self,
|
||||
query: &str,
|
||||
) -> Box<Future<Item = Transaction, Error = (Error, Transaction)> + Send> {
|
||||
self.0
|
||||
.batch_execute(query)
|
||||
.map(Transaction)
|
||||
.map_err(transaction_err)
|
||||
.boxed2()
|
||||
}
|
||||
|
||||
/// Like `Connection::prepare`.
|
||||
pub fn prepare(
|
||||
self,
|
||||
query: &str,
|
||||
) -> Box<Future<Item = (Statement, Transaction), Error = (Error, Transaction)> + Send> {
|
||||
self.0
|
||||
.prepare(query)
|
||||
.map(|(s, c)| (s, Transaction(c)))
|
||||
.map_err(transaction_err)
|
||||
.boxed2()
|
||||
}
|
||||
|
||||
/// Like `Connection::execute`.
|
||||
pub fn execute(
|
||||
self,
|
||||
statement: &Statement,
|
||||
params: &[&ToSql],
|
||||
) -> Box<Future<Item = (u64, Transaction), Error = (Error, Transaction)> + Send> {
|
||||
self.0
|
||||
.execute(statement, params)
|
||||
.map(|(n, c)| (n, Transaction(c)))
|
||||
.map_err(transaction_err)
|
||||
.boxed2()
|
||||
}
|
||||
|
||||
/// Like `Connection::query`.
|
||||
pub fn query(
|
||||
self,
|
||||
statement: &Statement,
|
||||
params: &[&ToSql],
|
||||
) -> Box<StateStream<Item = Row, State = Transaction, Error = Error> + Send> {
|
||||
self.0
|
||||
.query(statement, params)
|
||||
.map_state(Transaction)
|
||||
.boxed2()
|
||||
}
|
||||
|
||||
/// Commits the transaction.
|
||||
pub fn commit(self) -> Box<Future<Item = Connection, Error = (Error, Connection)> + Send> {
|
||||
self.finish("COMMIT")
|
||||
}
|
||||
|
||||
/// Rolls back the transaction.
|
||||
pub fn rollback(self) -> Box<Future<Item = Connection, Error = (Error, Connection)> + Send> {
|
||||
self.finish("ROLLBACK")
|
||||
}
|
||||
|
||||
fn finish(
|
||||
self,
|
||||
query: &str,
|
||||
) -> Box<Future<Item = Connection, Error = (Error, Connection)> + Send> {
|
||||
self.0.simple_query(query).map(|(_, c)| c).boxed2()
|
||||
}
|
||||
}
|
||||
|
||||
fn transaction_err((e, c): (Error, Connection)) -> (Error, Transaction) {
|
||||
(e, Transaction(c))
|
||||
}
|
Loading…
Reference in New Issue
Block a user