taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

server.rs (10973B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{fmt::Debug, io::Write, pin::Pin};
     18 
     19 use axum::{
     20     Router,
     21     body::{Body, Bytes},
     22     extract::{Query, Request, State},
     23     http::{
     24         HeaderMap, HeaderValue, Method, StatusCode, Uri,
     25         header::{self, AUTHORIZATION, AsHeaderName, IntoHeaderName},
     26         uri::PathAndQuery,
     27     },
     28     middleware::{self},
     29 };
     30 use flate2::{Compression, write::ZlibEncoder};
     31 use http_body_util::BodyExt as _;
     32 use serde::{Deserialize, Serialize, de::DeserializeOwned};
     33 use taler_common::{api::ErrorDetail, encoding::base64, error_code::ErrorCode};
     34 use tower::ServiceExt as _;
     35 use tracing::warn;
     36 use url::Url;
     37 
     38 pub trait TestServer {
     39     fn prefix(&self, prefix: &'static str) -> Self;
     40     fn suffix(&self, suffix: &'static str) -> Self;
     41 
     42     fn request(&self, method: Method, path: impl AsRef<str>) -> TestRequest;
     43 
     44     fn get(&self, path: impl AsRef<str>) -> TestRequest {
     45         self.request(Method::GET, path)
     46     }
     47 
     48     fn post(&self, path: impl AsRef<str>) -> TestRequest {
     49         self.request(Method::POST, path)
     50     }
     51 
     52     fn delete(&self, path: impl AsRef<str>) -> TestRequest {
     53         self.request(Method::DELETE, path)
     54     }
     55 }
     56 
     57 impl TestServer for Router {
     58     fn prefix(&self, prefix: &'static str) -> Self {
     59         Router::new()
     60             .fallback_service(self.clone().into_service())
     61             .layer(middleware::map_request_with_state(
     62                 prefix,
     63                 async |State(prefix): State<&'static str>, mut req: Request| {
     64                     let uri = req.uri().clone();
     65                     let mut parts = uri.into_parts();
     66 
     67                     let path_and_query = parts.path_and_query.unwrap();
     68                     let current_path = path_and_query.path();
     69 
     70                     let new_path_and_query = match path_and_query.query() {
     71                         Some(query) => format!("{prefix}{current_path}?{query}"),
     72                         None => format!("{prefix}{current_path}"),
     73                     };
     74 
     75                     let new_pq = PathAndQuery::from_maybe_shared(new_path_and_query).unwrap();
     76                     parts.path_and_query = Some(new_pq);
     77                     *req.uri_mut() = Uri::from_parts(parts).unwrap();
     78                     req
     79                 },
     80             ))
     81     }
     82 
     83     fn suffix(&self, suffix: &'static str) -> Self {
     84         Router::new()
     85             .fallback_service(self.clone().into_service())
     86             .layer(middleware::map_request_with_state(
     87                 suffix.trim_start_matches('/'),
     88                 async |State(suffix): State<&'static str>, mut req: Request| {
     89                     let uri = req.uri().clone();
     90                     let mut parts = uri.into_parts();
     91 
     92                     let path_and_query = parts.path_and_query.unwrap();
     93                     let current_path = path_and_query.path();
     94 
     95                     let new_path_and_query = match path_and_query.query() {
     96                         Some(query) => format!("{current_path}{suffix}?{query}"),
     97                         None => format!("{current_path}{suffix}"),
     98                     };
     99                     let new_pq = PathAndQuery::from_maybe_shared(new_path_and_query).unwrap();
    100                     parts.path_and_query = Some(new_pq);
    101                     *req.uri_mut() = Uri::from_parts(parts).unwrap();
    102                     req
    103                 },
    104             ))
    105     }
    106 
    107     fn request(&self, method: Method, path: impl AsRef<str>) -> TestRequest {
    108         let url = format!("https://example{}", path.as_ref());
    109         TestRequest {
    110             router: self.clone(),
    111             method,
    112             url: url.parse().unwrap(),
    113             body: None,
    114             headers: HeaderMap::new(),
    115         }
    116     }
    117 }
    118 
    119 pub struct TestRequest {
    120     router: Router,
    121     method: Method,
    122     pub url: Url,
    123     body: Option<Bytes>,
    124     headers: HeaderMap,
    125 }
    126 
    127 impl TestRequest {
    128     #[track_caller]
    129     pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self {
    130         let mut pairs = self.url.query_pairs_mut();
    131         let serializer = serde_urlencoded::Serializer::new(&mut pairs);
    132         [(k, v)].serialize(serializer).unwrap();
    133         drop(pairs);
    134         self
    135     }
    136 
    137     pub fn json<T: Serialize>(mut self, body: T) -> Self {
    138         assert!(self.body.is_none());
    139         let bytes = serde_json::to_vec(&body).unwrap();
    140         self.body = Some(bytes.into());
    141         self.headers.insert(
    142             header::CONTENT_TYPE,
    143             HeaderValue::from_static("application/json"),
    144         );
    145         self
    146     }
    147 
    148     pub fn raw_json(mut self, raw: Bytes) -> Self {
    149         assert!(self.body.is_none());
    150         self.body = Some(raw);
    151         self.headers.insert(
    152             header::CONTENT_TYPE,
    153             HeaderValue::from_static("application/json"),
    154         );
    155         self
    156     }
    157 
    158     pub fn deflate(mut self) -> Self {
    159         let body = self.body.unwrap();
    160         let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast());
    161         encoder.write_all(&body).unwrap();
    162         let compressed = encoder.finish().unwrap();
    163         self.body = Some(compressed.into());
    164         self.headers.insert(
    165             header::CONTENT_ENCODING,
    166             HeaderValue::from_static("deflate"),
    167         );
    168         self
    169     }
    170 
    171     pub fn remove(mut self, k: impl AsHeaderName) -> Self {
    172         self.headers.remove(k);
    173         self
    174     }
    175 
    176     pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self
    177     where
    178         V: TryInto<HeaderValue>,
    179         V::Error: Debug,
    180     {
    181         self.headers.insert(k, v.try_into().unwrap());
    182         self
    183     }
    184 
    185     pub fn basic_auth(self, username: &str, password: &str) -> Self {
    186         self.header(
    187             AUTHORIZATION,
    188             format!(
    189                 "Basic {}",
    190                 base64::fmt(format!("{username}:{password}").as_bytes())
    191             ),
    192         )
    193     }
    194 
    195     async fn send(self) -> TestResponse {
    196         let Self {
    197             router,
    198             method,
    199             url,
    200             body: req_body,
    201             headers,
    202         } = self;
    203         let uri = Uri::try_from(url.as_str()).unwrap();
    204         let mut req = axum::http::request::Builder::new()
    205             .method(&method)
    206             .uri(&uri)
    207             .body(req_body.clone().map(Body::from).unwrap_or_else(Body::empty))
    208             .unwrap();
    209         *req.headers_mut() = headers;
    210         let resp = router.clone().oneshot(req).await.unwrap();
    211         let (parts, body) = resp.into_parts();
    212         let bytes = body.collect().await.unwrap();
    213         TestResponse {
    214             router,
    215             req_body: req_body.unwrap_or_default(),
    216             res_body: bytes.to_bytes(),
    217             method,
    218             status: parts.status,
    219             uri,
    220         }
    221     }
    222 }
    223 
    224 impl IntoFuture for TestRequest {
    225     type Output = TestResponse;
    226     type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
    227 
    228     fn into_future(self) -> Self::IntoFuture {
    229         Box::pin(self.send())
    230     }
    231 }
    232 
    233 #[must_use]
    234 pub struct TestResponse {
    235     pub router: Router,
    236     pub req_body: Bytes,
    237     res_body: Bytes,
    238     pub method: Method,
    239     pub uri: Uri,
    240     pub status: StatusCode,
    241 }
    242 
    243 impl TestResponse {
    244     #[track_caller]
    245     pub fn is_implemented(&self) -> bool {
    246         if self.status == StatusCode::NOT_IMPLEMENTED {
    247             let err: ErrorDetail = self.json_parse();
    248             warn!(
    249                 "{} is not implemented: {}",
    250                 self.uri.path(),
    251                 err.hint.unwrap_or_default()
    252             );
    253             false
    254         } else {
    255             true
    256         }
    257     }
    258 
    259     #[track_caller]
    260     pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T {
    261         let Self {
    262             status,
    263             res_body: bytes,
    264             method,
    265             uri,
    266             ..
    267         } = self;
    268         match serde_json::from_slice(bytes) {
    269             Ok(body) => body,
    270             Err(err) => panic!(
    271                 "{method} {uri} {status} invalid JSON body: {err}\n{}",
    272                 String::from_utf8_lossy(bytes)
    273             ),
    274         }
    275     }
    276 
    277     #[track_caller]
    278     pub fn assert_status(&self, expected: StatusCode) {
    279         let Self {
    280             status,
    281             res_body: bytes,
    282             method,
    283             uri,
    284             ..
    285         } = self;
    286         if expected != *status {
    287             if status.is_success() || bytes.is_empty() {
    288                 panic!("{method} {uri} expected {expected} got {status}");
    289             } else {
    290                 let err: ErrorDetail = self.json_parse();
    291                 let error: ErrorCode = ErrorCode::try_from(err.code).expect("Unknown error code");
    292                 let description = err.hint.unwrap_or_default();
    293                 panic!("{method} {uri} expected {expected} got {status}: {error} {description}");
    294             }
    295         }
    296     }
    297 
    298     #[track_caller]
    299     pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T {
    300         self.assert_ok();
    301         self.json_parse()
    302     }
    303 
    304     #[track_caller]
    305     pub fn assert_accepted_json<'de, T: Deserialize<'de>>(&'de self) -> T {
    306         self.assert_accepted();
    307         self.json_parse()
    308     }
    309 
    310     #[track_caller]
    311     pub fn assert_ok(&self) {
    312         self.assert_status(StatusCode::OK);
    313     }
    314 
    315     #[track_caller]
    316     pub fn assert_accepted(&self) {
    317         self.assert_status(StatusCode::ACCEPTED);
    318     }
    319 
    320     #[track_caller]
    321     pub fn assert_no_content(&self) {
    322         self.assert_status(StatusCode::NO_CONTENT);
    323     }
    324 
    325     #[track_caller]
    326     pub fn assert_not_implemented(&self) {
    327         self.assert_status(StatusCode::NOT_IMPLEMENTED);
    328     }
    329 
    330     #[track_caller]
    331     pub fn assert_error(&self, error_code: ErrorCode) {
    332         self.assert_error_status(
    333             error_code,
    334             StatusCode::from_u16(error_code.status_code()).unwrap(),
    335         );
    336     }
    337 
    338     #[track_caller]
    339     pub fn assert_bad_request(&self) {
    340         self.assert_error(ErrorCode::GENERIC_JSON_INVALID);
    341     }
    342 
    343     #[track_caller]
    344     pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) {
    345         self.assert_status(status);
    346         let err: ErrorDetail = self.json_parse();
    347         assert_eq!(error_code, ErrorCode::try_from(err.code).unwrap());
    348     }
    349 
    350     #[track_caller]
    351     pub fn query<T: DeserializeOwned>(&self) -> T {
    352         Query::try_from_uri(&self.uri).unwrap().0
    353     }
    354 }