notification.rs (4487B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{hash::Hash, sync::Arc}; 18 19 use dashmap::DashMap; 20 use tokio::sync::watch::{self, Receiver}; 21 22 pub mod de; 23 24 /// Listen for many postgres notification channels using a single connection 25 #[macro_export] 26 macro_rules! notification_listener { 27 ($pool: expr, $($channel:expr => ($($arg:ident: $type:ty),*) $lambda:block),*$(,)?) => { 28 { 29 let mut listener = ::sqlx::postgres::PgListener::connect_with($pool).await?; 30 listener.listen_all([$($channel,)*]).await?; 31 loop { 32 while let Some(notification) = listener.try_recv().await? { 33 tracing::debug!(target: "db-watcher", 34 "db notification: {} - {}", 35 notification.channel(), 36 notification.payload() 37 ); 38 match notification.channel() { 39 $($channel => { 40 let ($($arg,)*): ($($type,)*) = 41 ::taler_api::notification::de::from_str(notification.payload()).unwrap();// TODO error handling 42 $lambda 43 }),* 44 unknown => unreachable!("{}", unknown), 45 } 46 } 47 // TODO wait before reconnect 48 } 49 } 50 51 } 52 } 53 54 pub use notification_listener; 55 56 type CountedNotify<T> = watch::Sender<Option<T>>; 57 58 #[derive(Default)] 59 pub struct NotificationChannel<K: Eq + Hash, V> { 60 map: Arc<DashMap<K, CountedNotify<V>>>, 61 } 62 63 pub struct Listener<K: Eq + Hash + Clone, V> { 64 map: Arc<DashMap<K, CountedNotify<V>>>, 65 channel: watch::Receiver<Option<V>>, 66 key: K, 67 } 68 69 impl<K: Eq + Hash + Clone, V> Listener<K, V> { 70 pub async fn wait_for(mut self, filter: impl Fn(&V) -> bool) { 71 self.channel 72 .wait_for(|it| it.as_ref().map(&filter).unwrap_or(false)) 73 .await 74 .ok(); // If the channel is closed we cannot wait efficiently 75 } 76 } 77 78 impl<K: Eq + Hash + Clone, V> Drop for Listener<K, V> { 79 fn drop(&mut self) { 80 self.map 81 .remove_if(&self.key, |_, it| it.receiver_count() == 1); 82 } 83 } 84 85 impl<K: Eq + Hash + Clone, V> NotificationChannel<K, V> { 86 pub fn listener(&self, key: K) -> Listener<K, V> { 87 let entry = self.map.entry(key.clone()).or_insert_with(|| { 88 let (sender, _) = watch::channel(None); 89 sender 90 }); 91 Listener { 92 map: self.map.clone(), 93 channel: entry.subscribe(), 94 key, 95 } 96 } 97 } 98 99 pub fn dummy_listen<T: Default>() -> Receiver<T> { 100 tokio::sync::watch::channel(T::default()).1 101 } 102 103 #[tokio::test] 104 async fn channel_gc() { 105 use std::time::Duration; 106 107 let channel = NotificationChannel::default(); 108 assert_eq!(0, channel.map.len()); 109 110 // Clean in future 111 let listener = channel.listener("test"); 112 assert_eq!(1, channel.map.len()); 113 tokio::time::timeout(Duration::from_millis(0), listener.wait_for(|it| it == 42)) 114 .await 115 .unwrap_err(); 116 assert_eq!(0, channel.map.len()); 117 118 // Clean on drop 119 let first = channel.listener("test"); 120 let second = channel.listener("test"); 121 assert_eq!(1, channel.map.len()); 122 tokio::time::timeout(Duration::from_millis(0), first.wait_for(|it| it == 42)) 123 .await 124 .unwrap_err(); 125 assert_eq!(1, channel.map.len()); 126 drop(second); 127 assert_eq!(0, channel.map.len()); 128 } 129 130 #[tokio::test] 131 async fn wake() { 132 let channel = NotificationChannel::default(); 133 let listener = channel.listener("test"); 134 let task = tokio::spawn(listener.wait_for(|it| *it == 42)); 135 channel.map.entry("test").and_modify(|it| { 136 it.send(Some(42)).unwrap(); 137 }); 138 task.await.unwrap(); 139 }