api.rs (7182B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{ 18 sync::{ 19 Arc, 20 atomic::{AtomicU32, Ordering}, 21 }, 22 time::Instant, 23 }; 24 25 use axum::{ 26 extract::{Request, State}, 27 middleware::{self, Next}, 28 response::Response, 29 }; 30 use revenue::Revenue; 31 use taler_common::{error_code::ErrorCode, types::amount::Amount}; 32 use tokio::signal; 33 use tracing::{Level, info}; 34 use wire::WireGateway; 35 36 use crate::{ 37 Listener, Serve, 38 auth::{AuthMethod, AuthMiddlewareState}, 39 error::{ApiResult, failure, failure_code}, 40 }; 41 42 pub mod revenue; 43 pub mod wire; 44 45 pub use axum::Router; 46 47 pub trait TalerApi: Send + Sync + 'static { 48 fn currency(&self) -> &str; 49 fn implementation(&self) -> Option<&str>; 50 fn check_currency(&self, amount: &Amount) -> ApiResult<()> { 51 let currency = self.currency(); 52 if amount.currency.as_ref() != currency { 53 Err(failure( 54 ErrorCode::GENERIC_CURRENCY_MISMATCH, 55 format!( 56 "wrong currency expected {} got {}", 57 currency, amount.currency 58 ), 59 )) 60 } else { 61 Ok(()) 62 } 63 } 64 } 65 66 pub trait RouterUtils { 67 fn auth(self, auth: AuthMethod, realm: &str) -> Self; 68 } 69 70 impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> { 71 fn auth(self, auth: AuthMethod, realm: &str) -> Self { 72 self.route_layer(middleware::from_fn_with_state( 73 Arc::new(AuthMiddlewareState::new(auth, realm)), 74 crate::auth::auth_middleware, 75 )) 76 } 77 } 78 79 pub trait TalerRouter { 80 fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self; 81 fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self; 82 fn finalize(self) -> Self; 83 fn serve( 84 self, 85 serve: Serve, 86 lifetime: Option<u32>, 87 ) -> impl std::future::Future<Output = std::io::Result<()>> + Send; 88 } 89 90 impl TalerRouter for Router { 91 fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self { 92 self.nest("/taler-wire-gateway", wire::router(api, auth)) 93 } 94 95 fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self { 96 self.nest("/taler-revenue", revenue::router(api, auth)) 97 } 98 99 fn finalize(self) -> Router { 100 self.method_not_allowed_fallback(|| async { 101 failure_code(ErrorCode::GENERIC_METHOD_INVALID) 102 }) 103 .fallback(|| async { failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN) }) 104 .layer(middleware::from_fn(logger_middleware)) 105 } 106 107 async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> { 108 let listener = serve.resolve()?; 109 110 let notify = Arc::new(tokio::sync::Notify::new()); 111 if let Some(lifetime) = lifetime { 112 self = self.layer(middleware::from_fn_with_state( 113 Arc::new(LifetimeMiddlewareState { 114 notify: notify.clone(), 115 lifetime: AtomicU32::new(lifetime), 116 }), 117 lifetime_middleware, 118 )) 119 } 120 let router = self.finalize(); 121 let signal = shutdown_signal(notify); 122 match listener { 123 Listener::Tcp(tcp_listener) => { 124 axum::serve(tcp_listener, router) 125 .with_graceful_shutdown(signal) 126 .await?; 127 } 128 Listener::Unix(unix_listener) => { 129 axum::serve(unix_listener, router) 130 .with_graceful_shutdown(signal) 131 .await?; 132 } 133 } 134 135 info!(target: "api", "Server stopped"); 136 Ok(()) 137 } 138 } 139 140 struct LifetimeMiddlewareState { 141 lifetime: AtomicU32, 142 notify: Arc<tokio::sync::Notify>, 143 } 144 145 async fn lifetime_middleware( 146 State(state): State<Arc<LifetimeMiddlewareState>>, 147 request: Request, 148 next: Next, 149 ) -> Response { 150 let mut current = state.lifetime.load(Ordering::Relaxed); 151 while current != 0 { 152 match state.lifetime.compare_exchange_weak( 153 current, 154 current - 1, 155 Ordering::Relaxed, 156 Ordering::Relaxed, 157 ) { 158 Ok(_) => break, 159 Err(new) => current = new, 160 } 161 } 162 if current == 0 { 163 state.notify.notify_one(); 164 } 165 next.run(request).await 166 } 167 168 /** Wait for manual shutdown or system signal shutdown */ 169 async fn shutdown_signal(manual_shutdown: Arc<tokio::sync::Notify>) { 170 let ctrl_c = async { 171 signal::ctrl_c() 172 .await 173 .expect("failed to install Ctrl+C handler"); 174 }; 175 176 #[cfg(unix)] 177 let terminate = async { 178 signal::unix::signal(signal::unix::SignalKind::terminate()) 179 .expect("failed to install signal handler") 180 .recv() 181 .await; 182 }; 183 184 #[cfg(not(unix))] 185 let terminate = std::future::pending::<()>(); 186 187 let manual = async { manual_shutdown.notified().await }; 188 189 tokio::select! { 190 _ = ctrl_c => {}, 191 _ = terminate => {}, 192 _ = manual => {} 193 } 194 } 195 196 #[macro_export] 197 macro_rules! dyn_event { 198 ($lvl:ident, $($arg:tt)+) => { 199 match $lvl { 200 ::tracing::Level::TRACE => ::tracing::trace!($($arg)+), 201 ::tracing::Level::DEBUG => ::tracing::debug!($($arg)+), 202 ::tracing::Level::INFO => ::tracing::info!($($arg)+), 203 ::tracing::Level::WARN => ::tracing::warn!($($arg)+), 204 ::tracing::Level::ERROR => ::tracing::error!($($arg)+), 205 } 206 }; 207 } 208 209 /** Taler API logger */ 210 async fn logger_middleware(request: Request, next: Next) -> Response { 211 let request_info = format!( 212 "{} {}", 213 request.method(), 214 request 215 .uri() 216 .path_and_query() 217 .map(|it| it.as_str()) 218 .unwrap_or_default() 219 ); 220 let now = Instant::now(); 221 let response = next.run(request).await; 222 let elapsed = now.elapsed(); 223 let status = response.status(); 224 let level = match status.as_u16() { 225 400..500 => Level::WARN, 226 500..600 => Level::ERROR, 227 _ => Level::INFO, 228 }; 229 230 if let Some(log) = response.extensions().get::<Box<str>>() { 231 dyn_event!(level, target: "api", 232 "{} {request_info} {}ms: {log}", 233 response.status(), 234 elapsed.as_millis() 235 ); 236 } else { 237 dyn_event!(level, target: "api", 238 "{} {request_info} {}ms", 239 response.status(), 240 elapsed.as_millis() 241 ); 242 } 243 response 244 }