commit 933dabf6110208e342b74115862f5abed06b4fea
parent d87dfa3dac50ac6ccef95f10c261aee92bd6aff1
Author: Antoine A <>
Date: Fri, 28 Mar 2025 11:20:43 +0100
common: improve route and middleware logic
Diffstat:
6 files changed, 71 insertions(+), 49 deletions(-)
diff --git a/common/taler-api/src/api.rs b/common/taler-api/src/api.rs
@@ -62,54 +62,38 @@ pub trait TalerApi: Send + Sync + 'static {
}
}
-pub struct TalerApiBuilder {
- router: Router,
+pub trait TalerRouter {
+ fn auth(self, auth: AuthMethod) -> Self;
+ fn finalize(self) -> Self;
+ fn serve(
+ self,
+ serve: Serve,
+ lifetime: Option<u32>,
+ ) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
}
-impl TalerApiBuilder {
- pub fn new() -> Self {
- Self {
- router: Router::new(),
- }
- }
-
- pub fn wire_gateway<T: WireGateway>(mut self, api: Arc<T>, auth: AuthMethod) -> Self {
- self.router = self.router.nest(
- "/taler-wire-gateway",
- wire::router(api).layer(middleware::from_fn_with_state(
- Arc::new(auth),
- crate::auth::auth_middleware,
- )),
- );
- self
- }
-
- pub fn revenue<T: Revenue>(mut self, api: Arc<T>, auth: AuthMethod) -> Self {
- self.router = self.router.nest(
- "/taler-revenue",
- revenue::router(api).layer(middleware::from_fn_with_state(
- Arc::new(auth),
- crate::auth::auth_middleware,
- )),
- );
- self
+impl TalerRouter for Router {
+ fn auth(self, auth: AuthMethod) -> Self {
+ self.layer(middleware::from_fn_with_state(
+ Arc::new(auth),
+ crate::auth::auth_middleware,
+ ))
}
- pub fn finalize(self) -> Router {
- self.router
- .method_not_allowed_fallback(|| async {
- failure_code(ErrorCode::GENERIC_METHOD_INVALID)
- })
- .fallback(|| async { failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN) })
- .layer(middleware::from_fn(logger_middleware))
+ fn finalize(self) -> Router {
+ self.method_not_allowed_fallback(|| async {
+ failure_code(ErrorCode::GENERIC_METHOD_INVALID)
+ })
+ .fallback(|| async { failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN) })
+ .layer(middleware::from_fn(logger_middleware))
}
- pub async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> {
+ async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> {
let listener = serve.resolve()?;
let notify = Arc::new(tokio::sync::Notify::new());
if let Some(lifetime) = lifetime {
- self.router = self.router.layer(middleware::from_fn_with_state(
+ self = self.layer(middleware::from_fn_with_state(
Arc::new(LifetimeMiddlewareState {
notify: notify.clone(),
lifetime: AtomicU32::new(lifetime),
@@ -137,6 +121,44 @@ impl TalerApiBuilder {
}
}
+pub struct TalerApiBuilder {
+ router: Router,
+}
+
+impl TalerApiBuilder {
+ pub fn new() -> Self {
+ Self {
+ router: Router::new(),
+ }
+ }
+
+ pub fn wire_gateway<T: WireGateway>(mut self, api: Arc<T>, auth: AuthMethod) -> Self {
+ self.router = self.router.nest(
+ "/taler-wire-gateway",
+ wire::router(api).layer(middleware::from_fn_with_state(
+ Arc::new(auth),
+ crate::auth::auth_middleware,
+ )),
+ );
+ self
+ }
+
+ pub fn revenue<T: Revenue>(mut self, api: Arc<T>, auth: AuthMethod) -> Self {
+ self.router = self.router.nest(
+ "/taler-revenue",
+ revenue::router(api).layer(middleware::from_fn_with_state(
+ Arc::new(auth),
+ crate::auth::auth_middleware,
+ )),
+ );
+ self
+ }
+
+ pub fn build(self) -> Router {
+ self.router
+ }
+}
+
struct LifetimeMiddlewareState {
lifetime: AtomicU32,
notify: Arc<tokio::sync::Notify>,
diff --git a/common/taler-api/src/json.rs b/common/taler-api/src/json.rs
@@ -41,7 +41,6 @@ where
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
// Check content type
- println!("{:?}", req.headers());
match req.headers().get(header::CONTENT_TYPE) {
Some(header) => {
if header != "application/json" {
@@ -117,13 +116,13 @@ where
let bytes = if compressed {
let mut buf = vec![0; MAX_BODY_LENGTH];
- match libdeflater::Decompressor::new().deflate_decompress(&bytes, &mut buf) {
+ match libdeflater::Decompressor::new().zlib_decompress(&bytes, &mut buf) {
Ok(it) => Bytes::copy_from_slice(&buf[..it]),
Err(it) => match it {
libdeflater::DecompressionError::BadData => {
return Err(failure(
ErrorCode::GENERIC_COMPRESSION_INVALID,
- "Failed to decompress body: invalid gzip",
+ "Failed to decompress body: invalid compression",
));
}
libdeflater::DecompressionError::InsufficientSpace => {
diff --git a/common/taler-api/tests/common/mod.rs b/common/taler-api/tests/common/mod.rs
@@ -20,10 +20,10 @@ use axum::Router;
use db::notification_listener;
use sqlx::PgPool;
use taler_api::{
- api::{TalerApi, TalerApiBuilder, revenue::Revenue, wire::WireGateway},
+ api::{revenue::Revenue, wire::WireGateway, TalerApi, TalerApiBuilder, TalerRouter as _},
auth::AuthMethod,
db::IncomingType,
- error::{ApiResult, failure},
+ error::{failure, ApiResult},
};
use taler_common::{
api_params::{History, Page},
@@ -203,5 +203,5 @@ pub async fn setup() -> (Router, PgPool) {
let pool = db_test_setup("taler-api").await;
let api = test_api(pool.clone(), "EUR".to_string()).await;
- (api.finalize(), pool)
+ (api.build().finalize(), pool)
}
diff --git a/common/taler-test-utils/src/server.rs b/common/taler-test-utils/src/server.rs
@@ -89,8 +89,8 @@ impl TestRequest {
pub fn deflate(mut self) -> Self {
let body = self.body.unwrap();
let mut compressor = libdeflater::Compressor::new(CompressionLvl::fastest());
- let mut compressed = vec![0; compressor.deflate_compress_bound(body.len())];
- let nb = compressor.deflate_compress(&body, &mut compressed).unwrap();
+ let mut compressed = vec![0; compressor.zlib_compress_bound(body.len())];
+ let nb = compressor.zlib_compress(&body, &mut compressed).unwrap();
compressed.truncate(nb);
self.body = Some(compressed);
self.headers.insert(
diff --git a/taler-magnet-bank/src/main.rs b/taler-magnet-bank/src/main.rs
@@ -17,7 +17,7 @@
use std::sync::Arc;
use clap::Parser;
-use taler_api::api::TalerApiBuilder;
+use taler_api::api::{TalerApiBuilder, TalerRouter as _};
use taler_common::{
CommonArgs,
cli::ConfigCmd,
@@ -119,7 +119,7 @@ async fn app(args: Args, cfg: Config) -> anyhow::Result<()> {
if let Some(cfg) = cfg.revenue {
builder = builder.revenue(api, cfg.auth);
}
- builder.serve(cfg.serve, None).await?;
+ builder.build().serve(cfg.serve, None).await?;
}
}
Command::Worker { transient: _ } => {
diff --git a/taler-magnet-bank/tests/api.rs b/taler-magnet-bank/tests/api.rs
@@ -17,7 +17,7 @@
use std::sync::Arc;
use sqlx::PgPool;
-use taler_api::{api::TalerApiBuilder, auth::AuthMethod, subject::OutgoingSubject};
+use taler_api::{api::{TalerApiBuilder, TalerRouter as _}, auth::AuthMethod, subject::OutgoingSubject};
use taler_common::{
api_common::ShortHashCode,
api_wire::{OutgoingHistory, TransferState},
@@ -41,6 +41,7 @@ async fn setup() -> (Router, PgPool) {
let server = TalerApiBuilder::new()
.wire_gateway(api.clone(), AuthMethod::None)
.revenue(api, AuthMethod::None)
+ .build()
.finalize();
(server, pool)