taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

sse.rs (7689B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::pin::Pin;
     18 
     19 use compact_str::CompactString;
     20 use futures_util::{Stream, StreamExt as _, stream};
     21 use tokio_util::{
     22     bytes::Bytes,
     23     codec::{FramedRead, LinesCodec, LinesCodecError},
     24     io::StreamReader,
     25 };
     26 use tracing::trace;
     27 
     28 #[derive(Debug, Default, PartialEq, Eq)]
     29 pub struct SseMessage {
     30     pub event: CompactString,
     31     pub data: String,
     32 }
     33 
     34 type SseStream = dyn Stream<Item = std::result::Result<String, LinesCodecError>> + Send;
     35 
     36 /// Server-sent event client
     37 pub struct SseClient {
     38     pub last_event_id: Option<CompactString>,
     39     pub reconnection_time: Option<u64>,
     40     stream: Pin<Box<SseStream>>,
     41 }
     42 
     43 impl SseClient {
     44     pub fn new() -> Self {
     45         Self {
     46             last_event_id: None,
     47             reconnection_time: None,
     48             stream: Box::pin(stream::empty()),
     49         }
     50     }
     51 
     52     pub fn connect<E: std::error::Error + Send + Sync + 'static>(
     53         &mut self,
     54         stream: impl Stream<Item = Result<Bytes, E>> + 'static + Send,
     55     ) {
     56         let stream = stream.map(|it| it.map_err(std::io::Error::other));
     57         let lines = FramedRead::new(StreamReader::new(stream), LinesCodec::new());
     58         self.stream = Box::pin(lines);
     59     }
     60 
     61     pub async fn next(&mut self) -> Option<Result<SseMessage, LinesCodecError>> {
     62         // TODO add tests
     63         let mut event = CompactString::new("message");
     64         let mut data = None::<String>;
     65         while let Some(res) = self.stream.next().await {
     66             let line = match res {
     67                 Ok(line) => line,
     68                 Err(e) => return Some(Err(e)),
     69             };
     70             // Parse line
     71             let (field, value): (&str, &str) = if line.is_empty() {
     72                 if let Some(data) = data.take() {
     73                     return Some(Ok(SseMessage { event, data }));
     74                 } else {
     75                     event = CompactString::new("message");
     76                     continue;
     77                 }
     78             } else if let Some(comment) = line.strip_prefix(':') {
     79                 trace!(target: "sse", "{comment}");
     80                 continue;
     81             } else if let Some((k, v)) = line.split_once(':') {
     82                 (k, v.strip_prefix(' ').unwrap_or(v))
     83             } else {
     84                 (&line, "")
     85             };
     86 
     87             // Process field
     88             match field {
     89                 "event" => event = CompactString::new(value),
     90                 "data" => match data.as_mut() {
     91                     Some(data) => {
     92                         data.push('\n');
     93                         data.push_str(value);
     94                     }
     95                     None => data = Some(value.to_string()),
     96                 },
     97                 "id" => {
     98                     if !value.contains('\0') {
     99                         self.last_event_id = Some(CompactString::new(value))
    100                     }
    101                 }
    102                 "retry" => {
    103                     if value.as_bytes().iter().all(|c| c.is_ascii_digit())
    104                         && let Ok(int) = value.parse::<u64>()
    105                     {
    106                         self.reconnection_time = Some(int)
    107                     }
    108                 }
    109                 _ => continue,
    110             }
    111         }
    112         None
    113     }
    114 }
    115 
    116 impl Default for SseClient {
    117     fn default() -> Self {
    118         Self::new()
    119     }
    120 }
    121 
    122 #[tokio::test]
    123 pub async fn protocol() {
    124     pub async fn test(
    125         stream: &'static str,
    126         result: &[(&str, &str)],
    127         last_event_id: Option<&str>,
    128         reconnection_time: Option<u64>,
    129     ) {
    130         let stream = stream::iter(
    131             stream
    132                 .as_bytes()
    133                 .chunks(12)
    134                 .map(|chunk| std::io::Result::Ok(Bytes::from_static(chunk))),
    135         );
    136         let mut client = SseClient::new();
    137         client.connect(stream);
    138         let mut res = Vec::new();
    139         while let Some(msg) = client.next().await {
    140             res.push(msg.unwrap());
    141         }
    142         assert_eq!(
    143             result,
    144             &res.iter()
    145                 .map(|m| (m.event.as_ref(), m.data.as_ref()))
    146                 .collect::<Vec<_>>()
    147         );
    148         assert_eq!(client.last_event_id.as_deref(), last_event_id);
    149         assert_eq!(client.reconnection_time, reconnection_time);
    150     }
    151 
    152     macro_rules! check {
    153         // Handle multiple tuples + optional id and retry
    154         ($stream:expr $(, ($e:expr, $d:expr))* $(, id: $id:expr)? $(, retry: $retry:expr)?) => {
    155             test(
    156                 $stream,
    157                 &[ $( ($e, $d) ),* ],
    158                 { let mut _id = None; $( _id = Some($id); )? _id },
    159                 { let mut _r = None; $( _r = Some($retry); )? _r }
    160             ).await
    161         };
    162     }
    163 
    164     check!("data\n\n", ("message", ""));
    165     check!("data:key:value\n\n", ("message", "key:value"));
    166     check!("data: value\n\n", ("message", "value"));
    167     check!("data:first\ndata:second\n\n", ("message", "first\nsecond"));
    168 
    169     check!(
    170         "event:first\nevent:second\ndata:test\n\n",
    171         ("second", "test")
    172     );
    173 
    174     check!("data:test\r\n\r\n", ("message", "test"));
    175     check!("data:test\r\r");
    176     check!("data:line1\r\ndata:line2\n\n", ("message", "line1\nline2"));
    177     check!("data:test\n");
    178     check!("data:test\n\n\n", ("message", "test"));
    179     check!("data:\ndata:\n\n", ("message", "\n"));
    180     check!("data:\ndata:content\ndata:\n\n", ("message", "\ncontent\n"));
    181     check!("data:Hello δΈ–η•Œ 🌍\n\n", ("message", "Hello δΈ–η•Œ 🌍"));
    182     check!("data:   \n\n", ("message", "  "));
    183     check!("id:123\ndata:test\n\n", ("message", "test"), id: "123");
    184     check!("id:first\nid:second\ndata:test\n\n", ("message", "test"), id: "second");
    185     check!("id:first\nid:second\nid:\ndata:test\n\n", ("message", "test"), id: "");
    186     check!("id:\ndata:test\n\n", ("message", "test"), id: "");
    187     check!(
    188         "id:test:123\ndata:test\n\n",
    189         ("message", "test"),
    190         id: "test:123"
    191     );
    192     check!("id:123\x00456\ndata:test\n\n", ("message", "test"));
    193     check!("id:123\n\n", id: "123");
    194     check!(
    195         "id:123\ndata:first\n\ndata:second\n\n",
    196         ("message", "first"), ("message", "second"),
    197         id: "123"
    198     );
    199     check!("event:customEvent\ndata:test\n\n", ("customEvent", "test"));
    200     check!("event:\ndata:test\n\n", ("", "test"));
    201     check!("event:my event\ndata:test\n\n", ("my event", "test"));
    202 
    203     check!("retry:3000\n\n", retry: 3000);
    204     check!("retry:0\n\n", retry: 0);
    205     check!("retry:abc\n\n");
    206     check!("retry:-1000\n\n");
    207     check!("retry:1000.5\n\n");
    208     check!("retry:1000\nretry:2000\n\n", retry: 2000);
    209 
    210     check!(":comment\n\n");
    211     check!(": comment\n\n");
    212     check!(":comment\ndata:test\n\n", ("message", "test"));
    213     check!("unknown:value\ndata:test\n\n", ("message", "test"));
    214     check!("datta:test\n\n");
    215     check!(" data:test\n\n");
    216     check!("data :test\n\n");
    217     check!("data:\tvalue\n\n", ("message", "\tvalue"));
    218     check!("id:123\n\n", id: "123");
    219     check!("event:test\n\n");
    220     check!("event:test\n\ndata:value\n\n", ("message", "value"));
    221 }