taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

commit 933dabf6110208e342b74115862f5abed06b4fea
parent d87dfa3dac50ac6ccef95f10c261aee92bd6aff1
Author: Antoine A <>
Date:   Fri, 28 Mar 2025 11:20:43 +0100

common: improve route and middleware logic

Diffstat:
Mcommon/taler-api/src/api.rs | 98++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
Mcommon/taler-api/src/json.rs | 5++---
Mcommon/taler-api/tests/common/mod.rs | 6+++---
Mcommon/taler-test-utils/src/server.rs | 4++--
Mtaler-magnet-bank/src/main.rs | 4++--
Mtaler-magnet-bank/tests/api.rs | 3++-
6 files changed, 71 insertions(+), 49 deletions(-)

diff --git a/common/taler-api/src/api.rs b/common/taler-api/src/api.rs @@ -62,54 +62,38 @@ pub trait TalerApi: Send + Sync + 'static { } } -pub struct TalerApiBuilder { - router: Router, +pub trait TalerRouter { + fn auth(self, auth: AuthMethod) -> Self; + fn finalize(self) -> Self; + fn serve( + self, + serve: Serve, + lifetime: Option<u32>, + ) -> impl std::future::Future<Output = std::io::Result<()>> + Send; } -impl TalerApiBuilder { - pub fn new() -> Self { - Self { - router: Router::new(), - } - } - - pub fn wire_gateway<T: WireGateway>(mut self, api: Arc<T>, auth: AuthMethod) -> Self { - self.router = self.router.nest( - "/taler-wire-gateway", - wire::router(api).layer(middleware::from_fn_with_state( - Arc::new(auth), - crate::auth::auth_middleware, - )), - ); - self - } - - pub fn revenue<T: Revenue>(mut self, api: Arc<T>, auth: AuthMethod) -> Self { - self.router = self.router.nest( - "/taler-revenue", - revenue::router(api).layer(middleware::from_fn_with_state( - Arc::new(auth), - crate::auth::auth_middleware, - )), - ); - self +impl TalerRouter for Router { + fn auth(self, auth: AuthMethod) -> Self { + self.layer(middleware::from_fn_with_state( + Arc::new(auth), + crate::auth::auth_middleware, + )) } - pub fn finalize(self) -> Router { - self.router - .method_not_allowed_fallback(|| async { - failure_code(ErrorCode::GENERIC_METHOD_INVALID) - }) - .fallback(|| async { failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN) }) - .layer(middleware::from_fn(logger_middleware)) + fn finalize(self) -> Router { + self.method_not_allowed_fallback(|| async { + failure_code(ErrorCode::GENERIC_METHOD_INVALID) + }) + .fallback(|| async { failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN) }) + .layer(middleware::from_fn(logger_middleware)) } - pub async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> { + async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> { let listener = serve.resolve()?; let notify = Arc::new(tokio::sync::Notify::new()); if let Some(lifetime) = lifetime { - self.router = self.router.layer(middleware::from_fn_with_state( + self = self.layer(middleware::from_fn_with_state( Arc::new(LifetimeMiddlewareState { notify: notify.clone(), lifetime: AtomicU32::new(lifetime), @@ -137,6 +121,44 @@ impl TalerApiBuilder { } } +pub struct TalerApiBuilder { + router: Router, +} + +impl TalerApiBuilder { + pub fn new() -> Self { + Self { + router: Router::new(), + } + } + + pub fn wire_gateway<T: WireGateway>(mut self, api: Arc<T>, auth: AuthMethod) -> Self { + self.router = self.router.nest( + "/taler-wire-gateway", + wire::router(api).layer(middleware::from_fn_with_state( + Arc::new(auth), + crate::auth::auth_middleware, + )), + ); + self + } + + pub fn revenue<T: Revenue>(mut self, api: Arc<T>, auth: AuthMethod) -> Self { + self.router = self.router.nest( + "/taler-revenue", + revenue::router(api).layer(middleware::from_fn_with_state( + Arc::new(auth), + crate::auth::auth_middleware, + )), + ); + self + } + + pub fn build(self) -> Router { + self.router + } +} + struct LifetimeMiddlewareState { lifetime: AtomicU32, notify: Arc<tokio::sync::Notify>, diff --git a/common/taler-api/src/json.rs b/common/taler-api/src/json.rs @@ -41,7 +41,6 @@ where async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> { // Check content type - println!("{:?}", req.headers()); match req.headers().get(header::CONTENT_TYPE) { Some(header) => { if header != "application/json" { @@ -117,13 +116,13 @@ where let bytes = if compressed { let mut buf = vec![0; MAX_BODY_LENGTH]; - match libdeflater::Decompressor::new().deflate_decompress(&bytes, &mut buf) { + match libdeflater::Decompressor::new().zlib_decompress(&bytes, &mut buf) { Ok(it) => Bytes::copy_from_slice(&buf[..it]), Err(it) => match it { libdeflater::DecompressionError::BadData => { return Err(failure( ErrorCode::GENERIC_COMPRESSION_INVALID, - "Failed to decompress body: invalid gzip", + "Failed to decompress body: invalid compression", )); } libdeflater::DecompressionError::InsufficientSpace => { diff --git a/common/taler-api/tests/common/mod.rs b/common/taler-api/tests/common/mod.rs @@ -20,10 +20,10 @@ use axum::Router; use db::notification_listener; use sqlx::PgPool; use taler_api::{ - api::{TalerApi, TalerApiBuilder, revenue::Revenue, wire::WireGateway}, + api::{revenue::Revenue, wire::WireGateway, TalerApi, TalerApiBuilder, TalerRouter as _}, auth::AuthMethod, db::IncomingType, - error::{ApiResult, failure}, + error::{failure, ApiResult}, }; use taler_common::{ api_params::{History, Page}, @@ -203,5 +203,5 @@ pub async fn setup() -> (Router, PgPool) { let pool = db_test_setup("taler-api").await; let api = test_api(pool.clone(), "EUR".to_string()).await; - (api.finalize(), pool) + (api.build().finalize(), pool) } diff --git a/common/taler-test-utils/src/server.rs b/common/taler-test-utils/src/server.rs @@ -89,8 +89,8 @@ impl TestRequest { pub fn deflate(mut self) -> Self { let body = self.body.unwrap(); let mut compressor = libdeflater::Compressor::new(CompressionLvl::fastest()); - let mut compressed = vec![0; compressor.deflate_compress_bound(body.len())]; - let nb = compressor.deflate_compress(&body, &mut compressed).unwrap(); + let mut compressed = vec![0; compressor.zlib_compress_bound(body.len())]; + let nb = compressor.zlib_compress(&body, &mut compressed).unwrap(); compressed.truncate(nb); self.body = Some(compressed); self.headers.insert( diff --git a/taler-magnet-bank/src/main.rs b/taler-magnet-bank/src/main.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use clap::Parser; -use taler_api::api::TalerApiBuilder; +use taler_api::api::{TalerApiBuilder, TalerRouter as _}; use taler_common::{ CommonArgs, cli::ConfigCmd, @@ -119,7 +119,7 @@ async fn app(args: Args, cfg: Config) -> anyhow::Result<()> { if let Some(cfg) = cfg.revenue { builder = builder.revenue(api, cfg.auth); } - builder.serve(cfg.serve, None).await?; + builder.build().serve(cfg.serve, None).await?; } } Command::Worker { transient: _ } => { diff --git a/taler-magnet-bank/tests/api.rs b/taler-magnet-bank/tests/api.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use sqlx::PgPool; -use taler_api::{api::TalerApiBuilder, auth::AuthMethod, subject::OutgoingSubject}; +use taler_api::{api::{TalerApiBuilder, TalerRouter as _}, auth::AuthMethod, subject::OutgoingSubject}; use taler_common::{ api_common::ShortHashCode, api_wire::{OutgoingHistory, TransferState}, @@ -41,6 +41,7 @@ async fn setup() -> (Router, PgPool) { let server = TalerApiBuilder::new() .wire_gateway(api.clone(), AuthMethod::None) .revenue(api, AuthMethod::None) + .build() .finalize(); (server, pool)