taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

notification.rs (4487B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{hash::Hash, sync::Arc};
     18 
     19 use dashmap::DashMap;
     20 use tokio::sync::watch::{self, Receiver};
     21 
     22 pub mod de;
     23 
     24 /// Listen for many postgres notification channels using a single connection
     25 #[macro_export]
     26 macro_rules! notification_listener {
     27     ($pool: expr, $($channel:expr => ($($arg:ident: $type:ty),*) $lambda:block),*$(,)?) => {
     28         {
     29             let mut listener = ::sqlx::postgres::PgListener::connect_with($pool).await?;
     30             listener.listen_all([$($channel,)*]).await?;
     31             loop {
     32                 while let Some(notification) = listener.try_recv().await? {
     33                     tracing::debug!(target: "db-watcher",
     34                         "db notification: {} - {}",
     35                         notification.channel(),
     36                         notification.payload()
     37                     );
     38                     match notification.channel() {
     39                         $($channel => {
     40                             let ($($arg,)*): ($($type,)*) =
     41                                 ::taler_api::notification::de::from_str(notification.payload()).unwrap();// TODO error handling
     42                             $lambda
     43                         }),*
     44                         unknown => unreachable!("{}", unknown),
     45                     }
     46                 }
     47                 // TODO wait before reconnect
     48             }
     49         }
     50 
     51     }
     52 }
     53 
     54 pub use notification_listener;
     55 
     56 type CountedNotify<T> = watch::Sender<Option<T>>;
     57 
     58 #[derive(Default)]
     59 pub struct NotificationChannel<K: Eq + Hash, V> {
     60     map: Arc<DashMap<K, CountedNotify<V>>>,
     61 }
     62 
     63 pub struct Listener<K: Eq + Hash + Clone, V> {
     64     map: Arc<DashMap<K, CountedNotify<V>>>,
     65     channel: watch::Receiver<Option<V>>,
     66     key: K,
     67 }
     68 
     69 impl<K: Eq + Hash + Clone, V> Listener<K, V> {
     70     pub async fn wait_for(mut self, filter: impl Fn(&V) -> bool) {
     71         self.channel
     72             .wait_for(|it| it.as_ref().map(&filter).unwrap_or(false))
     73             .await
     74             .ok(); // If the channel is closed we cannot wait efficiently
     75     }
     76 }
     77 
     78 impl<K: Eq + Hash + Clone, V> Drop for Listener<K, V> {
     79     fn drop(&mut self) {
     80         self.map
     81             .remove_if(&self.key, |_, it| it.receiver_count() == 1);
     82     }
     83 }
     84 
     85 impl<K: Eq + Hash + Clone, V> NotificationChannel<K, V> {
     86     pub fn listener(&self, key: K) -> Listener<K, V> {
     87         let entry = self.map.entry(key.clone()).or_insert_with(|| {
     88             let (sender, _) = watch::channel(None);
     89             sender
     90         });
     91         Listener {
     92             map: self.map.clone(),
     93             channel: entry.subscribe(),
     94             key,
     95         }
     96     }
     97 }
     98 
     99 pub fn dummy_listen<T: Default>() -> Receiver<T> {
    100     tokio::sync::watch::channel(T::default()).1
    101 }
    102 
    103 #[tokio::test]
    104 async fn channel_gc() {
    105     use std::time::Duration;
    106 
    107     let channel = NotificationChannel::default();
    108     assert_eq!(0, channel.map.len());
    109 
    110     // Clean in future
    111     let listener = channel.listener("test");
    112     assert_eq!(1, channel.map.len());
    113     tokio::time::timeout(Duration::from_millis(0), listener.wait_for(|it| it == 42))
    114         .await
    115         .unwrap_err();
    116     assert_eq!(0, channel.map.len());
    117 
    118     // Clean on drop
    119     let first = channel.listener("test");
    120     let second = channel.listener("test");
    121     assert_eq!(1, channel.map.len());
    122     tokio::time::timeout(Duration::from_millis(0), first.wait_for(|it| it == 42))
    123         .await
    124         .unwrap_err();
    125     assert_eq!(1, channel.map.len());
    126     drop(second);
    127     assert_eq!(0, channel.map.len());
    128 }
    129 
    130 #[tokio::test]
    131 async fn wake() {
    132     let channel = NotificationChannel::default();
    133     let listener = channel.listener("test");
    134     let task = tokio::spawn(listener.wait_for(|it| *it == 42));
    135     channel.map.entry("test").and_modify(|it| {
    136         it.send(Some(42)).unwrap();
    137     });
    138     task.await.unwrap();
    139 }