diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index bac19e56..de120ea9 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -1,7 +1,13 @@ //! Connection configuration. +#[cfg(feature = "runtime")] +use crate::connect::connect; use crate::connect_raw::connect_raw; +#[cfg(feature = "runtime")] +use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; +#[cfg(feature = "runtime")] +use crate::Socket; use crate::{Client, Connection, Error}; use std::borrow::Cow; #[cfg(unix)] @@ -367,6 +373,17 @@ impl Config { Ok(()) } + /// Opens a connection to a PostgreSQL database. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + #[cfg(feature = "runtime")] + pub async fn connect(&self, tls: T) -> Result<(Client, Connection), Error> + where + T: MakeTlsConnect, + { + connect(tls, self).await + } + /// Connects to a PostgreSQL database over an arbitrary stream. /// /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored. diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index 8b137891..8bb234d0 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -1 +1,60 @@ +use crate::config::{Host, TargetSessionAttrs}; +use crate::connect_raw::connect_raw; +use crate::connect_socket::connect_socket; +use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::{Client, Config, Connection, Error, Socket}; +pub async fn connect( + mut tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + if config.host.is_empty() { + return Err(Error::config("host missing".into())); + } + + if config.port.len() > 1 && config.port.len() != config.host.len() { + return Err(Error::config("invalid number of ports".into())); + } + + let mut error = None; + for (i, host) in config.host.iter().enumerate() { + let hostname = match host { + Host::Tcp(host) => &**host, + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter + #[cfg(unix)] + Host::Unix(_) => "", + }; + + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + match connect_once(i, tls, config).await { + Ok((client, connection)) => return Ok((client, connection)), + Err(e) => error = Some(e), + } + } + + return Err(error.unwrap()); +} + +async fn connect_once( + idx: usize, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: TlsConnect, +{ + let socket = connect_socket(idx, config).await?; + let (client, connection) = connect_raw(socket, tls, config, Some(idx)).await?; + + if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { + unimplemented!() + } + + Ok((client, connection)) +} diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs new file mode 100644 index 00000000..d88edbc9 --- /dev/null +++ b/tokio-postgres/src/connect_socket.rs @@ -0,0 +1,74 @@ +use crate::config::Host; +use crate::{Config, Error, Socket}; +use std::future::Future; +use std::io; +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::timer::Timeout; + +pub async fn connect_socket(idx: usize, config: &Config) -> Result { + let port = *config + .port + .get(idx) + .or_else(|| config.port.get(0)) + .unwrap_or(&5432); + + match &config.host[idx] { + Host::Tcp(host) => { + let addrs = match host.parse::() { + Ok(ip) => { + // avoid dealing with blocking DNS entirely if possible + vec![SocketAddr::new(ip, port)].into_iter() + } + Err(_) => { + // FIXME what do? + (&**host, port).to_socket_addrs().map_err(Error::connect)? + } + }; + + let mut error = None; + for addr in addrs { + let new_error = match connect_timeout(TcpStream::connect(&addr), config).await { + Ok(socket) => return Ok(Socket::new_tcp(socket)), + Err(e) => e, + }; + error = Some(new_error); + } + + let error = error.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidData, + "resolved 0 addresses", + )) + }); + Err(error) + } + #[cfg(unix)] + Host::Unix(path) => { + let socket = connect_timeout(UnixStream::connect(path), config).await?; + Ok(Socket::new_unix(socket)) + } + } +} + +async fn connect_timeout(connect: F, config: &Config) -> Result +where + F: Future>, +{ + match config.connect_timeout { + Some(connect_timeout) => match Timeout::new(connect, connect_timeout).await { + Ok(Ok(socket)) => Ok(socket), + Ok(Err(e)) => Err(Error::connect(e)), + Err(_) => Err(Error::connect(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + ))), + }, + None => match connect.await { + Ok(socket) => Ok(socket), + Err(e) => Err(Error::connect(e)), + }, + } +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index c1bba074..ea4921ee 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -112,19 +112,48 @@ #![warn(rust_2018_idioms, clippy::all, missing_docs)] #![feature(async_await)] -pub use client::Client; -pub use config::Config; -pub use connection::Connection; -pub use error::Error; +pub use crate::client::Client; +pub use crate::config::Config; +pub use crate::connection::Connection; +pub use crate::error::Error; +#[cfg(feature = "runtime")] +pub use crate::socket::Socket; +#[cfg(feature = "runtime")] +use crate::tls::MakeTlsConnect; +pub use crate::tls::NoTls; mod client; mod codec; pub mod config; +#[cfg(feature = "runtime")] mod connect; mod connect_raw; +#[cfg(feature = "runtime")] +mod connect_socket; mod connect_tls; mod connection; pub mod error; mod maybe_tls_stream; +#[cfg(feature = "runtime")] +mod socket; pub mod tls; pub mod types; + +/// A convenience function which parses a connection string and connects to the database. +/// +/// See the documentation for [`Config`] for details on the connection string format. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +/// +/// [`Config`]: ./Config.t.html +#[cfg(feature = "runtime")] +pub async fn connect( + config: &str, + tls: T, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + let config = config.parse::()?; + config.connect(tls).await +} diff --git a/tokio-postgres/src/socket.rs b/tokio-postgres/src/socket.rs new file mode 100644 index 00000000..74663cf6 --- /dev/null +++ b/tokio-postgres/src/socket.rs @@ -0,0 +1,115 @@ +use bytes::{Buf, BufMut}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; + +#[derive(Debug)] +enum Inner { + Tcp(TcpStream), + Unix(UnixStream), +} + +/// The standard stream type used by the crate. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +#[derive(Debug)] +pub struct Socket(Inner); + +impl Socket { + pub(crate) fn new_tcp(stream: TcpStream) -> Socket { + Socket(Inner::Tcp(stream)) + } + + #[cfg(unix)] + pub(crate) fn new_unix(stream: UnixStream) -> Socket { + Socket(Inner::Unix(stream)) + } +} + +impl AsyncRead for Socket { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + match &self.0 { + Inner::Tcp(s) => s.prepare_uninitialized_buffer(buf), + #[cfg(unix)] + Inner::Unix(s) => s.prepare_uninitialized_buffer(buf), + } + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_read(cx, buf), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_read(cx, buf), + } + } + + fn poll_read_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + Self: Sized, + B: BufMut, + { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_read_buf(cx, buf), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_read_buf(cx, buf), + } + } +} + +impl AsyncWrite for Socket { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_write(cx, buf), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_flush(cx), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_shutdown(cx), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_shutdown(cx), + } + } + + fn poll_write_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut B, + ) -> Poll> + where + Self: Sized, + B: Buf, + { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_write_buf(cx, buf), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_write_buf(cx, buf), + } + } +} diff --git a/tokio-postgres/src/tls.rs b/tokio-postgres/src/tls.rs index 99300521..1e1adeb2 100644 --- a/tokio-postgres/src/tls.rs +++ b/tokio-postgres/src/tls.rs @@ -41,7 +41,7 @@ pub trait MakeTlsConnect { type Stream: AsyncRead + AsyncWrite + Unpin; /// The `TlsConnect` implementation created by this type. type TlsConnect: TlsConnect; - /// The error type retured by the `TlsConnect` implementation. + /// The error type returned by the `TlsConnect` implementation. type Error: Into>; /// Creates a new `TlsConnect`or. @@ -73,6 +73,17 @@ pub trait TlsConnect { /// This can be used when `sslmode` is `none` or `prefer`. pub struct NoTls; +#[cfg(feature = "runtime")] +impl MakeTlsConnect for NoTls { + type Stream = NoTlsStream; + type TlsConnect = NoTls; + type Error = NoTlsError; + + fn make_tls_connect(&mut self, _: &str) -> Result { + Ok(NoTls) + } +} + impl TlsConnect for NoTls { type Stream = NoTlsStream; type Error = NoTlsError; diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 1403cbfa..9bddea87 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -7,9 +7,9 @@ use tokio_postgres::tls::{NoTls, NoTlsStream}; use tokio_postgres::{Client, Config, Connection, Error}; mod parse; -/* #[cfg(feature = "runtime")] mod runtime; +/* mod types; */ diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs index 2af9a18d..c18e5be4 100644 --- a/tokio-postgres/tests/test/runtime.rs +++ b/tokio-postgres/tests/test/runtime.rs @@ -5,70 +5,65 @@ use tokio::timer::Delay; use tokio_postgres::error::SqlState; use tokio_postgres::NoTls; -fn smoke_test(s: &str) { - let mut runtime = Runtime::new().unwrap(); - let connect = tokio_postgres::connect(s, NoTls); - let (mut client, connection) = runtime.block_on(connect).unwrap(); +async fn smoke_test(s: &str) { + let (mut client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap(); + /* let connection = connection.map_err(|e| panic!("{}", e)); runtime.spawn(connection); let execute = client.simple_query("SELECT 1").for_each(|_| Ok(())); runtime.block_on(execute).unwrap(); + */ } -#[test] +#[tokio::test] #[ignore] // FIXME doesn't work with our docker-based tests :( -fn unix_socket() { - smoke_test("host=/var/run/postgresql port=5433 user=postgres"); +async fn unix_socket() { + smoke_test("host=/var/run/postgresql port=5433 user=postgres").await; } -#[test] -fn tcp() { - smoke_test("host=localhost port=5433 user=postgres") +#[tokio::test] +async fn tcp() { + smoke_test("host=localhost port=5433 user=postgres").await; } -#[test] -fn multiple_hosts_one_port() { - smoke_test("host=foobar.invalid,localhost port=5433 user=postgres"); +#[tokio::test] +async fn multiple_hosts_one_port() { + smoke_test("host=foobar.invalid,localhost port=5433 user=postgres").await; } -#[test] -fn multiple_hosts_multiple_ports() { - smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres"); +#[tokio::test] +async fn multiple_hosts_multiple_ports() { + smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres").await; } -#[test] -fn wrong_port_count() { - let mut runtime = Runtime::new().unwrap(); - let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls); - runtime.block_on(f).err().unwrap(); - - let f = tokio_postgres::connect( - "host=localhost,localhost,localhost port=5433,5433 user=postgres", - NoTls, - ); - runtime.block_on(f).err().unwrap(); +#[tokio::test] +async fn wrong_port_count() { + tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls) + .await + .err() + .unwrap(); } -#[test] -fn target_session_attrs_ok() { - let mut runtime = Runtime::new().unwrap(); - let f = tokio_postgres::connect( +/* +#[tokio::test] +async fn target_session_attrs_ok() { + tokio_postgres::connect( "host=localhost port=5433 user=postgres target_session_attrs=read-write", NoTls, - ); - runtime.block_on(f).unwrap(); + ) + .await + .err() + .unwrap(); } -#[test] -fn target_session_attrs_err() { - let mut runtime = Runtime::new().unwrap(); - let f = tokio_postgres::connect( +#[tokio::test] +async fn target_session_attrs_err() { + tokio_postgres::connect( "host=localhost port=5433 user=postgres target_session_attrs=read-write options='-c default_transaction_read_only=on'", NoTls, - ); - runtime.block_on(f).err().unwrap(); + ).await.err().unwrap(); } #[test] @@ -100,3 +95,4 @@ fn cancel_query() { let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap(); } +*/