server.rs (7210B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{fmt::Debug, io::Write, pin::Pin}; 18 19 use axum::{ 20 Router, 21 body::{Body, Bytes}, 22 extract::Query, 23 http::{ 24 HeaderMap, HeaderValue, Method, StatusCode, Uri, 25 header::{self, AsHeaderName, IntoHeaderName}, 26 }, 27 }; 28 use flate2::{Compression, write::ZlibEncoder}; 29 use http_body_util::BodyExt as _; 30 use serde::{Deserialize, Serialize, de::DeserializeOwned}; 31 use taler_common::{api_common::ErrorDetail, error_code::ErrorCode}; 32 use tower::ServiceExt as _; 33 use tracing::warn; 34 use url::Url; 35 36 pub trait TestServer { 37 fn method(&self, method: Method, path: &str) -> TestRequest; 38 39 fn get(&self, path: &str) -> TestRequest { 40 self.method(Method::GET, path) 41 } 42 43 fn post(&self, path: &str) -> TestRequest { 44 self.method(Method::POST, path) 45 } 46 47 fn delete(&self, path: &str) -> TestRequest { 48 self.method(Method::DELETE, path) 49 } 50 } 51 52 impl TestServer for Router { 53 fn method(&self, method: Method, path: &str) -> TestRequest { 54 let url = format!("https://example{path}"); 55 TestRequest { 56 router: self.clone(), 57 method, 58 url: url.parse().unwrap(), 59 body: None, 60 headers: HeaderMap::new(), 61 } 62 } 63 } 64 65 pub struct TestRequest { 66 router: Router, 67 method: Method, 68 url: Url, 69 body: Option<Vec<u8>>, 70 headers: HeaderMap, 71 } 72 73 impl TestRequest { 74 #[track_caller] 75 pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self { 76 let mut pairs = self.url.query_pairs_mut(); 77 let serializer = serde_urlencoded::Serializer::new(&mut pairs); 78 [(k, v)].serialize(serializer).unwrap(); 79 drop(pairs); 80 self 81 } 82 83 pub fn json<T: Serialize>(mut self, body: &T) -> Self { 84 assert!(self.body.is_none()); 85 let bytes = serde_json::to_vec(body).unwrap(); 86 self.body = Some(bytes); 87 self.headers.insert( 88 header::CONTENT_TYPE, 89 HeaderValue::from_static("application/json"), 90 ); 91 self 92 } 93 94 pub fn deflate(mut self) -> Self { 95 let body = self.body.unwrap(); 96 let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast()); 97 encoder.write_all(&body).unwrap(); 98 let compressed = encoder.finish().unwrap(); 99 self.body = Some(compressed); 100 self.headers.insert( 101 header::CONTENT_ENCODING, 102 HeaderValue::from_static("deflate"), 103 ); 104 self 105 } 106 107 pub fn remove(mut self, k: impl AsHeaderName) -> Self { 108 self.headers.remove(k); 109 self 110 } 111 112 pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self 113 where 114 V: TryInto<HeaderValue>, 115 V::Error: Debug, 116 { 117 self.headers.insert(k, v.try_into().unwrap()); 118 self 119 } 120 121 async fn send(self) -> TestResponse { 122 let TestRequest { 123 router, 124 method, 125 url: uri, 126 body, 127 headers, 128 } = self; 129 let uri = Uri::try_from(uri.as_str()).unwrap(); 130 let mut req = axum::http::request::Builder::new() 131 .method(&method) 132 .uri(&uri) 133 .body(body.map(Body::from).unwrap_or_else(Body::empty)) 134 .unwrap(); 135 *req.headers_mut() = headers; 136 let resp = router.oneshot(req).await.unwrap(); 137 let (parts, body) = resp.into_parts(); 138 let bytes = body.collect().await.unwrap(); 139 TestResponse { 140 bytes: bytes.to_bytes(), 141 method, 142 status: parts.status, 143 uri, 144 } 145 } 146 } 147 148 impl IntoFuture for TestRequest { 149 type Output = TestResponse; 150 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>; 151 152 fn into_future(self) -> Self::IntoFuture { 153 Box::pin(self.send()) 154 } 155 } 156 157 #[must_use] 158 pub struct TestResponse { 159 bytes: Bytes, 160 method: Method, 161 uri: Uri, 162 pub status: StatusCode, 163 } 164 165 impl TestResponse { 166 #[track_caller] 167 pub fn is_implemented(&self) -> bool { 168 if self.status == StatusCode::NOT_IMPLEMENTED { 169 let err: ErrorDetail = self.json_parse(); 170 warn!( 171 "{} is not implemented: {}", 172 self.uri.path(), 173 err.hint.unwrap_or_default() 174 ); 175 false 176 } else { 177 true 178 } 179 } 180 181 #[track_caller] 182 pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { 183 let TestResponse { 184 status, 185 bytes, 186 method, 187 uri, 188 } = self; 189 match serde_json::from_slice(bytes) { 190 Ok(body) => body, 191 Err(err) => panic!( 192 "{method} {uri} {status} invalid JSON body: {err}\n{}", 193 String::from_utf8_lossy(bytes) 194 ), 195 } 196 } 197 198 #[track_caller] 199 pub fn assert_status(&self, expected: StatusCode) { 200 let TestResponse { 201 status, 202 bytes, 203 method, 204 uri, 205 } = self; 206 if expected != *status { 207 if status.is_success() || bytes.is_empty() { 208 panic!("{method} {uri} expected {expected} got {status}"); 209 } else { 210 let err: ErrorDetail = self.json_parse(); 211 let description = err.hint.unwrap_or_default(); 212 panic!( 213 "{method} {uri} expected {expected} got {status}: {} {description}", 214 err.code 215 ); 216 } 217 } 218 } 219 220 #[track_caller] 221 pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T { 222 self.assert_ok(); 223 self.json_parse() 224 } 225 226 #[track_caller] 227 pub fn assert_ok(&self) { 228 self.assert_status(StatusCode::OK); 229 } 230 231 #[track_caller] 232 pub fn assert_no_content(&self) { 233 self.assert_status(StatusCode::NO_CONTENT); 234 } 235 236 #[track_caller] 237 pub fn assert_error(&self, error_code: ErrorCode) { 238 let (status_code, _) = error_code.metadata(); 239 self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap()); 240 } 241 242 #[track_caller] 243 pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) { 244 self.assert_status(status); 245 let err: ErrorDetail = self.json_parse(); 246 assert_eq!(error_code as u32, err.code); 247 } 248 249 #[track_caller] 250 pub fn query<T: DeserializeOwned>(&self) -> T { 251 Query::try_from_uri(&self.uri).unwrap().0 252 } 253 }