commit f6d45d6bce90d68f5917a2c374e9b2cc252f9009
parent b52599cc54f713657e2b839085c9e8b973816dc5
Author: Antoine A <>
Date: Fri, 14 Nov 2025 17:07:49 +0100
common: improve auth logic
Diffstat:
6 files changed, 165 insertions(+), 112 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
@@ -55,22 +55,22 @@ dependencies = [
[[package]]
name = "anstyle-query"
-version = "1.1.4"
+version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9e231f6134f61b71076a3eab506c379d4f36122f2af15a9ff04415ea4c3339e2"
+checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
- "windows-sys 0.60.2",
+ "windows-sys 0.61.2",
]
[[package]]
name = "anstyle-wincon"
-version = "3.0.10"
+version = "3.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3e0633414522a32ffaac8ac6cc8f748e090c5717661fddeea04219e2344f5f2a"
+checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
- "windows-sys 0.60.2",
+ "windows-sys 0.61.2",
]
[[package]]
@@ -199,9 +199,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
-version = "1.10.1"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
+checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
[[package]]
name = "cast"
@@ -211,9 +211,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cc"
-version = "1.2.45"
+version = "1.2.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "35900b6c8d709fb1d854671ae27aeaa9eec2f8b01b364e1619a40da3e6fe2afe"
+checksum = "b97463e1064cb1b1c1384ad0a0b9c8abd0988e2a91f52606c80ef14aadb63e36"
dependencies = [
"find-msvc-tools",
"shlex",
@@ -659,9 +659,9 @@ checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d"
[[package]]
name = "find-msvc-tools"
-version = "0.1.4"
+version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52051878f80a721bb68ebfbc930e07b65ba72f2da88968ea5c06fd6ca3d3a127"
+checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844"
[[package]]
name = "fnv"
@@ -930,9 +930,9 @@ checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
-version = "1.8.0"
+version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1744436df46f0bde35af3eda22aeaba453aada65d8f1c171cd8a5f59030bd69f"
+checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11"
dependencies = [
"atomic-waker",
"bytes",
@@ -969,9 +969,9 @@ dependencies = [
[[package]]
name = "hyper-util"
-version = "0.1.17"
+version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8"
+checksum = "52e9a2a24dc5c6821e71a7030e1e14b7b632acac55c40e9d2e082c621261bb56"
dependencies = [
"base64",
"bytes",
diff --git a/common/taler-api/src/api.rs b/common/taler-api/src/api.rs
@@ -35,7 +35,7 @@ use wire::WireGateway;
use crate::{
Listener, Serve,
- auth::AuthMethod,
+ auth::{AuthMethod, AuthMiddlewareState},
error::{ApiResult, failure, failure_code},
};
@@ -63,8 +63,20 @@ pub trait TalerApi: Send + Sync + 'static {
}
}
+pub trait RouterUtils {
+ fn auth(self, auth: AuthMethod, realm: &str) -> Self;
+}
+
+impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> {
+ fn auth(self, auth: AuthMethod, realm: &str) -> Self {
+ self.route_layer(middleware::from_fn_with_state(
+ Arc::new(AuthMiddlewareState::new(auth, realm)),
+ crate::auth::auth_middleware,
+ ))
+ }
+}
+
pub trait TalerRouter {
- fn auth(self, auth: AuthMethod) -> Self;
fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self;
fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self;
fn finalize(self) -> Self;
@@ -76,19 +88,12 @@ pub trait TalerRouter {
}
impl TalerRouter for Router {
- fn auth(self, auth: AuthMethod) -> Self {
- self.layer(middleware::from_fn_with_state(
- Arc::new(auth),
- crate::auth::auth_middleware,
- ))
- }
-
fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self {
- self.nest("/taler-wire-gateway", wire::router(api).auth(auth))
+ self.nest("/taler-wire-gateway", wire::router(api, auth))
}
fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self {
- self.nest("/taler-revenue", revenue::router(api).auth(auth))
+ self.nest("/taler-revenue", revenue::router(api, auth))
}
fn finalize(self) -> Router {
diff --git a/common/taler-api/src/api/revenue.rs b/common/taler-api/src/api/revenue.rs
@@ -29,6 +29,8 @@ use taler_common::{
};
use crate::{
+ api::RouterUtils as _,
+ auth::AuthMethod,
constants::{MAX_PAGE_SIZE, MAX_TIMEOUT_MS, REVENUE_API_VERSION},
error::ApiResult,
};
@@ -42,21 +44,9 @@ pub trait Revenue: TalerApi {
) -> impl std::future::Future<Output = ApiResult<RevenueIncomingHistory>> + Send;
}
-pub fn router<I: Revenue>(state: Arc<I>) -> Router {
+pub fn router<I: Revenue>(state: Arc<I>, auth: AuthMethod) -> Router {
Router::new()
.route(
- "/config",
- get(|State(state): State<Arc<I>>| async move {
- Json(RevenueConfig {
- name: "taler-revenue",
- version: REVENUE_API_VERSION,
- currency: state.currency(),
- implementation: state.implementation(),
- })
- .into_response()
- }),
- )
- .route(
"/history",
get(
|State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| async move {
@@ -70,5 +60,18 @@ pub fn router<I: Revenue>(state: Arc<I>) -> Router {
},
),
)
+ .auth(auth, "taler-revenue")
+ .route(
+ "/config",
+ get(|State(state): State<Arc<I>>| async move {
+ Json(RevenueConfig {
+ name: "taler-revenue",
+ version: REVENUE_API_VERSION,
+ currency: state.currency(),
+ implementation: state.implementation(),
+ })
+ .into_response()
+ }),
+ )
.with_state(state)
}
diff --git a/common/taler-api/src/api/wire.rs b/common/taler-api/src/api/wire.rs
@@ -34,6 +34,8 @@ use taler_common::{
};
use crate::{
+ api::RouterUtils as _,
+ auth::AuthMethod,
constants::{MAX_PAGE_SIZE, MAX_TIMEOUT_MS, WIRE_GATEWAY_API_VERSION},
error::{ApiResult, failure, failure_code, failure_status},
json::Req,
@@ -88,22 +90,9 @@ pub trait WireGateway: TalerApi {
}
}
-pub fn router<I: WireGateway>(state: Arc<I>) -> Router {
+pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router {
Router::new()
.route(
- "/config",
- get(|State(state): State<Arc<I>>| async move {
- Json(WireConfig {
- name: "taler-wire-gateway",
- version: WIRE_GATEWAY_API_VERSION,
- currency: state.currency(),
- implementation: state.implementation(),
- support_account_check: state.support_account_check(),
- })
- .into_response()
- }),
- )
- .route(
"/transfer",
post(
|State(state): State<Arc<I>>, Req(req): Req<TransferRequest>| async move {
@@ -197,5 +186,19 @@ pub fn router<I: WireGateway>(state: Arc<I>) -> Router {
},
),
)
+ .auth(auth, "taler-wire-gateway")
+ .route(
+ "/config",
+ get(|State(state): State<Arc<I>>| async move {
+ Json(WireConfig {
+ name: "taler-wire-gateway",
+ version: WIRE_GATEWAY_API_VERSION,
+ currency: state.currency(),
+ implementation: state.implementation(),
+ support_account_check: state.support_account_check(),
+ })
+ .into_response()
+ }),
+ )
.with_state(state)
}
diff --git a/common/taler-api/src/auth.rs b/common/taler-api/src/auth.rs
@@ -18,7 +18,10 @@ use std::sync::Arc;
use axum::{
extract::{Request, State},
- http::header::{self},
+ http::{
+ HeaderValue,
+ header::{self, WWW_AUTHENTICATE},
+ },
middleware::Next,
response::{IntoResponse, Response},
};
@@ -32,18 +35,41 @@ pub enum AuthMethod {
None,
}
+pub struct AuthMiddlewareState {
+ method: AuthMethod,
+ challenge: HeaderValue,
+}
+
+impl AuthMiddlewareState {
+ pub fn new(method: AuthMethod, realm: &str) -> Self {
+ let challenge = match method {
+ AuthMethod::Basic(_) => format!("Basic realm=\"{realm}\" charset=\"UTF-8\""),
+ AuthMethod::Bearer(_) => format!("Bearer realm=\"{realm}\""),
+ AuthMethod::None => format!(""),
+ };
+ Self {
+ challenge: HeaderValue::from_str(&challenge).unwrap(),
+ method,
+ }
+ }
+}
+
pub async fn auth_middleware(
- State(method): State<Arc<AuthMethod>>,
+ State(state): State<Arc<AuthMiddlewareState>>,
req: Request,
next: Next,
) -> Response {
- fn parse_auth<'a>(req: &'a Request, scheme: &str) -> Result<&'a str, crate::error::ApiError> {
+ fn parse_auth<'a>(
+ req: &'a Request,
+ scheme: &'static str,
+ challenge: &HeaderValue,
+ ) -> Result<&'a str, crate::error::ApiError> {
let Some(authorisation) = req.headers().get(header::AUTHORIZATION) else {
- // TODO WWWAuthenticate challenge
return Err(failure(
ErrorCode::GENERIC_UNAUTHORIZED,
"Authorization header not found",
- ));
+ )
+ .with_header(WWW_AUTHENTICATE, challenge.clone()));
};
let Some((hscheme, parameter)) = authorisation
@@ -67,8 +93,8 @@ pub async fn auth_middleware(
Ok(parameter)
}
- match method.as_ref() {
- AuthMethod::Basic(token) => match parse_auth(&req, "Basic") {
+ match &state.method {
+ AuthMethod::Basic(token) => match parse_auth(&req, "Basic", &state.challenge) {
Ok(htoken) => {
if htoken != token {
return failure_code(ErrorCode::GENERIC_TOKEN_UNKNOWN).into_response();
@@ -76,7 +102,7 @@ pub async fn auth_middleware(
}
Err(err) => return err.into_response(),
},
- AuthMethod::Bearer(token) => match parse_auth(&req, "Bearer") {
+ AuthMethod::Bearer(token) => match parse_auth(&req, "Bearer", &state.challenge) {
Ok(htoken) => {
if htoken != token {
return failure_code(ErrorCode::GENERIC_TOKEN_UNKNOWN).into_response();
diff --git a/common/taler-api/src/error.rs b/common/taler-api/src/error.rs
@@ -16,7 +16,7 @@
use axum::{
Json,
- http::StatusCode,
+ http::{HeaderMap, HeaderValue, StatusCode, header::IntoHeaderName},
response::{IntoResponse, Response},
};
use taler_common::{
@@ -31,6 +31,53 @@ pub struct ApiError {
log: Option<Box<str>>,
status: Option<StatusCode>,
path: Option<Box<str>>,
+ headers: HeaderMap,
+}
+
+impl ApiError {
+ pub fn new(code: ErrorCode) -> Self {
+ Self {
+ code,
+ hint: None,
+ log: None,
+ status: None,
+ path: None,
+ headers: HeaderMap::new(),
+ }
+ }
+
+ pub fn with_hint(self, hint: impl Into<Box<str>>) -> Self {
+ Self {
+ hint: Some(hint.into()),
+ ..self
+ }
+ }
+
+ pub fn with_log(self, log: impl Into<Box<str>>) -> Self {
+ Self {
+ log: Some(log.into()),
+ ..self
+ }
+ }
+
+ pub fn with_status(self, code: StatusCode) -> Self {
+ Self {
+ status: Some(code),
+ ..self
+ }
+ }
+
+ pub fn with_path(self, path: impl Into<Box<str>>) -> Self {
+ Self {
+ path: Some(path.into()),
+ ..self
+ }
+ }
+
+ pub fn with_header(mut self, key: impl IntoHeaderName, value: HeaderValue) -> Self {
+ self.headers.append(key, value);
+ self
+ }
}
impl From<sqlx::Error> for ApiError {
@@ -54,43 +101,31 @@ impl From<sqlx::Error> for ApiError {
status: Some(status),
log: Some(format!("db: {value}").into_boxed_str()),
path: None,
+ headers: HeaderMap::new(),
}
}
}
impl From<PaytoErr> for ApiError {
fn from(value: PaytoErr) -> Self {
- Self {
- code: ErrorCode::GENERIC_PAYTO_URI_MALFORMED,
- hint: Some(value.to_string().into_boxed_str()),
- log: None,
- status: None,
- path: None,
- }
+ ApiError::new(ErrorCode::GENERIC_PAYTO_URI_MALFORMED).with_hint(value.to_string())
}
}
impl From<ParamsErr> for ApiError {
fn from(value: ParamsErr) -> Self {
- Self {
- code: ErrorCode::GENERIC_PARAMETER_MALFORMED,
- hint: Some(value.to_string().into_boxed_str()),
- log: None,
- status: None,
- path: Some(value.param.to_owned().into_boxed_str()),
- }
+ ApiError::new(ErrorCode::GENERIC_PARAMETER_MALFORMED)
+ .with_hint(value.to_string())
+ .with_path(value.param)
}
}
impl From<serde_path_to_error::Error<serde_json::Error>> for ApiError {
fn from(value: serde_path_to_error::Error<serde_json::Error>) -> Self {
- Self {
- code: ErrorCode::GENERIC_JSON_INVALID,
- hint: Some(value.inner().to_string().into_boxed_str()),
- log: Some(value.to_string().into_boxed_str()),
- status: None,
- path: Some(value.path().to_string().into_boxed_str()),
- }
+ ApiError::new(ErrorCode::GENERIC_JSON_INVALID)
+ .with_hint(value.inner().to_string())
+ .with_path(value.path().to_string())
+ .with_log(value.to_string())
}
}
@@ -120,6 +155,9 @@ impl IntoResponse for ApiError {
}),
)
.into_response();
+ for (k, v) in self.headers {
+ resp.headers_mut().append(k.unwrap(), v);
+ }
if let Some(log) = log {
resp.extensions_mut().insert(log);
};
@@ -129,41 +167,19 @@ impl IntoResponse for ApiError {
}
pub fn failure_code(code: ErrorCode) -> ApiError {
- ApiError {
- code,
- hint: None,
- log: None,
- status: None,
- path: None,
- }
+ ApiError::new(code)
}
pub fn failure(code: ErrorCode, hint: impl Into<Box<str>>) -> ApiError {
- ApiError {
- code,
- hint: Some(hint.into()),
- log: None,
- status: None,
- path: None,
- }
+ ApiError::new(code).with_hint(hint)
}
pub fn failure_status(code: ErrorCode, hint: impl Into<Box<str>>, status: StatusCode) -> ApiError {
- ApiError {
- code,
- hint: Some(hint.into()),
- log: None,
- status: Some(status),
- path: None,
- }
+ ApiError::new(code).with_hint(hint).with_status(status)
}
pub fn not_implemented(hint: impl Into<Box<str>>) -> ApiError {
- ApiError {
- code: ErrorCode::END,
- hint: Some(hint.into()),
- log: None,
- status: Some(StatusCode::NOT_IMPLEMENTED),
- path: None,
- }
+ ApiError::new(ErrorCode::END)
+ .with_hint(hint)
+ .with_status(StatusCode::NOT_IMPLEMENTED)
}