diff --git a/.circleci/config.yml b/.circleci/config.yml index 324151f1..68edc5c1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -37,5 +37,6 @@ jobs: - run: cargo fmt --all -- --check - run: cargo clippy --all - run: cargo test --all + - run: cargo test --manifest-path tokio-postgres/Cargo.toml --no-default-features - run: cargo test --manifest-path tokio-postgres/Cargo.toml --all-features - *SAVE_DEPS diff --git a/tokio-postgres-native-tls/Cargo.toml b/tokio-postgres-native-tls/Cargo.toml index 9ecfb2c3..6ba2d66e 100644 --- a/tokio-postgres-native-tls/Cargo.toml +++ b/tokio-postgres-native-tls/Cargo.toml @@ -9,7 +9,7 @@ futures = "0.1" native-tls = "0.2" tokio-io = "0.1" tokio-tls = "0.2" -tokio-postgres = { version = "0.3", path = "../tokio-postgres" } +tokio-postgres = { version = "0.3", path = "../tokio-postgres", default-features = false } [dev-dependencies] tokio = "0.1.7" diff --git a/tokio-postgres-native-tls/src/test.rs b/tokio-postgres-native-tls/src/test.rs index 6c2a8ac6..8e21bf0d 100644 --- a/tokio-postgres-native-tls/src/test.rs +++ b/tokio-postgres-native-tls/src/test.rs @@ -15,7 +15,7 @@ where let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .map_err(|e| panic!("{}", e)) - .and_then(|s| builder.connect(s, tls)); + .and_then(|s| builder.handshake(s, tls)); let (mut client, connection) = runtime.block_on(handshake).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.spawn(connection); diff --git a/tokio-postgres-openssl/Cargo.toml b/tokio-postgres-openssl/Cargo.toml index 3903050e..c875b118 100644 --- a/tokio-postgres-openssl/Cargo.toml +++ b/tokio-postgres-openssl/Cargo.toml @@ -9,7 +9,7 @@ futures = "0.1" openssl = "0.10" tokio-io = "0.1" tokio-openssl = "0.3" -tokio-postgres = { version = "0.3", path = "../tokio-postgres" } +tokio-postgres = { version = "0.3", path = "../tokio-postgres", default-features = false } [dev-dependencies] tokio = "0.1.7" diff --git a/tokio-postgres-openssl/src/test.rs b/tokio-postgres-openssl/src/test.rs index db58b448..6729916b 100644 --- a/tokio-postgres-openssl/src/test.rs +++ b/tokio-postgres-openssl/src/test.rs @@ -15,7 +15,7 @@ where let handshake = TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .map_err(|e| panic!("{}", e)) - .and_then(|s| builder.connect(s, tls)); + .and_then(|s| builder.handshake(s, tls)); let (mut client, connection) = runtime.block_on(handshake).unwrap(); let connection = connection.map_err(|e| panic!("{}", e)); runtime.spawn(connection); diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index 0b9eb907..c317c5d6 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -27,6 +27,9 @@ features = [ circle-ci = { repository = "sfackler/rust-postgres" } [features] +default = ["runtime"] +runtime = ["tokio-tcp", "tokio-uds"] + "with-bit-vec-0.5" = ["bit-vec-05"] "with-chrono-0.4" = ["chrono-04"] "with-eui48-0.4" = ["eui48-04"] @@ -48,6 +51,8 @@ tokio-codec = "0.1" tokio-io = "0.1" void = "1.0" +tokio-tcp = { version = "0.1", optional = true } + bit-vec-05 = { version = "0.5", package = "bit-vec", optional = true } chrono-04 = { version = "0.4", package = "chrono", optional = true } eui48-04 = { version = "0.4", package = "eui48", optional = true } @@ -56,6 +61,9 @@ serde-1 = { version = "1.0", package = "serde", optional = true } serde_json-1 = { version = "1.0", package = "serde_json", optional = true } uuid-07 = { version = "0.7", package = "uuid", optional = true } +[target.'cfg(unix)'.dependencies] +tokio-uds = { version = "0.2", optional = true } + [dev-dependencies] tokio = "0.1.7" env_logger = "0.5" diff --git a/tokio-postgres/src/builder.rs b/tokio-postgres/src/builder.rs index f2919c6b..85981e9a 100644 --- a/tokio-postgres/src/builder.rs +++ b/tokio-postgres/src/builder.rs @@ -3,8 +3,8 @@ use std::iter; use std::str::{self, FromStr}; use tokio_io::{AsyncRead, AsyncWrite}; -use crate::proto::ConnectFuture; -use crate::{Connect, Error, TlsMode}; +use crate::proto::HandshakeFuture; +use crate::{Error, Handshake, TlsMode}; #[derive(Clone)] pub struct Builder { @@ -48,12 +48,12 @@ impl Builder { Iter(self.params.iter()) } - pub fn connect(&self, stream: S, tls_mode: T) -> Connect + pub fn handshake(&self, stream: S, tls_mode: T) -> Handshake where S: AsyncRead + AsyncWrite, T: TlsMode, { - Connect(ConnectFuture::new(stream, tls_mode, self.params.clone())) + Handshake(HandshakeFuture::new(stream, tls_mode, self.params.clone())) } } diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index e1f86ce9..c9d77d1c 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -10,6 +10,8 @@ pub use crate::builder::*; pub use crate::error::*; use crate::proto::CancelFuture; pub use crate::row::{Row, RowIndex}; +#[cfg(feature = "runtime")] +pub use crate::socket::Socket; pub use crate::stmt::Column; pub use crate::tls::*; use crate::types::{ToSql, Type}; @@ -18,6 +20,8 @@ mod builder; pub mod error; mod proto; mod row; +#[cfg(feature = "runtime")] +mod socket; mod stmt; mod tls; pub mod types; @@ -156,12 +160,12 @@ where } #[must_use = "futures do nothing unless polled"] -pub struct Connect(proto::ConnectFuture) +pub struct Handshake(proto::HandshakeFuture) where S: AsyncRead + AsyncWrite, T: TlsMode; -impl Future for Connect +impl Future for Handshake where S: AsyncRead + AsyncWrite, T: TlsMode, diff --git a/tokio-postgres/src/proto/connect.rs b/tokio-postgres/src/proto/handshake.rs similarity index 98% rename from tokio-postgres/src/proto/connect.rs rename to tokio-postgres/src/proto/handshake.rs index c2dde72b..3089ffc1 100644 --- a/tokio-postgres/src/proto/connect.rs +++ b/tokio-postgres/src/proto/handshake.rs @@ -16,7 +16,7 @@ use crate::proto::{Client, Connection, PostgresCodec, TlsFuture}; use crate::{CancelData, ChannelBinding, Error, TlsMode}; #[derive(StateMachineFuture)] -pub enum Connect +pub enum Handshake where S: AsyncRead + AsyncWrite, T: TlsMode, @@ -70,7 +70,7 @@ where Failed(Error), } -impl PollConnect for Connect +impl PollHandshake for Handshake where S: AsyncRead + AsyncWrite, T: TlsMode, @@ -319,12 +319,12 @@ where } } -impl ConnectFuture +impl HandshakeFuture where S: AsyncRead + AsyncWrite, T: TlsMode, { - pub fn new(stream: S, tls_mode: T, params: HashMap) -> ConnectFuture { - Connect::start(TlsFuture::new(stream, tls_mode), params) + pub fn new(stream: S, tls_mode: T, params: HashMap) -> HandshakeFuture { + Handshake::start(TlsFuture::new(stream, tls_mode), params) } } diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs index b1c82cb3..9d19fa0e 100644 --- a/tokio-postgres/src/proto/mod.rs +++ b/tokio-postgres/src/proto/mod.rs @@ -22,11 +22,11 @@ mod bind; mod cancel; mod client; mod codec; -mod connect; mod connection; mod copy_in; mod copy_out; mod execute; +mod handshake; mod portal; mod prepare; mod query; @@ -42,11 +42,11 @@ pub use crate::proto::bind::BindFuture; pub use crate::proto::cancel::CancelFuture; pub use crate::proto::client::Client; pub use crate::proto::codec::PostgresCodec; -pub use crate::proto::connect::ConnectFuture; pub use crate::proto::connection::Connection; pub use crate::proto::copy_in::CopyInFuture; pub use crate::proto::copy_out::CopyOutStream; pub use crate::proto::execute::ExecuteFuture; +pub use crate::proto::handshake::HandshakeFuture; pub use crate::proto::portal::Portal; pub use crate::proto::prepare::PrepareFuture; pub use crate::proto::query::QueryStream; diff --git a/tokio-postgres/src/socket.rs b/tokio-postgres/src/socket.rs new file mode 100644 index 00000000..32d30b24 --- /dev/null +++ b/tokio-postgres/src/socket.rs @@ -0,0 +1,96 @@ +use bytes::{Buf, BufMut}; +use futures::Poll; +use std::io::{self, Read, Write}; +use tokio_io::{AsyncRead, AsyncWrite}; +use tokio_tcp::TcpStream; +#[cfg(unix)] +use tokio_uds::UnixStream; + +enum Inner { + Tcp(TcpStream), + #[cfg(unix)] + Unix(UnixStream), +} + +pub struct Socket(Inner); + +impl Socket { + pub(crate) fn new_tcp(stream: TcpStream) -> Socket { + Socket(Inner::Tcp(stream)) + } + + #[cfg(unix)] + pub(crate) fn new_unix(stream: UnixStream) -> Socket { + Socket(Inner::Unix(stream)) + } +} + +impl Read for Socket { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match &mut self.0 { + Inner::Tcp(s) => s.read(buf), + #[cfg(unix)] + Inner::Unix(s) => s.read(buf), + } + } +} + +impl AsyncRead for Socket { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match &self.0 { + Inner::Tcp(s) => s.prepare_uninitialized_buffer(buf), + #[cfg(unix)] + Inner::Unix(s) => s.prepare_uninitialized_buffer(buf), + } + } + + fn read_buf(&mut self, buf: &mut B) -> Poll + where + B: BufMut, + { + match &mut self.0 { + Inner::Tcp(s) => s.read_buf(buf), + #[cfg(unix)] + Inner::Unix(s) => s.read_buf(buf), + } + } +} + +impl Write for Socket { + fn write(&mut self, buf: &[u8]) -> io::Result { + match &mut self.0 { + Inner::Tcp(s) => s.write(buf), + #[cfg(unix)] + Inner::Unix(s) => s.write(buf), + } + } + + fn flush(&mut self) -> io::Result<()> { + match &mut self.0 { + Inner::Tcp(s) => s.flush(), + #[cfg(unix)] + Inner::Unix(s) => s.flush(), + } + } +} + +impl AsyncWrite for Socket { + fn shutdown(&mut self) -> Poll<(), io::Error> { + match &mut self.0 { + Inner::Tcp(s) => s.shutdown(), + #[cfg(unix)] + Inner::Unix(s) => s.shutdown(), + } + } + + fn write_buf(&mut self, buf: &mut B) -> Poll + where + B: Buf, + { + match &mut self.0 { + Inner::Tcp(s) => s.write_buf(buf), + #[cfg(unix)] + Inner::Unix(s) => s.write_buf(buf), + } + } +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 15e39137..ad1736f0 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -22,7 +22,7 @@ fn connect( let builder = s.parse::().unwrap(); TcpStream::connect(&"127.0.0.1:5433".parse().unwrap()) .map_err(|e| panic!("{}", e)) - .and_then(move |s| builder.connect(s, NoTls)) + .and_then(move |s| builder.handshake(s, NoTls)) } fn smoke_test(s: &str) {