server.rs (10973B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{fmt::Debug, io::Write, pin::Pin}; 18 19 use axum::{ 20 Router, 21 body::{Body, Bytes}, 22 extract::{Query, Request, State}, 23 http::{ 24 HeaderMap, HeaderValue, Method, StatusCode, Uri, 25 header::{self, AUTHORIZATION, AsHeaderName, IntoHeaderName}, 26 uri::PathAndQuery, 27 }, 28 middleware::{self}, 29 }; 30 use flate2::{Compression, write::ZlibEncoder}; 31 use http_body_util::BodyExt as _; 32 use serde::{Deserialize, Serialize, de::DeserializeOwned}; 33 use taler_common::{api::ErrorDetail, encoding::base64, error_code::ErrorCode}; 34 use tower::ServiceExt as _; 35 use tracing::warn; 36 use url::Url; 37 38 pub trait TestServer { 39 fn prefix(&self, prefix: &'static str) -> Self; 40 fn suffix(&self, suffix: &'static str) -> Self; 41 42 fn request(&self, method: Method, path: impl AsRef<str>) -> TestRequest; 43 44 fn get(&self, path: impl AsRef<str>) -> TestRequest { 45 self.request(Method::GET, path) 46 } 47 48 fn post(&self, path: impl AsRef<str>) -> TestRequest { 49 self.request(Method::POST, path) 50 } 51 52 fn delete(&self, path: impl AsRef<str>) -> TestRequest { 53 self.request(Method::DELETE, path) 54 } 55 } 56 57 impl TestServer for Router { 58 fn prefix(&self, prefix: &'static str) -> Self { 59 Router::new() 60 .fallback_service(self.clone().into_service()) 61 .layer(middleware::map_request_with_state( 62 prefix, 63 async |State(prefix): State<&'static str>, mut req: Request| { 64 let uri = req.uri().clone(); 65 let mut parts = uri.into_parts(); 66 67 let path_and_query = parts.path_and_query.unwrap(); 68 let current_path = path_and_query.path(); 69 70 let new_path_and_query = match path_and_query.query() { 71 Some(query) => format!("{prefix}{current_path}?{query}"), 72 None => format!("{prefix}{current_path}"), 73 }; 74 75 let new_pq = PathAndQuery::from_maybe_shared(new_path_and_query).unwrap(); 76 parts.path_and_query = Some(new_pq); 77 *req.uri_mut() = Uri::from_parts(parts).unwrap(); 78 req 79 }, 80 )) 81 } 82 83 fn suffix(&self, suffix: &'static str) -> Self { 84 Router::new() 85 .fallback_service(self.clone().into_service()) 86 .layer(middleware::map_request_with_state( 87 suffix.trim_start_matches('/'), 88 async |State(suffix): State<&'static str>, mut req: Request| { 89 let uri = req.uri().clone(); 90 let mut parts = uri.into_parts(); 91 92 let path_and_query = parts.path_and_query.unwrap(); 93 let current_path = path_and_query.path(); 94 95 let new_path_and_query = match path_and_query.query() { 96 Some(query) => format!("{current_path}{suffix}?{query}"), 97 None => format!("{current_path}{suffix}"), 98 }; 99 let new_pq = PathAndQuery::from_maybe_shared(new_path_and_query).unwrap(); 100 parts.path_and_query = Some(new_pq); 101 *req.uri_mut() = Uri::from_parts(parts).unwrap(); 102 req 103 }, 104 )) 105 } 106 107 fn request(&self, method: Method, path: impl AsRef<str>) -> TestRequest { 108 let url = format!("https://example{}", path.as_ref()); 109 TestRequest { 110 router: self.clone(), 111 method, 112 url: url.parse().unwrap(), 113 body: None, 114 headers: HeaderMap::new(), 115 } 116 } 117 } 118 119 pub struct TestRequest { 120 router: Router, 121 method: Method, 122 pub url: Url, 123 body: Option<Bytes>, 124 headers: HeaderMap, 125 } 126 127 impl TestRequest { 128 #[track_caller] 129 pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self { 130 let mut pairs = self.url.query_pairs_mut(); 131 let serializer = serde_urlencoded::Serializer::new(&mut pairs); 132 [(k, v)].serialize(serializer).unwrap(); 133 drop(pairs); 134 self 135 } 136 137 pub fn json<T: Serialize>(mut self, body: T) -> Self { 138 assert!(self.body.is_none()); 139 let bytes = serde_json::to_vec(&body).unwrap(); 140 self.body = Some(bytes.into()); 141 self.headers.insert( 142 header::CONTENT_TYPE, 143 HeaderValue::from_static("application/json"), 144 ); 145 self 146 } 147 148 pub fn raw_json(mut self, raw: Bytes) -> Self { 149 assert!(self.body.is_none()); 150 self.body = Some(raw); 151 self.headers.insert( 152 header::CONTENT_TYPE, 153 HeaderValue::from_static("application/json"), 154 ); 155 self 156 } 157 158 pub fn deflate(mut self) -> Self { 159 let body = self.body.unwrap(); 160 let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast()); 161 encoder.write_all(&body).unwrap(); 162 let compressed = encoder.finish().unwrap(); 163 self.body = Some(compressed.into()); 164 self.headers.insert( 165 header::CONTENT_ENCODING, 166 HeaderValue::from_static("deflate"), 167 ); 168 self 169 } 170 171 pub fn remove(mut self, k: impl AsHeaderName) -> Self { 172 self.headers.remove(k); 173 self 174 } 175 176 pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self 177 where 178 V: TryInto<HeaderValue>, 179 V::Error: Debug, 180 { 181 self.headers.insert(k, v.try_into().unwrap()); 182 self 183 } 184 185 pub fn basic_auth(self, username: &str, password: &str) -> Self { 186 self.header( 187 AUTHORIZATION, 188 format!( 189 "Basic {}", 190 base64::fmt(format!("{username}:{password}").as_bytes()) 191 ), 192 ) 193 } 194 195 async fn send(self) -> TestResponse { 196 let Self { 197 router, 198 method, 199 url, 200 body: req_body, 201 headers, 202 } = self; 203 let uri = Uri::try_from(url.as_str()).unwrap(); 204 let mut req = axum::http::request::Builder::new() 205 .method(&method) 206 .uri(&uri) 207 .body(req_body.clone().map(Body::from).unwrap_or_else(Body::empty)) 208 .unwrap(); 209 *req.headers_mut() = headers; 210 let resp = router.clone().oneshot(req).await.unwrap(); 211 let (parts, body) = resp.into_parts(); 212 let bytes = body.collect().await.unwrap(); 213 TestResponse { 214 router, 215 req_body: req_body.unwrap_or_default(), 216 res_body: bytes.to_bytes(), 217 method, 218 status: parts.status, 219 uri, 220 } 221 } 222 } 223 224 impl IntoFuture for TestRequest { 225 type Output = TestResponse; 226 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>; 227 228 fn into_future(self) -> Self::IntoFuture { 229 Box::pin(self.send()) 230 } 231 } 232 233 #[must_use] 234 pub struct TestResponse { 235 pub router: Router, 236 pub req_body: Bytes, 237 res_body: Bytes, 238 pub method: Method, 239 pub uri: Uri, 240 pub status: StatusCode, 241 } 242 243 impl TestResponse { 244 #[track_caller] 245 pub fn is_implemented(&self) -> bool { 246 if self.status == StatusCode::NOT_IMPLEMENTED { 247 let err: ErrorDetail = self.json_parse(); 248 warn!( 249 "{} is not implemented: {}", 250 self.uri.path(), 251 err.hint.unwrap_or_default() 252 ); 253 false 254 } else { 255 true 256 } 257 } 258 259 #[track_caller] 260 pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { 261 let Self { 262 status, 263 res_body: bytes, 264 method, 265 uri, 266 .. 267 } = self; 268 match serde_json::from_slice(bytes) { 269 Ok(body) => body, 270 Err(err) => panic!( 271 "{method} {uri} {status} invalid JSON body: {err}\n{}", 272 String::from_utf8_lossy(bytes) 273 ), 274 } 275 } 276 277 #[track_caller] 278 pub fn assert_status(&self, expected: StatusCode) { 279 let Self { 280 status, 281 res_body: bytes, 282 method, 283 uri, 284 .. 285 } = self; 286 if expected != *status { 287 if status.is_success() || bytes.is_empty() { 288 panic!("{method} {uri} expected {expected} got {status}"); 289 } else { 290 let err: ErrorDetail = self.json_parse(); 291 let error: ErrorCode = ErrorCode::try_from(err.code).expect("Unknown error code"); 292 let description = err.hint.unwrap_or_default(); 293 panic!("{method} {uri} expected {expected} got {status}: {error} {description}"); 294 } 295 } 296 } 297 298 #[track_caller] 299 pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T { 300 self.assert_ok(); 301 self.json_parse() 302 } 303 304 #[track_caller] 305 pub fn assert_accepted_json<'de, T: Deserialize<'de>>(&'de self) -> T { 306 self.assert_accepted(); 307 self.json_parse() 308 } 309 310 #[track_caller] 311 pub fn assert_ok(&self) { 312 self.assert_status(StatusCode::OK); 313 } 314 315 #[track_caller] 316 pub fn assert_accepted(&self) { 317 self.assert_status(StatusCode::ACCEPTED); 318 } 319 320 #[track_caller] 321 pub fn assert_no_content(&self) { 322 self.assert_status(StatusCode::NO_CONTENT); 323 } 324 325 #[track_caller] 326 pub fn assert_not_implemented(&self) { 327 self.assert_status(StatusCode::NOT_IMPLEMENTED); 328 } 329 330 #[track_caller] 331 pub fn assert_error(&self, error_code: ErrorCode) { 332 self.assert_error_status( 333 error_code, 334 StatusCode::from_u16(error_code.status_code()).unwrap(), 335 ); 336 } 337 338 #[track_caller] 339 pub fn assert_bad_request(&self) { 340 self.assert_error(ErrorCode::GENERIC_JSON_INVALID); 341 } 342 343 #[track_caller] 344 pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) { 345 self.assert_status(status); 346 let err: ErrorDetail = self.json_parse(); 347 assert_eq!(error_code, ErrorCode::try_from(err.code).unwrap()); 348 } 349 350 #[track_caller] 351 pub fn query<T: DeserializeOwned>(&self) -> T { 352 Query::try_from_uri(&self.uri).unwrap().0 353 } 354 }