Runtime connect

This commit is contained in:
Steven Fackler 2019-07-22 20:17:29 -07:00
parent 89501f66d9
commit 32fe52490e
8 changed files with 346 additions and 45 deletions

View File

@ -1,7 +1,13 @@
//! Connection configuration.
#[cfg(feature = "runtime")]
use crate::connect::connect;
use crate::connect_raw::connect_raw;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
#[cfg(feature = "runtime")]
use crate::Socket;
use crate::{Client, Connection, Error};
use std::borrow::Cow;
#[cfg(unix)]
@ -367,6 +373,17 @@ impl Config {
Ok(())
}
/// Opens a connection to a PostgreSQL database.
///
/// Requires the `runtime` Cargo feature (enabled by default).
#[cfg(feature = "runtime")]
pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
connect(tls, self).await
}
/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored.

View File

@ -1 +1,60 @@
use crate::config::{Host, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, Socket};
pub async fn connect<T>(
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
if config.host.is_empty() {
return Err(Error::config("host missing".into()));
}
if config.port.len() > 1 && config.port.len() != config.host.len() {
return Err(Error::config("invalid number of ports".into()));
}
let mut error = None;
for (i, host) in config.host.iter().enumerate() {
let hostname = match host {
Host::Tcp(host) => &**host,
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
#[cfg(unix)]
Host::Unix(_) => "",
};
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
match connect_once(i, tls, config).await {
Ok((client, connection)) => return Ok((client, connection)),
Err(e) => error = Some(e),
}
}
return Err(error.unwrap());
}
async fn connect_once<T>(
idx: usize,
tls: T,
config: &Config,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: TlsConnect<Socket>,
{
let socket = connect_socket(idx, config).await?;
let (client, connection) = connect_raw(socket, tls, config, Some(idx)).await?;
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
unimplemented!()
}
Ok((client, connection))
}

View File

@ -0,0 +1,74 @@
use crate::config::Host;
use crate::{Config, Error, Socket};
use std::future::Future;
use std::io;
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::timer::Timeout;
pub async fn connect_socket(idx: usize, config: &Config) -> Result<Socket, Error> {
let port = *config
.port
.get(idx)
.or_else(|| config.port.get(0))
.unwrap_or(&5432);
match &config.host[idx] {
Host::Tcp(host) => {
let addrs = match host.parse::<IpAddr>() {
Ok(ip) => {
// avoid dealing with blocking DNS entirely if possible
vec![SocketAddr::new(ip, port)].into_iter()
}
Err(_) => {
// FIXME what do?
(&**host, port).to_socket_addrs().map_err(Error::connect)?
}
};
let mut error = None;
for addr in addrs {
let new_error = match connect_timeout(TcpStream::connect(&addr), config).await {
Ok(socket) => return Ok(Socket::new_tcp(socket)),
Err(e) => e,
};
error = Some(new_error);
}
let error = error.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidData,
"resolved 0 addresses",
))
});
Err(error)
}
#[cfg(unix)]
Host::Unix(path) => {
let socket = connect_timeout(UnixStream::connect(path), config).await?;
Ok(Socket::new_unix(socket))
}
}
}
async fn connect_timeout<F, T>(connect: F, config: &Config) -> Result<T, Error>
where
F: Future<Output = io::Result<T>>,
{
match config.connect_timeout {
Some(connect_timeout) => match Timeout::new(connect, connect_timeout).await {
Ok(Ok(socket)) => Ok(socket),
Ok(Err(e)) => Err(Error::connect(e)),
Err(_) => Err(Error::connect(io::Error::new(
io::ErrorKind::TimedOut,
"connection timed out",
))),
},
None => match connect.await {
Ok(socket) => Ok(socket),
Err(e) => Err(Error::connect(e)),
},
}
}

View File

