notification.rs (4367B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024-2025 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::hash::Hash; 18 use std::sync::Arc; 19 20 use dashmap::DashMap; 21 use tokio::sync::watch; 22 23 pub mod de; 24 25 /// Listen for many postgres notification channels using a single connection 26 #[macro_export] 27 macro_rules! notification_listener { 28 ($pool: expr, $($channel:expr => ($($arg:ident: $type:ty),*) $lambda:block),*$(,)?) => { 29 { 30 let mut listener = ::sqlx::postgres::PgListener::connect_with($pool).await?; 31 listener.listen_all([$($channel,)*]).await?; 32 loop { 33 while let Some(notification) = listener.try_recv().await? { 34 tracing::debug!(target: "db-watcher", 35 "db notification: {} - {}", 36 notification.channel(), 37 notification.payload() 38 ); 39 match notification.channel() { 40 $($channel => { 41 let ($($arg,)*): ($($type,)*) = 42 ::taler_api::notification::de::from_str(notification.payload()).unwrap();// TODO error handling 43 $lambda 44 }),* 45 unknown => unreachable!("{}", unknown), 46 } 47 } 48 // TODO wait before reconnect 49 } 50 } 51 52 } 53 } 54 55 pub use notification_listener; 56 57 type CountedNotify<T> = watch::Sender<Option<T>>; 58 59 #[derive(Default)] 60 pub struct NotificationChannel<K: Eq + Hash, V> { 61 map: Arc<DashMap<K, CountedNotify<V>>>, 62 } 63 64 pub struct Listener<K: Eq + Hash + Clone, V> { 65 map: Arc<DashMap<K, CountedNotify<V>>>, 66 channel: watch::Receiver<Option<V>>, 67 key: K, 68 } 69 70 impl<K: Eq + Hash + Clone, V> Listener<K, V> { 71 pub async fn wait_for(mut self, filter: impl Fn(&V) -> bool) { 72 self.channel 73 .wait_for(|it| it.as_ref().map(&filter).unwrap_or(false)) 74 .await 75 .ok(); // If the channel is closed we cannot wait efficiently 76 } 77 } 78 79 impl<K: Eq + Hash + Clone, V> Drop for Listener<K, V> { 80 fn drop(&mut self) { 81 self.map 82 .remove_if(&self.key, |_, it| it.receiver_count() == 1); 83 } 84 } 85 86 impl<K: Eq + Hash + Clone, V> NotificationChannel<K, V> { 87 pub fn listener(&self, key: K) -> Listener<K, V> { 88 let entry = self.map.entry(key.clone()).or_insert_with(|| { 89 let (sender, _) = watch::channel(None); 90 sender 91 }); 92 Listener { 93 map: self.map.clone(), 94 channel: entry.subscribe(), 95 key, 96 } 97 } 98 } 99 100 #[tokio::test] 101 async fn channel_gc() { 102 use std::time::Duration; 103 104 let channel = NotificationChannel::default(); 105 assert_eq!(0, channel.map.len()); 106 107 // Clean in future 108 let listener = channel.listener("test"); 109 assert_eq!(1, channel.map.len()); 110 tokio::time::timeout(Duration::from_millis(0), listener.wait_for(|it| it == 42)) 111 .await 112 .unwrap_err(); 113 assert_eq!(0, channel.map.len()); 114 115 // Clean on drop 116 let first = channel.listener("test"); 117 let second = channel.listener("test"); 118 assert_eq!(1, channel.map.len()); 119 tokio::time::timeout(Duration::from_millis(0), first.wait_for(|it| it == 42)) 120 .await 121 .unwrap_err(); 122 assert_eq!(1, channel.map.len()); 123 drop(second); 124 assert_eq!(0, channel.map.len()); 125 } 126 127 #[tokio::test] 128 async fn wake() { 129 let channel = NotificationChannel::default(); 130 let listener = channel.listener("test"); 131 let task = tokio::spawn(listener.wait_for(|it| *it == 42)); 132 channel.map.entry("test").and_modify(|it| { 133 it.send(Some(42)).unwrap(); 134 }); 135 task.await.unwrap(); 136 }