server.rs (7534B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{fmt::Debug, pin::Pin}; 18 19 use axum::{ 20 Router, 21 body::{Body, Bytes}, 22 extract::Query, 23 http::{ 24 HeaderMap, HeaderValue, Method, StatusCode, Uri, 25 header::{self, AsHeaderName, IntoHeaderName}, 26 }, 27 }; 28 use http_body_util::BodyExt as _; 29 use libdeflater::CompressionLvl; 30 use serde::{Deserialize, Serialize, de::DeserializeOwned}; 31 use taler_common::{api_common::ErrorDetail, error_code::ErrorCode}; 32 use tower::ServiceExt as _; 33 use tracing::warn; 34 use url::Url; 35 36 pub trait TestServer { 37 fn method(&self, method: Method, path: &str) -> TestRequest; 38 39 fn get(&self, path: &str) -> TestRequest { 40 self.method(Method::GET, path) 41 } 42 43 fn post(&self, path: &str) -> TestRequest { 44 self.method(Method::POST, path) 45 } 46 } 47 48 impl TestServer for Router { 49 fn method(&self, method: Method, path: &str) -> TestRequest { 50 let url = format!("https://example{path}"); 51 TestRequest { 52 router: self.clone(), 53 method, 54 url: url.parse().unwrap(), 55 body: None, 56 headers: HeaderMap::new(), 57 } 58 } 59 } 60 61 pub struct TestRequest { 62 router: Router, 63 method: Method, 64 url: Url, 65 body: Option<Vec<u8>>, 66 headers: HeaderMap, 67 } 68 69 impl TestRequest { 70 #[track_caller] 71 pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self { 72 let mut pairs = self.url.query_pairs_mut(); 73 let serializer = serde_urlencoded::Serializer::new(&mut pairs); 74 [(k, v)].serialize(serializer).unwrap(); 75 drop(pairs); 76 self 77 } 78 79 pub fn json<T: Serialize>(mut self, body: &T) -> Self { 80 assert!(self.body.is_none()); 81 let bytes = serde_json::to_vec(body).unwrap(); 82 self.body = Some(bytes); 83 self.headers.insert( 84 header::CONTENT_TYPE, 85 HeaderValue::from_static("application/json"), 86 ); 87 self 88 } 89 90 pub fn deflate(mut self) -> Self { 91 let body = self.body.unwrap(); 92 let mut compressor = libdeflater::Compressor::new(CompressionLvl::fastest()); 93 let mut compressed = vec![0; compressor.zlib_compress_bound(body.len())]; 94 let nb = compressor.zlib_compress(&body, &mut compressed).unwrap(); 95 compressed.truncate(nb); 96 self.body = Some(compressed); 97 self.headers.insert( 98 header::CONTENT_ENCODING, 99 HeaderValue::from_static("deflate"), 100 ); 101 self 102 } 103 104 pub fn remove(mut self, k: impl AsHeaderName) -> Self { 105 self.headers.remove(k); 106 self 107 } 108 109 pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self 110 where 111 V: TryInto<HeaderValue>, 112 V::Error: Debug, 113 { 114 self.headers.insert(k, v.try_into().unwrap()); 115 self 116 } 117 118 async fn send(self) -> TestResponse { 119 let TestRequest { 120 router, 121 method, 122 url: uri, 123 body, 124 headers, 125 } = self; 126 let uri = Uri::try_from(uri.as_str()).unwrap(); 127 let mut builder = axum::http::request::Builder::new() 128 .method(&method) 129 .uri(&uri); 130 for (k, v) in headers { 131 if let Some(k) = k { 132 builder = builder.header(k, v); 133 } else { 134 builder = builder.header("", v); 135 } 136 } 137 138 let resp = router 139 .oneshot(builder.body(Body::from(body.unwrap_or_default())).unwrap()) 140 .await 141 .unwrap(); 142 let (parts, body) = resp.into_parts(); 143 let bytes = body.collect().await.unwrap(); 144 TestResponse { 145 bytes: bytes.to_bytes(), 146 method, 147 status: parts.status, 148 uri, 149 } 150 } 151 } 152 153 impl IntoFuture for TestRequest { 154 type Output = TestResponse; 155 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>; 156 157 fn into_future(self) -> Self::IntoFuture { 158 Box::pin(self.send()) 159 } 160 } 161 162 pub struct TestResponse { 163 bytes: Bytes, 164 method: Method, 165 uri: Uri, 166 status: StatusCode, 167 } 168 169 impl TestResponse { 170 #[track_caller] 171 pub fn is_implemented(&self) -> bool { 172 if self.status == StatusCode::NOT_IMPLEMENTED { 173 let err: ErrorDetail = self.json_parse(); 174 warn!( 175 "{} is not implemented: {}", 176 self.uri.path(), 177 err.hint.unwrap_or_default() 178 ); 179 false 180 } else { 181 true 182 } 183 } 184 185 #[track_caller] 186 pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { 187 let TestResponse { 188 status, 189 bytes, 190 method, 191 uri, 192 } = self; 193 match serde_json::from_slice(bytes) { 194 Ok(body) => body, 195 Err(err) => match serde_json::from_slice::<serde_json::Value>(bytes) { 196 Ok(raw) => panic!("{method} {uri} {status} invalid JSON schema: {err}\n{raw}"), 197 Err(err) => panic!( 198 "{method} {uri} {status} invalid JSON body: {err}\n{}", 199 String::from_utf8_lossy(bytes) 200 ), 201 }, 202 } 203 } 204 205 #[track_caller] 206 pub fn assert_status(&self, expected: StatusCode) { 207 let TestResponse { 208 status, 209 bytes, 210 method, 211 uri, 212 } = self; 213 if expected != *status { 214 if status.is_success() || bytes.is_empty() { 215 panic!("{method} {uri} expected {expected} got {status}"); 216 } else { 217 let err: ErrorDetail = self.json_parse(); 218 let description = err.hint.unwrap_or_default(); 219 panic!( 220 "{method} {uri} expected {expected} got {status}: {} {description}", 221 err.code 222 ); 223 } 224 } 225 } 226 227 #[track_caller] 228 pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T { 229 self.assert_ok(); 230 self.json_parse() 231 } 232 233 #[track_caller] 234 pub fn assert_ok(&self) { 235 self.assert_status(StatusCode::OK); 236 } 237 238 #[track_caller] 239 pub fn assert_no_content(&self) { 240 self.assert_status(StatusCode::NO_CONTENT); 241 } 242 243 #[track_caller] 244 pub fn assert_error(&self, error_code: ErrorCode) { 245 let (status_code, _) = error_code.metadata(); 246 self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap()); 247 } 248 249 #[track_caller] 250 pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) { 251 self.assert_status(status); 252 let err: ErrorDetail = self.json_parse(); 253 assert_eq!(error_code as u32, err.code); 254 } 255 256 #[track_caller] 257 pub fn query<T: DeserializeOwned>(&self) -> T { 258 Query::try_from_uri(&self.uri).unwrap().0 259 } 260 }