taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

notification.rs (4367B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024-2025 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::hash::Hash;
     18 use std::sync::Arc;
     19 
     20 use dashmap::DashMap;
     21 use tokio::sync::watch;
     22 
     23 pub mod de;
     24 
     25 /// Listen for many postgres notification channels using a single connection
     26 #[macro_export]
     27 macro_rules! notification_listener {
     28     ($pool: expr, $($channel:expr => ($($arg:ident: $type:ty),*) $lambda:block),*$(,)?) => {
     29         {
     30             let mut listener = ::sqlx::postgres::PgListener::connect_with($pool).await?;
     31             listener.listen_all([$($channel,)*]).await?;
     32             loop {
     33                 while let Some(notification) = listener.try_recv().await? {
     34                     tracing::debug!(target: "db-watcher",
     35                         "db notification: {} - {}",
     36                         notification.channel(),
     37                         notification.payload()
     38                     );
     39                     match notification.channel() {
     40                         $($channel => {
     41                             let ($($arg,)*): ($($type,)*) =
     42                                 ::taler_api::notification::de::from_str(notification.payload()).unwrap();// TODO error handling
     43                             $lambda
     44                         }),*
     45                         unknown => unreachable!("{}", unknown),
     46                     }
     47                 }
     48                 // TODO wait before reconnect
     49             }
     50         }
     51 
     52     }
     53 }
     54 
     55 pub use notification_listener;
     56 
     57 type CountedNotify<T> = watch::Sender<Option<T>>;
     58 
     59 #[derive(Default)]
     60 pub struct NotificationChannel<K: Eq + Hash, V> {
     61     map: Arc<DashMap<K, CountedNotify<V>>>,
     62 }
     63 
     64 pub struct Listener<K: Eq + Hash + Clone, V> {
     65     map: Arc<DashMap<K, CountedNotify<V>>>,
     66     channel: watch::Receiver<Option<V>>,
     67     key: K,
     68 }
     69 
     70 impl<K: Eq + Hash + Clone, V> Listener<K, V> {
     71     pub async fn wait_for(mut self, filter: impl Fn(&V) -> bool) {
     72         self.channel
     73             .wait_for(|it| it.as_ref().map(&filter).unwrap_or(false))
     74             .await
     75             .ok(); // If the channel is closed we cannot wait efficiently
     76     }
     77 }
     78 
     79 impl<K: Eq + Hash + Clone, V> Drop for Listener<K, V> {
     80     fn drop(&mut self) {
     81         self.map
     82             .remove_if(&self.key, |_, it| it.receiver_count() == 1);
     83     }
     84 }
     85 
     86 impl<K: Eq + Hash + Clone, V> NotificationChannel<K, V> {
     87     pub fn listener(&self, key: K) -> Listener<K, V> {
     88         let entry = self.map.entry(key.clone()).or_insert_with(|| {
     89             let (sender, _) = watch::channel(None);
     90             sender
     91         });
     92         Listener {
     93             map: self.map.clone(),
     94             channel: entry.subscribe(),
     95             key,
     96         }
     97     }
     98 }
     99 
    100 #[tokio::test]
    101 async fn channel_gc() {
    102     use std::time::Duration;
    103 
    104     let channel = NotificationChannel::default();
    105     assert_eq!(0, channel.map.len());
    106 
    107     // Clean in future
    108     let listener = channel.listener("test");
    109     assert_eq!(1, channel.map.len());
    110     tokio::time::timeout(Duration::from_millis(0), listener.wait_for(|it| it == 42))
    111         .await
    112         .unwrap_err();
    113     assert_eq!(0, channel.map.len());
    114 
    115     // Clean on drop
    116     let first = channel.listener("test");
    117     let second = channel.listener("test");
    118     assert_eq!(1, channel.map.len());
    119     tokio::time::timeout(Duration::from_millis(0), first.wait_for(|it| it == 42))
    120         .await
    121         .unwrap_err();
    122     assert_eq!(1, channel.map.len());
    123     drop(second);
    124     assert_eq!(0, channel.map.len());
    125 }
    126 
    127 #[tokio::test]
    128 async fn wake() {
    129     let channel = NotificationChannel::default();
    130     let listener = channel.listener("test");
    131     let task = tokio::spawn(listener.wait_for(|it| *it == 42));
    132     channel.map.entry("test").and_modify(|it| {
    133         it.send(Some(42)).unwrap();
    134     });
    135     task.await.unwrap();
    136 }