Runtime connect
This commit is contained in:
parent
89501f66d9
commit
32fe52490e
@ -1,7 +1,13 @@
|
||||
//! Connection configuration.
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::connect::connect;
|
||||
use crate::connect_raw::connect_raw;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::tls::TlsConnect;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::Socket;
|
||||
use crate::{Client, Connection, Error};
|
||||
use std::borrow::Cow;
|
||||
#[cfg(unix)]
|
||||
@ -367,6 +373,17 @@ impl Config {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
#[cfg(feature = "runtime")]
|
||||
pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
connect(tls, self).await
|
||||
}
|
||||
|
||||
/// Connects to a PostgreSQL database over an arbitrary stream.
|
||||
///
|
||||
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application` name are ignored.
|
||||
|
@ -1 +1,60 @@
|
||||
use crate::config::{Host, TargetSessionAttrs};
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, Socket};
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
if config.host.is_empty() {
|
||||
return Err(Error::config("host missing".into()));
|
||||
}
|
||||
|
||||
if config.port.len() > 1 && config.port.len() != config.host.len() {
|
||||
return Err(Error::config("invalid number of ports".into()));
|
||||
}
|
||||
|
||||
let mut error = None;
|
||||
for (i, host) in config.host.iter().enumerate() {
|
||||
let hostname = match host {
|
||||
Host::Tcp(host) => &**host,
|
||||
// postgres doesn't support TLS over unix sockets, so the choice here doesn't matter
|
||||
#[cfg(unix)]
|
||||
Host::Unix(_) => "",
|
||||
};
|
||||
|
||||
let tls = tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
match connect_once(i, tls, config).await {
|
||||
Ok((client, connection)) => return Ok((client, connection)),
|
||||
Err(e) => error = Some(e),
|
||||
}
|
||||
}
|
||||
|
||||
return Err(error.unwrap());
|
||||
}
|
||||
|
||||
async fn connect_once<T>(
|
||||
idx: usize,
|
||||
tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
|
||||
where
|
||||
T: TlsConnect<Socket>,
|
||||
{
|
||||
let socket = connect_socket(idx, config).await?;
|
||||
let (client, connection) = connect_raw(socket, tls, config, Some(idx)).await?;
|
||||
|
||||
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
74
tokio-postgres/src/connect_socket.rs
Normal file
74
tokio-postgres/src/connect_socket.rs
Normal file
@ -0,0 +1,74 @@
|
||||
use crate::config::Host;
|
||||
use crate::{Config, Error, Socket};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};
|
||||
use tokio::net::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
use tokio::timer::Timeout;
|
||||
|
||||
pub async fn connect_socket(idx: usize, config: &Config) -> Result<Socket, Error> {
|
||||
let port = *config
|
||||
.port
|
||||
.get(idx)
|
||||
.or_else(|| config.port.get(0))
|
||||
.unwrap_or(&5432);
|
||||
|
||||
match &config.host[idx] {
|
||||
Host::Tcp(host) => {
|
||||
let addrs = match host.parse::<IpAddr>() {
|
||||
Ok(ip) => {
|
||||
// avoid dealing with blocking DNS entirely if possible
|
||||
vec![SocketAddr::new(ip, port)].into_iter()
|
||||
}
|
||||
Err(_) => {
|
||||
// FIXME what do?
|
||||
(&**host, port).to_socket_addrs().map_err(Error::connect)?
|
||||
}
|
||||
};
|
||||
|
||||
let mut error = None;
|
||||
for addr in addrs {
|
||||
let new_error = match connect_timeout(TcpStream::connect(&addr), config).await {
|
||||
Ok(socket) => return Ok(Socket::new_tcp(socket)),
|
||||
Err(e) => e,
|
||||
};
|
||||
error = Some(new_error);
|
||||
}
|
||||
|
||||
let error = error.unwrap_or_else(|| {
|
||||
Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"resolved 0 addresses",
|
||||
))
|
||||
});
|
||||
Err(error)
|
||||
}
|
||||
#[cfg(unix)]
|
||||
Host::Unix(path) => {
|
||||
let socket = connect_timeout(UnixStream::connect(path), config).await?;
|
||||
Ok(Socket::new_unix(socket))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_timeout<F, T>(connect: F, config: &Config) -> Result<T, Error>
|
||||
where
|
||||
F: Future<Output = io::Result<T>>,
|
||||
{
|
||||
match config.connect_timeout {
|
||||
Some(connect_timeout) => match Timeout::new(connect, connect_timeout).await {
|
||||
Ok(Ok(socket)) => Ok(socket),
|
||||
Ok(Err(e)) => Err(Error::connect(e)),
|
||||
Err(_) => Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"connection timed out",
|
||||
))),
|
||||
},
|
||||
None => match connect.await {
|
||||
Ok(socket) => Ok(socket),
|
||||
Err(e) => Err(Error::connect(e)),
|
||||
},
|
||||
}
|
||||
}
|
@ -112,19 +112,48 @@
|
||||
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
||||
#![feature(async_await)]
|
||||
|
||||
pub use client::Client;
|
||||
pub use config::Config;
|
||||
pub use connection::Connection;
|
||||
pub use error::Error;
|
||||
pub use crate::client::Client;
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connection::Connection;
|
||||
pub use crate::error::Error;
|
||||
#[cfg(feature = "runtime")]
|
||||
pub use crate::socket::Socket;
|
||||
#[cfg(feature = "runtime")]
|
||||
use crate::tls::MakeTlsConnect;
|
||||
pub use crate::tls::NoTls;
|
||||
|
||||
mod client;
|
||||
mod codec;
|
||||
pub mod config;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod connect;
|
||||
mod connect_raw;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod connect_socket;
|
||||
mod connect_tls;
|
||||
mod connection;
|
||||
pub mod error;
|
||||
mod maybe_tls_stream;
|
||||
#[cfg(feature = "runtime")]
|
||||
mod socket;
|
||||
pub mod tls;
|
||||
pub mod types;
|
||||
|
||||
/// A convenience function which parses a connection string and connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for details on the connection string format.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
///
|
||||
/// [`Config`]: ./Config.t.html
|
||||
#[cfg(feature = "runtime")]
|
||||
pub async fn connect<T>(
|
||||
config: &str,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<Socket, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<Socket>,
|
||||
{
|
||||
let config = config.parse::<Config>()?;
|
||||
config.connect(tls).await
|
||||
}
|
||||
|
115
tokio-postgres/src/socket.rs
Normal file
115
tokio-postgres/src/socket.rs
Normal file
@ -0,0 +1,115 @@
|
||||
use bytes::{Buf, BufMut};
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Inner {
|
||||
Tcp(TcpStream),
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
/// The standard stream type used by the crate.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
#[derive(Debug)]
|
||||
pub struct Socket(Inner);
|
||||
|
||||
impl Socket {
|
||||
pub(crate) fn new_tcp(stream: TcpStream) -> Socket {
|
||||
Socket(Inner::Tcp(stream))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn new_unix(stream: UnixStream) -> Socket {
|
||||
Socket(Inner::Unix(stream))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Socket {
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
match &self.0 {
|
||||
Inner::Tcp(s) => s.prepare_uninitialized_buffer(buf),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => s.prepare_uninitialized_buffer(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match &mut self.0 {
|
||||
Inner::Tcp(s) => Pin::new(s).poll_read(cx, buf),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => Pin::new(s).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read_buf<B>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>>
|
||||
where
|
||||
Self: Sized,
|
||||
B: BufMut,
|
||||
{
|
||||
match &mut self.0 {
|
||||
Inner::Tcp(s) => Pin::new(s).poll_read_buf(cx, buf),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => Pin::new(s).poll_read_buf(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socket {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match &mut self.0 {
|
||||
Inner::Tcp(s) => Pin::new(s).poll_write(cx, buf),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => Pin::new(s).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut self.0 {
|
||||
Inner::Tcp(s) => Pin::new(s).poll_flush(cx),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut self.0 {
|
||||
Inner::Tcp(s) => Pin::new(s).poll_shutdown(cx),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => Pin::new(s).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_buf<B>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut B,
|
||||
) -> Poll<io::Result<usize>>
|
||||
where
|
||||
Self: Sized,
|
||||
B: Buf,
|
||||
{
|
||||
match &mut self.0 {
|
||||
Inner::Tcp(s) => Pin::new(s).poll_write_buf(cx, buf),
|
||||
#[cfg(unix)]
|
||||
Inner::Unix(s) => Pin::new(s).poll_write_buf(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
@ -41,7 +41,7 @@ pub trait MakeTlsConnect<S> {
|
||||
type Stream: AsyncRead + AsyncWrite + Unpin;
|
||||
/// The `TlsConnect` implementation created by this type.
|
||||
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
|
||||
/// The error type retured by the `TlsConnect` implementation.
|
||||
/// The error type returned by the `TlsConnect` implementation.
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
|
||||
/// Creates a new `TlsConnect`or.
|
||||
@ -73,6 +73,17 @@ pub trait TlsConnect<S> {
|
||||
/// This can be used when `sslmode` is `none` or `prefer`.
|
||||
pub struct NoTls;
|
||||
|
||||
#[cfg(feature = "runtime")]
|
||||
impl<S> MakeTlsConnect<S> for NoTls {
|
||||
type Stream = NoTlsStream;
|
||||
type TlsConnect = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> TlsConnect<S> for NoTls {
|
||||
type Stream = NoTlsStream;
|
||||
type Error = NoTlsError;
|
||||
|
@ -7,9 +7,9 @@ use tokio_postgres::tls::{NoTls, NoTlsStream};
|
||||
use tokio_postgres::{Client, Config, Connection, Error};
|
||||
|
||||
mod parse;
|
||||
/*
|
||||
#[cfg(feature = "runtime")]
|
||||
mod runtime;
|
||||
/*
|
||||
mod types;
|
||||
*/
|
||||
|
||||
|
@ -5,70 +5,65 @@ use tokio::timer::Delay;
|
||||
use tokio_postgres::error::SqlState;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
fn smoke_test(s: &str) {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let connect = tokio_postgres::connect(s, NoTls);
|
||||
let (mut client, connection) = runtime.block_on(connect).unwrap();
|
||||
async fn smoke_test(s: &str) {
|
||||
let (mut client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap();
|
||||
/*
|
||||
let connection = connection.map_err(|e| panic!("{}", e));
|
||||
runtime.spawn(connection);
|
||||
|
||||
let execute = client.simple_query("SELECT 1").for_each(|_| Ok(()));
|
||||
runtime.block_on(execute).unwrap();
|
||||
*/
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[tokio::test]
|
||||
#[ignore] // FIXME doesn't work with our docker-based tests :(
|
||||
fn unix_socket() {
|
||||
smoke_test("host=/var/run/postgresql port=5433 user=postgres");
|
||||
async fn unix_socket() {
|
||||
smoke_test("host=/var/run/postgresql port=5433 user=postgres").await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tcp() {
|
||||
smoke_test("host=localhost port=5433 user=postgres")
|
||||
#[tokio::test]
|
||||
async fn tcp() {
|
||||
smoke_test("host=localhost port=5433 user=postgres").await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_hosts_one_port() {
|
||||
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres");
|
||||
#[tokio::test]
|
||||
async fn multiple_hosts_one_port() {
|
||||
smoke_test("host=foobar.invalid,localhost port=5433 user=postgres").await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_hosts_multiple_ports() {
|
||||
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres");
|
||||
#[tokio::test]
|
||||
async fn multiple_hosts_multiple_ports() {
|
||||
smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres").await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrong_port_count() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
|
||||
let f = tokio_postgres::connect(
|
||||
"host=localhost,localhost,localhost port=5433,5433 user=postgres",
|
||||
NoTls,
|
||||
);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
#[tokio::test]
|
||||
async fn wrong_port_count() {
|
||||
tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls)
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_session_attrs_ok() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = tokio_postgres::connect(
|
||||
/*
|
||||
#[tokio::test]
|
||||
async fn target_session_attrs_ok() {
|
||||
tokio_postgres::connect(
|
||||
"host=localhost port=5433 user=postgres target_session_attrs=read-write",
|
||||
NoTls,
|
||||
);
|
||||
runtime.block_on(f).unwrap();
|
||||
)
|
||||
.await
|
||||
.err()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn target_session_attrs_err() {
|
||||
let mut runtime = Runtime::new().unwrap();
|
||||
let f = tokio_postgres::connect(
|
||||
#[tokio::test]
|
||||
async fn target_session_attrs_err() {
|
||||
tokio_postgres::connect(
|
||||
"host=localhost port=5433 user=postgres target_session_attrs=read-write
|
||||
options='-c default_transaction_read_only=on'",
|
||||
NoTls,
|
||||
);
|
||||
runtime.block_on(f).err().unwrap();
|
||||
).await.err().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -100,3 +95,4 @@ fn cancel_query() {
|
||||
|
||||
let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap();
|
||||
}
|
||||
*/
|
||||
|
Loading…
Reference in New Issue
Block a user