use crate::client::SocketConfig; use crate::config::{Host, TargetSessionAttrs}; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; use futures::{future, pin_mut, Future, FutureExt, Stream}; use std::io; use std::task::Poll; pub async fn connect( mut tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where T: MakeTlsConnect, { if config.host.is_empty() { return Err(Error::config("host missing".into())); } if config.port.len() > 1 && config.port.len() != config.host.len() { return Err(Error::config("invalid number of ports".into())); } let mut error = None; for (i, host) in config.host.iter().enumerate() { let port = config .port .get(i) .or_else(|| config.port.get(0)) .copied() .unwrap_or(5432); let hostname = match host { Host::Tcp(host) => host.as_str(), // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter #[cfg(unix)] Host::Unix(_) => "", }; let tls = tls .make_tls_connect(hostname) .map_err(|e| Error::tls(e.into()))?; match connect_once(host, port, tls, config).await { Ok((client, connection)) => return Ok((client, connection)), Err(e) => error = Some(e), } } Err(error.unwrap()) } async fn connect_once( host: &Host, port: u16, tls: T, config: &Config, ) -> Result<(Client, Connection), Error> where T: TlsConnect, { let socket = connect_socket( host, port, config.connect_timeout, config.keepalives, config.keepalives_idle, ) .await?; let (mut client, mut connection) = connect_raw(socket, tls, config).await?; if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { let rows = client.simple_query_raw("SHOW transaction_read_only"); pin_mut!(rows); let rows = future::poll_fn(|cx| { if connection.poll_unpin(cx)?.is_ready() { return Poll::Ready(Err(Error::closed())); } rows.as_mut().poll(cx) }) .await?; pin_mut!(rows); loop { let next = future::poll_fn(|cx| { if connection.poll_unpin(cx)?.is_ready() { return Poll::Ready(Some(Err(Error::closed()))); } rows.as_mut().poll_next(cx) }); match next.await.transpose()? { Some(SimpleQueryMessage::Row(row)) => { if row.try_get(0)? == Some("on") { return Err(Error::connect(io::Error::new( io::ErrorKind::PermissionDenied, "database does not allow writes", ))); } else { break; } } Some(_) => {} None => return Err(Error::unexpected_message()), } } } client.set_socket_config(SocketConfig { host: host.clone(), port, connect_timeout: config.connect_timeout, keepalives: config.keepalives, keepalives_idle: config.keepalives_idle, }); Ok((client, connection)) }