kych

OAuth 2.0 API for Swiyu to enable Taler integration of Swiyu for KYC (experimental)
Log | Files | Refs

commit a5436336f53f8c1dc8951e0269e926571374bb51
parent ab7dc9432feec1f6a6303ef1c98d3d2a4b60857a
Author: Henrique Chan Carvalho Machado <henriqueccmachado@tecnico.ulisboa.pt>
Date:   Tue, 25 Nov 2025 22:21:19 +0100

oauth2_gateway: add unix socket support, add /token code validation

Diffstat:
Moauth2_gateway/Cargo.toml | 7++++---
Moauth2_gateway/src/config.rs | 179++++++++++++++++++-------------------------------------------------------------
Moauth2_gateway/src/db/authorization_codes.rs | 14++++----------
Moauth2_gateway/src/db/clients.rs | 21+++++++++++++++------
Moauth2_gateway/src/db/notification_webhooks.rs | 1-
Moauth2_gateway/src/handlers.rs | 124++++++++++++++++++++++++++++++++++++++++++++++---------------------------------
Moauth2_gateway/src/main.rs | 26+++++++++++++++++++++-----
Moauth2_gateway/src/models.rs | 5++---
8 files changed, 160 insertions(+), 217 deletions(-)

diff --git a/oauth2_gateway/Cargo.toml b/oauth2_gateway/Cargo.toml @@ -3,7 +3,7 @@ name = "oauth2-gateway" version = "0.0.1" edition = "2024" -[lib] # For tests +[lib] name = "oauth2_gateway" path = "src/lib.rs" @@ -54,6 +54,7 @@ dotenvy = "0.15" # Cryptography rand = "0.8.5" +bcrypt = "0.15" base64 = "0.22.1" # Database @@ -62,4 +63,4 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "uuid", "chro [dev-dependencies] tempfile = "3.8" wiremock = "0.6" -serial_test = "3.2" -\ No newline at end of file +serial_test = "3.2.0" +\ No newline at end of file diff --git a/oauth2_gateway/src/config.rs b/oauth2_gateway/src/config.rs @@ -11,8 +11,34 @@ pub struct Config { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerConfig { - pub host: String, - pub port: u16, + pub host: Option<String>, + pub port: Option<u16>, + pub socket_path: Option<String>, +} + +impl ServerConfig { + pub fn validate(&self) -> Result<()> { + let has_tcp = self.host.is_some() || self.port.is_some(); + let has_unix = self.socket_path.is_some(); + + if has_tcp && has_unix { + anyhow::bail!("Cannot specify both TCP (host/port) and Unix socket (socket_path)"); + } + + if !has_tcp && !has_unix { + anyhow::bail!("Must specify either TCP (host/port) or Unix socket (socket_path)"); + } + + if has_tcp && (self.host.is_none() || self.port.is_none()) { + anyhow::bail!("Host and port must be specified for TCP"); + } + + Ok(()) + } + + pub fn is_unix_socket(&self) -> bool { + self.socket_path.is_some() + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -29,18 +55,22 @@ impl Config { .section(Some("server")) .context("Missing [server] section")?; + let host = server_section.get("host").map(|s| s.to_string()); + let port = server_section + .get("port") + .map(|s| s.parse::<u16>()) + .transpose() + .context("Invalid port")?; + let socket_path = server_section.get("socket_path").map(|s| s.to_string()); + let server = ServerConfig { - host: server_section - .get("host") - .unwrap_or("127.0.0.1") - .to_string(), - port: server_section - .get("port") - .unwrap_or("9090") - .parse() - .context("Invalid port")?, + host, + port, + socket_path, }; + server.validate()?; + let database_section = ini .section(Some("database")) .context("Missing [database] section")?; @@ -59,130 +89,3 @@ impl Config { } } -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write; - use tempfile::NamedTempFile; - - #[test] - fn test_config_load_valid() { - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#" -[server] -host = 0.0.0.0 -port = 3000 - -[database] -url = postgresql://localhost/oauth2gw_test -"# - ) - .unwrap(); - - let config = Config::from_file(temp_file.path()).unwrap(); - - assert_eq!(config.server.host, "0.0.0.0"); - assert_eq!(config.server.port, 3000); - assert_eq!(config.database.url, "postgresql://localhost/oauth2gw_test"); - } - - #[test] - fn test_config_defaults() { - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#" -[server] -# host and port not specified, should use defaults - -[database] -url = postgresql://localhost/oauth2gw -"# - ) - .unwrap(); - - let config = Config::from_file(temp_file.path()).unwrap(); - - assert_eq!(config.server.host, "127.0.0.1"); - assert_eq!(config.server.port, 9090); - } - - #[test] - fn test_config_missing_server_section() { - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#" -[database] -url = postgresql://localhost/oauth2gw -"# - ) - .unwrap(); - - let result = Config::from_file(temp_file.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("server")); - } - - #[test] - fn test_config_missing_database_section() { - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#" -[server] -host = 127.0.0.1 -port = 9090 -"# - ) - .unwrap(); - - let result = Config::from_file(temp_file.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("database")); - } - - #[test] - fn test_config_missing_database_url() { - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#" -[server] -host = 127.0.0.1 -port = 9090 - -[database] -# url is missing -"# - ) - .unwrap(); - - let result = Config::from_file(temp_file.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("database.url")); - } - - #[test] - fn test_config_invalid_port() { - let mut temp_file = NamedTempFile::new().unwrap(); - writeln!( - temp_file, - r#" -[server] -host = 127.0.0.1 -port = not_a_number - -[database] -url = postgresql://localhost/oauth2gw -"# - ) - .unwrap(); - - let result = Config::from_file(temp_file.path()); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("port")); - } -} diff --git a/oauth2_gateway/src/db/authorization_codes.rs b/oauth2_gateway/src/db/authorization_codes.rs @@ -7,7 +7,6 @@ use chrono::{DateTime, Utc}; use super::sessions::SessionStatus; -/// Authorization code record #[derive(Debug, Clone, sqlx::FromRow)] pub struct AuthorizationCode { pub id: Uuid, @@ -19,14 +18,13 @@ pub struct AuthorizationCode { pub created_at: DateTime<Utc>, } -/// Result of code exchange query - includes session and existing token data for idempotency #[derive(Debug, Clone)] pub struct CodeExchangeData { + pub client_id: Uuid, pub code_id: Uuid, pub was_already_used: bool, pub session_id: Uuid, pub session_status: SessionStatus, - /// Existing token if one was already created (for idempotent response) pub existing_token: Option<String>, pub existing_token_expires_at: Option<DateTime<Utc>>, } @@ -56,13 +54,7 @@ pub async fn create_authorization_code( Ok(auth_code) } -/// Atomically mark code as used and fetch session + existing token data -/// -/// This is the idempotent query for /token endpoint: -/// - Capture OLD used value before update -/// - JOINs with session to get status -/// - LEFT JOINs with access_tokens to get existing token (for idempotent response) -/// - Returns None if code doesn't exist or is expired +/// Mark code as used and fetch session + existing token data /// /// Used by the /token endpoint pub async fn get_code_for_token_exchange( @@ -90,6 +82,7 @@ pub async fn get_code_for_token_exchange( uc.id AS code_id, cd.was_already_used, uc.session_id, + vs.client_id, vs.status AS session_status, at.token AS existing_token, at.expires_at AS existing_token_expires_at @@ -107,6 +100,7 @@ pub async fn get_code_for_token_exchange( Ok(result.map(|row: sqlx::postgres::PgRow| { use sqlx::Row; CodeExchangeData { + client_id: row.get("code_id"), code_id: row.get("code_id"), was_already_used: row.get("was_already_used"), session_id: row.get("session_id"), diff --git a/oauth2_gateway/src/db/clients.rs b/oauth2_gateway/src/db/clients.rs @@ -30,6 +30,8 @@ pub async fn register_client( let api_path = verifier_management_api_path .unwrap_or("/management/api/verifications"); + let secret_hash = bcrypt::hash(client_secret, bcrypt::DEFAULT_COST)?; + let client = sqlx::query_as::<_, Client>( r#" INSERT INTO oauth2gw.clients @@ -40,7 +42,7 @@ pub async fn register_client( "# ) .bind(client_id) - .bind(client_secret) + .bind(secret_hash) .bind(webhook_url) .bind(verifier_url) .bind(api_path) @@ -103,15 +105,23 @@ pub async fn authenticate_client( SELECT id, client_id, secret_hash, webhook_url, verifier_url, verifier_management_api_path, created_at, updated_at FROM oauth2gw.clients - WHERE client_id = $1 AND secret_hash = $2 + WHERE client_id = $1 "# ) .bind(client_id) - .bind(client_secret) .fetch_optional(pool) .await?; - Ok(client) + match client { + Some(c) => { + if bcrypt::verify(client_secret, &c.secret_hash)? { + Ok(Some(c)) + } else { + Ok(None) + } + } + None => Ok(None) + } } /// Update client configuration @@ -188,4 +198,4 @@ pub async fn list_clients(pool: &PgPool) -> Result<Vec<Client>> { .await?; Ok(clients) -} -\ No newline at end of file +} diff --git a/oauth2_gateway/src/db/notification_webhooks.rs b/oauth2_gateway/src/db/notification_webhooks.rs @@ -21,7 +21,6 @@ pub struct PendingWebhook { /// Fetch pending webhooks ready to be sent /// -/// JOINs with authorization_codes to get the code value. /// Only returns webhooks where next_attempt <= current epoch time. /// /// Used by the background worker diff --git a/oauth2_gateway/src/handlers.rs b/oauth2_gateway/src/handlers.rs @@ -28,10 +28,11 @@ pub async fn setup( Path(client_id): Path<String>, Json(request): Json<SetupRequest>, ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { - tracing::info!("Setup request for client: {}, scope: {}", - client_id, request.scope); + + tracing::info!("Setup request for client: {}, scope: {}", client_id, request.scope); let nonce = crypto::generate_nonce(); + tracing::debug!("Generated nonce: {}", nonce); let session = crate::db::sessions::create_session( @@ -40,25 +41,22 @@ pub async fn setup( &nonce, &request.scope, 15, // 15 minutes expiration - ) - .await + ).await .map_err(|e| { tracing::error!("Failed to create session: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse::new("internal_error"))) + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error"))) })?; + let session = match session { Some(s) => s, None => { tracing::warn!("Client not found: {}", client_id); - return Err((StatusCode::NOT_FOUND, - Json(ErrorResponse::new("client_not_found")))) + return Err((StatusCode::NOT_FOUND, Json(ErrorResponse::new("client_not_found")))) } }; - tracing::info!("Created session {} for client {} with nonce {}", - session.id, client_id, nonce); + tracing::info!("Created session {} for client {} with nonce {}", session.id, client_id, nonce); Ok((StatusCode::OK, Json(SetupResponse { nonce }))) } @@ -68,8 +66,9 @@ pub async fn authorize( State(state): State<AppState>, Query(params): Query<AuthorizeQuery>, ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { - tracing::info!("Authorize request for client: {}, nonce: {}", - params.client_id, params.nonce); + + tracing::info!("Authorize request for client: {}, nonce: {}",params.client_id, params.nonce); + // Validate response_type if params.response_type != "code" { @@ -77,19 +76,19 @@ pub async fn authorize( Json(ErrorResponse::new("invalid_request")))); } + // Fetch session and client data (idempotent) let session_data = crate::db::sessions::get_session_for_authorize( &state.pool, &params.nonce, &params.client_id, - ) - .await + ).await .map_err(|e| { tracing::error!("DB error in authorize: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse::new("internal_error"))) + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error"))) })?; + let data = match session_data { Some(d) => d, None => { @@ -99,27 +98,30 @@ pub async fn authorize( } }; + // Backend validation if data.expires_at < Utc::now() { tracing::warn!("Session expired: {}", data.session_id); - return Err((StatusCode::GONE, - Json(ErrorResponse::new("session_expired")))); + return Err((StatusCode::GONE, Json(ErrorResponse::new("session_expired")))); } + // Check status for idempotency match data.status { SessionStatus::Authorized => { // Already authorized - return cached response - tracing::info!("Session {} already authorized, returning cached response", - data.session_id); + tracing::info!("Session {} already authorized, returning cached response", data.session_id); + let verification_id = data.request_id .and_then(|id| uuid::Uuid::parse_str(&id).ok()) .unwrap_or(uuid::Uuid::nil()); + return Ok((StatusCode::OK, Json(AuthorizeResponse { verification_id, verification_url: data.verification_url.unwrap_or_default(), }))); } + SessionStatus::Pending => { // Proceed with authorization } @@ -135,9 +137,7 @@ pub async fn authorize( let presentation_definition = build_presentation_definition(&data.scope); // Call Swiyu Verifier - let verifier_url = format!("{}{}", - data.verifier_url, - data.verifier_management_api_path); + let verifier_url = format!("{}{}", data.verifier_url, data.verifier_management_api_path); let verifier_request = SwiyuCreateVerificationRequest { accepted_issuer_dids: default_accepted_issuer_dids(), @@ -150,11 +150,11 @@ pub async fn authorize( dcql_query: None, }; - // convert verification request to json for debug view - tracing::debug!("Swiyu verifier request: {}", serde_json::to_string_pretty(&verifier_request).unwrap()); + tracing::debug!("Swiyu verifier request: {}", serde_json::to_string_pretty(&verifier_request).unwrap()); tracing::debug!("Calling Swiyu verifier at: {}", verifier_url); + let verifier_response = state.http_client .post(&verifier_url) .json(&verifier_request) @@ -162,27 +162,27 @@ pub async fn authorize( .await .map_err(|e| { tracing::error!("Failed to call Swiyu verifier: {}", e); - (StatusCode::BAD_GATEWAY, - Json(ErrorResponse::new("verifier_unavailable"))) + (StatusCode::BAD_GATEWAY, Json(ErrorResponse::new("verifier_unavailable"))) })?; + if !verifier_response.status().is_success() { let status = verifier_response.status(); let body = verifier_response.text().await.unwrap_or_default(); tracing::error!("Swiyu verifier returned error {}: {}", status, body); - return Err((StatusCode::BAD_GATEWAY, - Json(ErrorResponse::new("verifier_error")))); + return Err((StatusCode::BAD_GATEWAY, Json(ErrorResponse::new("verifier_error")))); } + let swiyu_response: SwiyuManagementResponse = verifier_response .json() .await .map_err(|e| { tracing::error!("Failed to parse Swiyu response: {}", e); - (StatusCode::BAD_GATEWAY, - Json(ErrorResponse::new("verifier_invalid_response"))) + (StatusCode::BAD_GATEWAY, Json(ErrorResponse::new("verifier_invalid_response"))) })?; + // Update session with verifier data let result = crate::db::sessions::update_session_authorized( &state.pool, @@ -190,14 +190,13 @@ pub async fn authorize( &swiyu_response.verification_url, &swiyu_response.id.to_string(), swiyu_response.request_nonce.as_deref(), - ) - .await + ).await .map_err(|e| { tracing::error!("Failed to update session: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse::new("internal_error"))) + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error"))) })?; + tracing::info!("Session {} authorized, verification_id: {}", data.session_id, swiyu_response.id); @@ -216,8 +215,7 @@ fn build_presentation_definition(scope: &str) -> PresentationDefinition { let attributes: Vec<&str> = scope.split_whitespace().collect(); - tracing::debug!("Building presentation definition for attributes: {:?}", - attributes); + tracing::debug!("Building presentation definition for attributes: {:?}", attributes); // First field: $.vct with filter for credential type let vct_field = Field { @@ -231,6 +229,7 @@ fn build_presentation_definition(scope: &str) -> PresentationDefinition { }), }; + // Attribute fields from scope let mut fields: Vec<Field> = vec![vct_field]; for attr in &attributes { @@ -243,6 +242,7 @@ fn build_presentation_definition(scope: &str) -> PresentationDefinition { }); } + let mut format = HashMap::new(); format.insert( "vc+sd-jwt".to_string(), @@ -274,6 +274,7 @@ pub async fn token( State(state): State<AppState>, Json(request): Json<TokenRequest>, ) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { + tracing::info!("Token request for code: {}", request.code); // Validate grant_type @@ -284,27 +285,52 @@ pub async fn token( )); } + // Authenticate client + let client = crate::db::clients::authenticate_client( + &state.pool, + &request.client_id, + &request.client_secret, + ).await + .map_err(|e| { + tracing::error!("DB error during client authentication: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error"))) + })?; + + let client = match client { + Some(c) => c, + None => { + tracing::warn!("Client authentication failed for {}", request.client_id); + return Err((StatusCode::UNAUTHORIZED, Json(ErrorResponse::new("invalid_client")))); + } + }; + + // Fetch code (idempotent) let code_data = crate::db::authorization_codes::get_code_for_token_exchange( &state.pool, &request.code, - ) - .await + ).await .map_err(|e| { tracing::error!("DB error in token exchange: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse::new("internal_error"))) + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error"))) })?; let data = match code_data { Some(d) => d, None => { tracing::warn!("Authorization code not found or expired: {}", request.code); - return Err((StatusCode::BAD_REQUEST, - Json(ErrorResponse::new("invalid_grant")))); + return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant")))); } }; + // Verify the authorization code belongs to the client + if data.client_id != client.id { + tracing::warn!("Authorization code {} does not belong to the client {}", + request.code, request.client_id); + + return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant")))); + } + // Check for existing token if let Some(existing_token) = data.existing_token { tracing::info!("Token already exists for session {}, returning cached response", @@ -319,16 +345,14 @@ pub async fn token( // Check if code was already used if data.was_already_used { tracing::warn!("Authorization code {} was already used", request.code); - return Err((StatusCode::BAD_REQUEST, - Json(ErrorResponse::new("invalid_grant")))); + return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant")))); } // Validate session status if data.session_status != SessionStatus::Verified { tracing::warn!("Session {} not in verified status: {:?}", data.session_id, data.session_status); - return Err((StatusCode::BAD_REQUEST, - Json(ErrorResponse::new("invalid_grant")))); + return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant")))); } // Generate new token and complete session @@ -338,12 +362,10 @@ pub async fn token( data.session_id, &access_token, 3600, // 1 hour - ) - .await + ).await .map_err(|e| { tracing::error!("Failed to create token: {}", e); - (StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse::new("internal_error"))) + (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error"))) })?; tracing::info!("Token created for session {}", data.session_id); diff --git a/oauth2_gateway/src/main.rs b/oauth2_gateway/src/main.rs @@ -49,13 +49,29 @@ async fn main() -> Result<()> { .layer(TraceLayer::new_for_http()) .with_state(state); - let addr = format!("{}:{}", config.server.host, config.server.port); - let listener = tokio::net::TcpListener::bind(&addr).await?; + if config.server.is_unix_socket() { + let socket_path = config.server.socket_path.as_ref().unwrap(); - tracing::info!("Server listening on {}", addr); - tracing::info!("Health check available at: http://{}/health", addr); + if std::path::Path::new(socket_path).exists() { + tracing::warn!("Removing existing socket file: {}", socket_path); + std::fs::remove_file(socket_path)?; + } - axum::serve(listener, app).await?; + let listener = tokio::net::UnixListener::bind(socket_path)?; + tracing::info!("Server listening on Unix socket: {}", socket_path); + + axum::serve(listener, app).await?; + } else { + let host = config.server.host.as_ref().unwrap(); + let port = config.server.port.unwrap(); + let addr = format!("{}:{}", host, port); + + let listener = tokio::net::TcpListener::bind(&addr).await?; + tracing::info!("Server listening on {}", addr); + tracing::info!("Health check available at: http://{}/health", addr); + + axum::serve(listener, app).await?; + } Ok(()) } diff --git a/oauth2_gateway/src/models.rs b/oauth2_gateway/src/models.rs @@ -2,7 +2,6 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use std::collections::HashMap; -// Setup endpoint #[derive(Debug, Deserialize, Serialize)] pub struct SetupRequest { pub scope: String, @@ -20,7 +19,6 @@ pub struct AuthorizeQuery { pub nonce: String, } -// Authorize endpoint #[derive(Debug, Deserialize, Serialize)] pub struct AuthorizeResponse { #[serde(rename = "verificationId")] @@ -35,6 +33,8 @@ pub struct AuthorizeResponse { pub struct TokenRequest { pub grant_type: String, pub code: String, + pub client_id: String, + pub client_secret: String, } #[derive(Debug, Deserialize, Serialize)] @@ -95,7 +95,6 @@ pub fn default_accepted_issuer_dids() -> Vec<String> { pub struct SwiyuCreateVerificationRequest { #[serde(default = "default_accepted_issuer_dids")] pub accepted_issuer_dids: Vec<String>, - #[serde(skip_serializing_if = "Option::is_none")] pub trust_anchors: Option<Vec<TrustAnchor>>,