taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

server.rs (7534B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{fmt::Debug, pin::Pin};
     18 
     19 use axum::{
     20     Router,
     21     body::{Body, Bytes},
     22     extract::Query,
     23     http::{
     24         HeaderMap, HeaderValue, Method, StatusCode, Uri,
     25         header::{self, AsHeaderName, IntoHeaderName},
     26     },
     27 };
     28 use http_body_util::BodyExt as _;
     29 use libdeflater::CompressionLvl;
     30 use serde::{Deserialize, Serialize, de::DeserializeOwned};
     31 use taler_common::{api_common::ErrorDetail, error_code::ErrorCode};
     32 use tower::ServiceExt as _;
     33 use tracing::warn;
     34 use url::Url;
     35 
     36 pub trait TestServer {
     37     fn method(&self, method: Method, path: &str) -> TestRequest;
     38 
     39     fn get(&self, path: &str) -> TestRequest {
     40         self.method(Method::GET, path)
     41     }
     42 
     43     fn post(&self, path: &str) -> TestRequest {
     44         self.method(Method::POST, path)
     45     }
     46 }
     47 
     48 impl TestServer for Router {
     49     fn method(&self, method: Method, path: &str) -> TestRequest {
     50         let url = format!("https://example{path}");
     51         TestRequest {
     52             router: self.clone(),
     53             method,
     54             url: url.parse().unwrap(),
     55             body: None,
     56             headers: HeaderMap::new(),
     57         }
     58     }
     59 }
     60 
     61 pub struct TestRequest {
     62     router: Router,
     63     method: Method,
     64     url: Url,
     65     body: Option<Vec<u8>>,
     66     headers: HeaderMap,
     67 }
     68 
     69 impl TestRequest {
     70     #[track_caller]
     71     pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self {
     72         let mut pairs = self.url.query_pairs_mut();
     73         let serializer = serde_urlencoded::Serializer::new(&mut pairs);
     74         [(k, v)].serialize(serializer).unwrap();
     75         drop(pairs);
     76         self
     77     }
     78 
     79     pub fn json<T: Serialize>(mut self, body: &T) -> Self {
     80         assert!(self.body.is_none());
     81         let bytes = serde_json::to_vec(body).unwrap();
     82         self.body = Some(bytes);
     83         self.headers.insert(
     84             header::CONTENT_TYPE,
     85             HeaderValue::from_static("application/json"),
     86         );
     87         self
     88     }
     89 
     90     pub fn deflate(mut self) -> Self {
     91         let body = self.body.unwrap();
     92         let mut compressor = libdeflater::Compressor::new(CompressionLvl::fastest());
     93         let mut compressed = vec![0; compressor.zlib_compress_bound(body.len())];
     94         let nb = compressor.zlib_compress(&body, &mut compressed).unwrap();
     95         compressed.truncate(nb);
     96         self.body = Some(compressed);
     97         self.headers.insert(
     98             header::CONTENT_ENCODING,
     99             HeaderValue::from_static("deflate"),
    100         );
    101         self
    102     }
    103 
    104     pub fn remove(mut self, k: impl AsHeaderName) -> Self {
    105         self.headers.remove(k);
    106         self
    107     }
    108 
    109     pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self
    110     where
    111         V: TryInto<HeaderValue>,
    112         V::Error: Debug,
    113     {
    114         self.headers.insert(k, v.try_into().unwrap());
    115         self
    116     }
    117 
    118     async fn send(self) -> TestResponse {
    119         let TestRequest {
    120             router,
    121             method,
    122             url: uri,
    123             body,
    124             headers,
    125         } = self;
    126         let uri = Uri::try_from(uri.as_str()).unwrap();
    127         let mut builder = axum::http::request::Builder::new()
    128             .method(&method)
    129             .uri(&uri);
    130         for (k, v) in headers {
    131             if let Some(k) = k {
    132                 builder = builder.header(k, v);
    133             } else {
    134                 builder = builder.header("", v);
    135             }
    136         }
    137 
    138         let resp = router
    139             .oneshot(builder.body(Body::from(body.unwrap_or_default())).unwrap())
    140             .await
    141             .unwrap();
    142         let (parts, body) = resp.into_parts();
    143         let bytes = body.collect().await.unwrap();
    144         TestResponse {
    145             bytes: bytes.to_bytes(),
    146             method,
    147             status: parts.status,
    148             uri,
    149         }
    150     }
    151 }
    152 
    153 impl IntoFuture for TestRequest {
    154     type Output = TestResponse;
    155     type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>;
    156 
    157     fn into_future(self) -> Self::IntoFuture {
    158         Box::pin(self.send())
    159     }
    160 }
    161 
    162 pub struct TestResponse {
    163     bytes: Bytes,
    164     method: Method,
    165     uri: Uri,
    166     status: StatusCode,
    167 }
    168 
    169 impl TestResponse {
    170     #[track_caller]
    171     pub fn is_implemented(&self) -> bool {
    172         if self.status == StatusCode::NOT_IMPLEMENTED {
    173             let err: ErrorDetail = self.json_parse();
    174             warn!(
    175                 "{} is not implemented: {}",
    176                 self.uri.path(),
    177                 err.hint.unwrap_or_default()
    178             );
    179             false
    180         } else {
    181             true
    182         }
    183     }
    184 
    185     #[track_caller]
    186     pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T {
    187         let TestResponse {
    188             status,
    189             bytes,
    190             method,
    191             uri,
    192         } = self;
    193         match serde_json::from_slice(bytes) {
    194             Ok(body) => body,
    195             Err(err) => match serde_json::from_slice::<serde_json::Value>(bytes) {
    196                 Ok(raw) => panic!("{method} {uri} {status} invalid JSON schema: {err}\n{raw}"),
    197                 Err(err) => panic!(
    198                     "{method} {uri} {status} invalid JSON body: {err}\n{}",
    199                     String::from_utf8_lossy(bytes)
    200                 ),
    201             },
    202         }
    203     }
    204 
    205     #[track_caller]
    206     pub fn assert_status(&self, expected: StatusCode) {
    207         let TestResponse {
    208             status,
    209             bytes,
    210             method,
    211             uri,
    212         } = self;
    213         if expected != *status {
    214             if status.is_success() || bytes.is_empty() {
    215                 panic!("{method} {uri} expected {expected} got {status}");
    216             } else {
    217                 let err: ErrorDetail = self.json_parse();
    218                 let description = err.hint.unwrap_or_default();
    219                 panic!(
    220                     "{method} {uri} expected {expected} got {status}: {} {description}",
    221                     err.code
    222                 );
    223             }
    224         }
    225     }
    226 
    227     #[track_caller]
    228     pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T {
    229         self.assert_ok();
    230         self.json_parse()
    231     }
    232 
    233     #[track_caller]
    234     pub fn assert_ok(&self) {
    235         self.assert_status(StatusCode::OK);
    236     }
    237 
    238     #[track_caller]
    239     pub fn assert_no_content(&self) {
    240         self.assert_status(StatusCode::NO_CONTENT);
    241     }
    242 
    243     #[track_caller]
    244     pub fn assert_error(&self, error_code: ErrorCode) {
    245         let (status_code, _) = error_code.metadata();
    246         self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap());
    247     }
    248 
    249     #[track_caller]
    250     pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) {
    251         self.assert_status(status);
    252         let err: ErrorDetail = self.json_parse();
    253         assert_eq!(error_code as u32, err.code);
    254     }
    255 
    256     #[track_caller]
    257     pub fn query<T: DeserializeOwned>(&self) -> T {
    258         Query::try_from_uri(&self.uri).unwrap().0
    259     }
    260 }