server.rs (7460B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{fmt::Debug, io::Write, pin::Pin}; 18 19 use axum::{ 20 Router, 21 body::{Body, Bytes}, 22 extract::Query, 23 http::{ 24 HeaderMap, HeaderValue, Method, StatusCode, Uri, 25 header::{self, AsHeaderName, IntoHeaderName}, 26 }, 27 }; 28 use flate2::{Compression, write::ZlibEncoder}; 29 use http_body_util::BodyExt as _; 30 use serde::{Deserialize, Serialize, de::DeserializeOwned}; 31 use taler_common::{api_common::ErrorDetail, error_code::ErrorCode}; 32 use tower::ServiceExt as _; 33 use tracing::warn; 34 use url::Url; 35 36 pub trait TestServer { 37 fn method(&self, method: Method, path: &str) -> TestRequest; 38 39 fn get(&self, path: &str) -> TestRequest { 40 self.method(Method::GET, path) 41 } 42 43 fn post(&self, path: &str) -> TestRequest { 44 self.method(Method::POST, path) 45 } 46 } 47 48 impl TestServer for Router { 49 fn method(&self, method: Method, path: &str) -> TestRequest { 50 let url = format!("https://example{path}"); 51 TestRequest { 52 router: self.clone(), 53 method, 54 url: url.parse().unwrap(), 55 body: None, 56 headers: HeaderMap::new(), 57 } 58 } 59 } 60 61 pub struct TestRequest { 62 router: Router, 63 method: Method, 64 url: Url, 65 body: Option<Vec<u8>>, 66 headers: HeaderMap, 67 } 68 69 impl TestRequest { 70 #[track_caller] 71 pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self { 72 let mut pairs = self.url.query_pairs_mut(); 73 let serializer = serde_urlencoded::Serializer::new(&mut pairs); 74 [(k, v)].serialize(serializer).unwrap(); 75 drop(pairs); 76 self 77 } 78 79 pub fn json<T: Serialize>(mut self, body: &T) -> Self { 80 assert!(self.body.is_none()); 81 let bytes = serde_json::to_vec(body).unwrap(); 82 self.body = Some(bytes); 83 self.headers.insert( 84 header::CONTENT_TYPE, 85 HeaderValue::from_static("application/json"), 86 ); 87 self 88 } 89 90 pub fn deflate(mut self) -> Self { 91 let body = self.body.unwrap(); 92 let mut encoder = ZlibEncoder::new(Vec::new(), Compression::fast()); 93 encoder.write_all(&body).unwrap(); 94 let compressed = encoder.finish().unwrap(); 95 self.body = Some(compressed); 96 self.headers.insert( 97 header::CONTENT_ENCODING, 98 HeaderValue::from_static("deflate"), 99 ); 100 self 101 } 102 103 pub fn remove(mut self, k: impl AsHeaderName) -> Self { 104 self.headers.remove(k); 105 self 106 } 107 108 pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self 109 where 110 V: TryInto<HeaderValue>, 111 V::Error: Debug, 112 { 113 self.headers.insert(k, v.try_into().unwrap()); 114 self 115 } 116 117 async fn send(self) -> TestResponse { 118 let TestRequest { 119 router, 120 method, 121 url: uri, 122 body, 123 headers, 124 } = self; 125 let uri = Uri::try_from(uri.as_str()).unwrap(); 126 let mut builder = axum::http::request::Builder::new() 127 .method(&method) 128 .uri(&uri); 129 for (k, v) in headers { 130 if let Some(k) = k { 131 builder = builder.header(k, v); 132 } else { 133 builder = builder.header("", v); 134 } 135 } 136 137 let resp = router 138 .oneshot(builder.body(Body::from(body.unwrap_or_default())).unwrap()) 139 .await 140 .unwrap(); 141 let (parts, body) = resp.into_parts(); 142 let bytes = body.collect().await.unwrap(); 143 TestResponse { 144 bytes: bytes.to_bytes(), 145 method, 146 status: parts.status, 147 uri, 148 } 149 } 150 } 151 152 impl IntoFuture for TestRequest { 153 type Output = TestResponse; 154 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>; 155 156 fn into_future(self) -> Self::IntoFuture { 157 Box::pin(self.send()) 158 } 159 } 160 161 pub struct TestResponse { 162 bytes: Bytes, 163 method: Method, 164 uri: Uri, 165 status: StatusCode, 166 } 167 168 impl TestResponse { 169 #[track_caller] 170 pub fn is_implemented(&self) -> bool { 171 if self.status == StatusCode::NOT_IMPLEMENTED { 172 let err: ErrorDetail = self.json_parse(); 173 warn!( 174 "{} is not implemented: {}", 175 self.uri.path(), 176 err.hint.unwrap_or_default() 177 ); 178 false 179 } else { 180 true 181 } 182 } 183 184 #[track_caller] 185 pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { 186 let TestResponse { 187 status, 188 bytes, 189 method, 190 uri, 191 } = self; 192 match serde_json::from_slice(bytes) { 193 Ok(body) => body, 194 Err(err) => match serde_json::from_slice::<serde_json::Value>(bytes) { 195 Ok(raw) => panic!("{method} {uri} {status} invalid JSON schema: {err}\n{raw}"), 196 Err(err) => panic!( 197 "{method} {uri} {status} invalid JSON body: {err}\n{}", 198 String::from_utf8_lossy(bytes) 199 ), 200 }, 201 } 202 } 203 204 #[track_caller] 205 pub fn assert_status(&self, expected: StatusCode) { 206 let TestResponse { 207 status, 208 bytes, 209 method, 210 uri, 211 } = self; 212 if expected != *status { 213 if status.is_success() || bytes.is_empty() { 214 panic!("{method} {uri} expected {expected} got {status}"); 215 } else { 216 let err: ErrorDetail = self.json_parse(); 217 let description = err.hint.unwrap_or_default(); 218 panic!( 219 "{method} {uri} expected {expected} got {status}: {} {description}", 220 err.code 221 ); 222 } 223 } 224 } 225 226 #[track_caller] 227 pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T { 228 self.assert_ok(); 229 self.json_parse() 230 } 231 232 #[track_caller] 233 pub fn assert_ok(&self) { 234 self.assert_status(StatusCode::OK); 235 } 236 237 #[track_caller] 238 pub fn assert_no_content(&self) { 239 self.assert_status(StatusCode::NO_CONTENT); 240 } 241 242 #[track_caller] 243 pub fn assert_error(&self, error_code: ErrorCode) { 244 let (status_code, _) = error_code.metadata(); 245 self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap()); 246 } 247 248 #[track_caller] 249 pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) { 250 self.assert_status(status); 251 let err: ErrorDetail = self.json_parse(); 252 assert_eq!(error_code as u32, err.code); 253 } 254 255 #[track_caller] 256 pub fn query<T: DeserializeOwned>(&self) -> T { 257 Query::try_from_uri(&self.uri).unwrap().0 258 } 259 }