@ -112,19 +112,48 @@
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
#![feature(async_await)]
pub use client::Client;
pub use config::Config;
pub use connection::Connection;
pub use error::Error;
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::connection::Connection;
pub use crate::error::Error;
#[cfg(feature = "runtime")]
pub use crate::socket::Socket;
#[cfg(feature = "runtime")]
use crate::tls::MakeTlsConnect;
pub use crate::tls::NoTls;
mod client;
mod codec;
pub mod config;
#[cfg(feature = "runtime")]
mod connect;
mod connect_raw;
#[cfg(feature = "runtime")]
mod connect_socket;
mod connect_tls;
mod connection;
pub mod error;
mod maybe_tls_stream;
#[cfg(feature = "runtime")]
mod socket;
pub mod tls;
pub mod types;
/// A convenience function which parses a connection string and connects to the database.
///
/// See the documentation for [`Config`] for details on the connection string format.
///
/// Requires the `runtime` Cargo feature (enabled by default).
///
/// [`Config`]: ./Config.t.html
#[cfg(feature = "runtime")]
pub async fn connect<T>(
config: &str,
tls: T,
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
where
T: MakeTlsConnect<Socket>,
{
let config = config.parse::<Config>()?;
config.connect(tls).await
}

View File

@ -0,0 +1,115 @@
use bytes::{Buf, BufMut};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
#[derive(Debug)]
enum Inner {
Tcp(TcpStream),
Unix(UnixStream),
}
/// The standard stream type used by the crate.
///
/// Requires the `runtime` Cargo feature (enabled by default).
#[derive(Debug)]
pub struct Socket(Inner);
impl Socket {
pub(crate) fn new_tcp(stream: TcpStream) -> Socket {
Socket(Inner::Tcp(stream))
}
#[cfg(unix)]
pub(crate) fn new_unix(stream: UnixStream) -> Socket {
Socket(Inner::Unix(stream))
}
}
impl AsyncRead for Socket {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match &self.0 {
Inner::Tcp(s) => s.prepare_uninitialized_buffer(buf),
#[cfg(unix)]
Inner::Unix(s) => s.prepare_uninitialized_buffer(buf),
}
}
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_read(cx, buf),
}
}
fn poll_read_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: BufMut,
{
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_read_buf(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_read_buf(cx, buf),
}
}
}
impl AsyncWrite for Socket {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_flush(cx),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_shutdown(cx),
}
}
fn poll_write_buf<B>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut B,
) -> Poll<io::Result<usize>>
where
Self: Sized,
B: Buf,
{
match &mut self.0 {
Inner::Tcp(s) => Pin::new(s).poll_write_buf(cx, buf),
#[cfg(unix)]
Inner::Unix(s) => Pin::new(s).poll_write_buf(cx, buf),
}
}
}

View File

@ -41,7 +41,7 @@ pub trait MakeTlsConnect<S> {
type Stream: AsyncRead + AsyncWrite + Unpin;
/// The `TlsConnect` implementation created by this type.
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
/// The error type retured by the `TlsConnect` implementation.
/// The error type returned by the `TlsConnect` implementation.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// Creates a new `TlsConnect`or.
@ -73,6 +73,17 @@ pub trait TlsConnect<S> {
/// This can be used when `sslmode` is `none` or `prefer`.
pub struct NoTls;
#[cfg(feature = "runtime")]
impl<S> MakeTlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}
impl<S> TlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type Error = NoTlsError;

View File

@ -7,9 +7,9 @@ use tokio_postgres::tls::{NoTls, NoTlsStream};
use tokio_postgres::{Client, Config, Connection, Error};
mod parse;
/*
#[cfg(feature = "runtime")]
mod runtime;
/*
mod types;
*/

View File

@ -5,70 +5,65 @@ use tokio::timer::Delay;
use tokio_postgres::error::SqlState;
use tokio_postgres::NoTls;
fn smoke_test(s: &str) {
let mut runtime = Runtime::new().unwrap();
let connect = tokio_postgres::connect(s, NoTls);
let (mut client, connection) = runtime.block_on(connect).unwrap();
async fn smoke_test(s: &str) {
let (mut client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap();
/*
let connection = connection.map_err(|e| panic!("{}", e));
runtime.spawn(connection);
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
runtime.block_on(execute).unwrap();
*/
}
#[test]
#[tokio::test]
#[ignore] // FIXME doesn't work with our docker-based tests :(
fn unix_socket() {
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
async fn unix_socket() {
smoke_test("host=/var/run/postgresql port=5433 user=postgres").await;
}
#[test]
fn tcp() {
smoke_test("host=localhost port=5433 user=postgres")
#[tokio::test]
async fn tcp() {
smoke_test("host=localhost port=5433 user=postgres").await;
}
#[test]
fn multiple_hosts_one_port() {
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
#[tokio::test]
async fn multiple_hosts_one_port() {
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres").await;
}
#[test]
fn multiple_hosts_multiple_ports() {
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
#[tokio::test]
async fn multiple_hosts_multiple_ports() {
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres").await;
}
#[test]
fn wrong_port_count() {
let mut runtime = Runtime::new().unwrap();
let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls);
runtime.block_on(f).err().unwrap();
let f = tokio_postgres::connect(
"host=localhost,localhost,localhost port=5433,5433 user=postgres",
NoTls,
);
runtime.block_on(f).err().unwrap();
#[tokio::test]
async fn wrong_port_count() {
tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls)
.await
.err()
.unwrap();
}
#[test]
fn target_session_attrs_ok() {
let mut runtime = Runtime::new().unwrap();
let f = tokio_postgres::connect(
/*
#[tokio::test]
async fn target_session_attrs_ok() {
tokio_postgres::connect(
"host=localhost port=5433 user=postgres target_session_attrs=read-write",
NoTls,
);
runtime.block_on(f).unwrap();
)
.await
.err()
.unwrap();
}
#[test]
fn target_session_attrs_err() {
let mut runtime = Runtime::new().unwrap();
let f = tokio_postgres::connect(
#[tokio::test]
async fn target_session_attrs_err() {
tokio_postgres::connect(
"host=localhost port=5433 user=postgres target_session_attrs=read-write
options='-c default_transaction_read_only=on'",
NoTls,
);
runtime.block_on(f).err().unwrap();
).await.err().unwrap();
}
#[test]
@ -100,3 +95,4 @@ fn cancel_query() {
let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap();
}
*/