Merge pull request #655 from benesch/notice-callback

Permit configuring the notice callback
This commit is contained in:
Steven Fackler 2020-09-22 20:51:55 -04:00 committed by GitHub
commit eabcc28657
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 8 deletions

View File

@ -4,13 +4,16 @@
use crate::connection::Connection;
use crate::Client;
use log::info;
use std::fmt;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime;
#[doc(inline)]
pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs};
use tokio_postgres::error::DbError;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Error, Socket};
@ -90,6 +93,7 @@ use tokio_postgres::{Error, Socket};
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
@ -109,9 +113,7 @@ impl Default for Config {
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
Config {
config: tokio_postgres::Config::new(),
}
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
@ -307,6 +309,25 @@ impl Config {
self.config.get_channel_binding()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub fn connect<T>(&self, tls: T) -> Result<Client, Error>
where
@ -323,7 +344,7 @@ impl Config {
let (client, connection) = runtime.block_on(self.config.connect(tls))?;
let connection = Connection::new(runtime, connection);
let connection = Connection::new(runtime, connection, self.notice_callback.clone());
Ok(Client::new(connection, client))
}
}
@ -338,6 +359,11 @@ impl FromStr for Config {
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config { config }
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}

View File

@ -1,24 +1,30 @@
use crate::{Error, Notification};
use futures::future;
use futures::{pin_mut, Stream};
use log::info;
use std::collections::VecDeque;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::runtime::Runtime;
use tokio_postgres::error::DbError;
use tokio_postgres::AsyncMessage;
pub struct Connection {
runtime: Runtime,
connection: Pin<Box<dyn Stream<Item = Result<AsyncMessage, Error>> + Send>>,
notifications: VecDeque<Notification>,
notice_callback: Arc<dyn Fn(DbError)>,
}
impl Connection {
pub fn new<S, T>(runtime: Runtime, connection: tokio_postgres::Connection<S, T>) -> Connection
pub fn new<S, T>(
runtime: Runtime,
connection: tokio_postgres::Connection<S, T>,
notice_callback: Arc<dyn Fn(DbError)>,
) -> Connection
where
S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
T: AsyncRead + AsyncWrite + Unpin + 'static + Send,
@ -27,6 +33,7 @@ impl Connection {
runtime,
connection: Box::pin(ConnectionStream { connection }),
notifications: VecDeque::new(),
notice_callback,
}
}
@ -55,6 +62,7 @@ impl Connection {
{
let connection = &mut self.connection;
let notifications = &mut self.notifications;
let notice_callback = &mut self.notice_callback;
self.runtime.block_on({
future::poll_fn(|cx| {
let done = loop {
@ -63,7 +71,7 @@ impl Connection {
notifications.push_back(notification);
}
Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => {
info!("{}: {}", notice.severity(), notice.message());
notice_callback(notice)
}
Poll::Ready(Some(Ok(_))) => {}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),

View File

@ -1,4 +1,6 @@
use std::io::{Read, Write};
use std::str::FromStr;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use tokio_postgres::error::SqlState;
@ -476,6 +478,22 @@ fn notifications_timeout_iter() {
assert_eq!(notifications[1].payload(), "world");
}
#[test]
fn notice_callback() {
let (notice_tx, notice_rx) = mpsc::sync_channel(64);
let mut client = Config::from_str("host=localhost port=5433 user=postgres")
.unwrap()
.notice_callback(move |n| notice_tx.send(n).unwrap())
.connect(NoTls)
.unwrap();
client
.batch_execute("DO $$BEGIN RAISE NOTICE 'custom'; END$$")
.unwrap();
assert_eq!(notice_rx.recv().unwrap().message(), "custom");
}
#[test]
fn explicit_close() {
let client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();