taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

api.rs (7182B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{
     18     sync::{
     19         Arc,
     20         atomic::{AtomicU32, Ordering},
     21     },
     22     time::Instant,
     23 };
     24 
     25 use axum::{
     26     extract::{Request, State},
     27     middleware::{self, Next},
     28     response::Response,
     29 };
     30 use revenue::Revenue;
     31 use taler_common::{error_code::ErrorCode, types::amount::Amount};
     32 use tokio::signal;
     33 use tracing::{Level, info};
     34 use wire::WireGateway;
     35 
     36 use crate::{
     37     Listener, Serve,
     38     auth::{AuthMethod, AuthMiddlewareState},
     39     error::{ApiResult, failure, failure_code},
     40 };
     41 
     42 pub mod revenue;
     43 pub mod wire;
     44 
     45 pub use axum::Router;
     46 
     47 pub trait TalerApi: Send + Sync + 'static {
     48     fn currency(&self) -> &str;
     49     fn implementation(&self) -> Option<&str>;
     50     fn check_currency(&self, amount: &Amount) -> ApiResult<()> {
     51         let currency = self.currency();
     52         if amount.currency.as_ref() != currency {
     53             Err(failure(
     54                 ErrorCode::GENERIC_CURRENCY_MISMATCH,
     55                 format!(
     56                     "wrong currency expected {} got {}",
     57                     currency, amount.currency
     58                 ),
     59             ))
     60         } else {
     61             Ok(())
     62         }
     63     }
     64 }
     65 
     66 pub trait RouterUtils {
     67     fn auth(self, auth: AuthMethod, realm: &str) -> Self;
     68 }
     69 
     70 impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> {
     71     fn auth(self, auth: AuthMethod, realm: &str) -> Self {
     72         self.route_layer(middleware::from_fn_with_state(
     73             Arc::new(AuthMiddlewareState::new(auth, realm)),
     74             crate::auth::auth_middleware,
     75         ))
     76     }
     77 }
     78 
     79 pub trait TalerRouter {
     80     fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self;
     81     fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self;
     82     fn finalize(self) -> Self;
     83     fn serve(
     84         self,
     85         serve: Serve,
     86         lifetime: Option<u32>,
     87     ) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
     88 }
     89 
     90 impl TalerRouter for Router {
     91     fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self {
     92         self.nest("/taler-wire-gateway", wire::router(api, auth))
     93     }
     94 
     95     fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self {
     96         self.nest("/taler-revenue", revenue::router(api, auth))
     97     }
     98 
     99     fn finalize(self) -> Router {
    100         self.method_not_allowed_fallback(|| async {
    101             failure_code(ErrorCode::GENERIC_METHOD_INVALID)
    102         })
    103         .fallback(|| async { failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN) })
    104         .layer(middleware::from_fn(logger_middleware))
    105     }
    106 
    107     async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> {
    108         let listener = serve.resolve()?;
    109 
    110         let notify = Arc::new(tokio::sync::Notify::new());
    111         if let Some(lifetime) = lifetime {
    112             self = self.layer(middleware::from_fn_with_state(
    113                 Arc::new(LifetimeMiddlewareState {
    114                     notify: notify.clone(),
    115                     lifetime: AtomicU32::new(lifetime),
    116                 }),
    117                 lifetime_middleware,
    118             ))
    119         }
    120         let router = self.finalize();
    121         let signal = shutdown_signal(notify);
    122         match listener {
    123             Listener::Tcp(tcp_listener) => {
    124                 axum::serve(tcp_listener, router)
    125                     .with_graceful_shutdown(signal)
    126                     .await?;
    127             }
    128             Listener::Unix(unix_listener) => {
    129                 axum::serve(unix_listener, router)
    130                     .with_graceful_shutdown(signal)
    131                     .await?;
    132             }
    133         }
    134 
    135         info!(target: "api", "Server stopped");
    136         Ok(())
    137     }
    138 }
    139 
    140 struct LifetimeMiddlewareState {
    141     lifetime: AtomicU32,
    142     notify: Arc<tokio::sync::Notify>,
    143 }
    144 
    145 async fn lifetime_middleware(
    146     State(state): State<Arc<LifetimeMiddlewareState>>,
    147     request: Request,
    148     next: Next,
    149 ) -> Response {
    150     let mut current = state.lifetime.load(Ordering::Relaxed);
    151     while current != 0 {
    152         match state.lifetime.compare_exchange_weak(
    153             current,
    154             current - 1,
    155             Ordering::Relaxed,
    156             Ordering::Relaxed,
    157         ) {
    158             Ok(_) => break,
    159             Err(new) => current = new,
    160         }
    161     }
    162     if current == 0 {
    163         state.notify.notify_one();
    164     }
    165     next.run(request).await
    166 }
    167 
    168 /** Wait for manual shutdown or system signal shutdown */
    169 async fn shutdown_signal(manual_shutdown: Arc<tokio::sync::Notify>) {
    170     let ctrl_c = async {
    171         signal::ctrl_c()
    172             .await
    173             .expect("failed to install Ctrl+C handler");
    174     };
    175 
    176     #[cfg(unix)]
    177     let terminate = async {
    178         signal::unix::signal(signal::unix::SignalKind::terminate())
    179             .expect("failed to install signal handler")
    180             .recv()
    181             .await;
    182     };
    183 
    184     #[cfg(not(unix))]
    185     let terminate = std::future::pending::<()>();
    186 
    187     let manual = async { manual_shutdown.notified().await };
    188 
    189     tokio::select! {
    190         _ = ctrl_c => {},
    191         _ = terminate => {},
    192         _ = manual => {}
    193     }
    194 }
    195 
    196 #[macro_export]
    197 macro_rules! dyn_event {
    198     ($lvl:ident, $($arg:tt)+) => {
    199         match $lvl {
    200             ::tracing::Level::TRACE => ::tracing::trace!($($arg)+),
    201             ::tracing::Level::DEBUG => ::tracing::debug!($($arg)+),
    202             ::tracing::Level::INFO => ::tracing::info!($($arg)+),
    203             ::tracing::Level::WARN => ::tracing::warn!($($arg)+),
    204             ::tracing::Level::ERROR => ::tracing::error!($($arg)+),
    205         }
    206     };
    207 }
    208 
    209 /** Taler API logger */
    210 async fn logger_middleware(request: Request, next: Next) -> Response {
    211     let request_info = format!(
    212         "{} {}",
    213         request.method(),
    214         request
    215             .uri()
    216             .path_and_query()
    217             .map(|it| it.as_str())
    218             .unwrap_or_default()
    219     );
    220     let now = Instant::now();
    221     let response = next.run(request).await;
    222     let elapsed = now.elapsed();
    223     let status = response.status();
    224     let level = match status.as_u16() {
    225         400..500 => Level::WARN,
    226         500..600 => Level::ERROR,
    227         _ => Level::INFO,
    228     };
    229 
    230     if let Some(log) = response.extensions().get::<Box<str>>() {
    231         dyn_event!(level, target: "api",
    232             "{} {request_info} {}ms: {log}",
    233             response.status(),
    234             elapsed.as_millis()
    235         );
    236     } else {
    237         dyn_event!(level, target: "api",
    238             "{} {request_info} {}ms",
    239             response.status(),
    240             elapsed.as_millis()
    241         );
    242     }
    243     response
    244 }