commit a5436336f53f8c1dc8951e0269e926571374bb51
parent ab7dc9432feec1f6a6303ef1c98d3d2a4b60857a
Author: Henrique Chan Carvalho Machado <henriqueccmachado@tecnico.ulisboa.pt>
Date: Tue, 25 Nov 2025 22:21:19 +0100
oauth2_gateway: add unix socket support, add /token code validation
Diffstat:
8 files changed, 160 insertions(+), 217 deletions(-)
diff --git a/oauth2_gateway/Cargo.toml b/oauth2_gateway/Cargo.toml
@@ -3,7 +3,7 @@ name = "oauth2-gateway"
version = "0.0.1"
edition = "2024"
-[lib] # For tests
+[lib]
name = "oauth2_gateway"
path = "src/lib.rs"
@@ -54,6 +54,7 @@ dotenvy = "0.15"
# Cryptography
rand = "0.8.5"
+bcrypt = "0.15"
base64 = "0.22.1"
# Database
@@ -62,4 +63,4 @@ sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "uuid", "chro
[dev-dependencies]
tempfile = "3.8"
wiremock = "0.6"
-serial_test = "3.2"
-\ No newline at end of file
+serial_test = "3.2.0"
+\ No newline at end of file
diff --git a/oauth2_gateway/src/config.rs b/oauth2_gateway/src/config.rs
@@ -11,8 +11,34 @@ pub struct Config {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
- pub host: String,
- pub port: u16,
+ pub host: Option<String>,
+ pub port: Option<u16>,
+ pub socket_path: Option<String>,
+}
+
+impl ServerConfig {
+ pub fn validate(&self) -> Result<()> {
+ let has_tcp = self.host.is_some() || self.port.is_some();
+ let has_unix = self.socket_path.is_some();
+
+ if has_tcp && has_unix {
+ anyhow::bail!("Cannot specify both TCP (host/port) and Unix socket (socket_path)");
+ }
+
+ if !has_tcp && !has_unix {
+ anyhow::bail!("Must specify either TCP (host/port) or Unix socket (socket_path)");
+ }
+
+ if has_tcp && (self.host.is_none() || self.port.is_none()) {
+ anyhow::bail!("Host and port must be specified for TCP");
+ }
+
+ Ok(())
+ }
+
+ pub fn is_unix_socket(&self) -> bool {
+ self.socket_path.is_some()
+ }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -29,18 +55,22 @@ impl Config {
.section(Some("server"))
.context("Missing [server] section")?;
+ let host = server_section.get("host").map(|s| s.to_string());
+ let port = server_section
+ .get("port")
+ .map(|s| s.parse::<u16>())
+ .transpose()
+ .context("Invalid port")?;
+ let socket_path = server_section.get("socket_path").map(|s| s.to_string());
+
let server = ServerConfig {
- host: server_section
- .get("host")
- .unwrap_or("127.0.0.1")
- .to_string(),
- port: server_section
- .get("port")
- .unwrap_or("9090")
- .parse()
- .context("Invalid port")?,
+ host,
+ port,
+ socket_path,
};
+ server.validate()?;
+
let database_section = ini
.section(Some("database"))
.context("Missing [database] section")?;
@@ -59,130 +89,3 @@ impl Config {
}
}
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::io::Write;
- use tempfile::NamedTempFile;
-
- #[test]
- fn test_config_load_valid() {
- let mut temp_file = NamedTempFile::new().unwrap();
- writeln!(
- temp_file,
- r#"
-[server]
-host = 0.0.0.0
-port = 3000
-
-[database]
-url = postgresql://localhost/oauth2gw_test
-"#
- )
- .unwrap();
-
- let config = Config::from_file(temp_file.path()).unwrap();
-
- assert_eq!(config.server.host, "0.0.0.0");
- assert_eq!(config.server.port, 3000);
- assert_eq!(config.database.url, "postgresql://localhost/oauth2gw_test");
- }
-
- #[test]
- fn test_config_defaults() {
- let mut temp_file = NamedTempFile::new().unwrap();
- writeln!(
- temp_file,
- r#"
-[server]
-# host and port not specified, should use defaults
-
-[database]
-url = postgresql://localhost/oauth2gw
-"#
- )
- .unwrap();
-
- let config = Config::from_file(temp_file.path()).unwrap();
-
- assert_eq!(config.server.host, "127.0.0.1");
- assert_eq!(config.server.port, 9090);
- }
-
- #[test]
- fn test_config_missing_server_section() {
- let mut temp_file = NamedTempFile::new().unwrap();
- writeln!(
- temp_file,
- r#"
-[database]
-url = postgresql://localhost/oauth2gw
-"#
- )
- .unwrap();
-
- let result = Config::from_file(temp_file.path());
- assert!(result.is_err());
- assert!(result.unwrap_err().to_string().contains("server"));
- }
-
- #[test]
- fn test_config_missing_database_section() {
- let mut temp_file = NamedTempFile::new().unwrap();
- writeln!(
- temp_file,
- r#"
-[server]
-host = 127.0.0.1
-port = 9090
-"#
- )
- .unwrap();
-
- let result = Config::from_file(temp_file.path());
- assert!(result.is_err());
- assert!(result.unwrap_err().to_string().contains("database"));
- }
-
- #[test]
- fn test_config_missing_database_url() {
- let mut temp_file = NamedTempFile::new().unwrap();
- writeln!(
- temp_file,
- r#"
-[server]
-host = 127.0.0.1
-port = 9090
-
-[database]
-# url is missing
-"#
- )
- .unwrap();
-
- let result = Config::from_file(temp_file.path());
- assert!(result.is_err());
- assert!(result.unwrap_err().to_string().contains("database.url"));
- }
-
- #[test]
- fn test_config_invalid_port() {
- let mut temp_file = NamedTempFile::new().unwrap();
- writeln!(
- temp_file,
- r#"
-[server]
-host = 127.0.0.1
-port = not_a_number
-
-[database]
-url = postgresql://localhost/oauth2gw
-"#
- )
- .unwrap();
-
- let result = Config::from_file(temp_file.path());
- assert!(result.is_err());
- assert!(result.unwrap_err().to_string().contains("port"));
- }
-}
diff --git a/oauth2_gateway/src/db/authorization_codes.rs b/oauth2_gateway/src/db/authorization_codes.rs
@@ -7,7 +7,6 @@ use chrono::{DateTime, Utc};
use super::sessions::SessionStatus;
-/// Authorization code record
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct AuthorizationCode {
pub id: Uuid,
@@ -19,14 +18,13 @@ pub struct AuthorizationCode {
pub created_at: DateTime<Utc>,
}
-/// Result of code exchange query - includes session and existing token data for idempotency
#[derive(Debug, Clone)]
pub struct CodeExchangeData {
+ pub client_id: Uuid,
pub code_id: Uuid,
pub was_already_used: bool,
pub session_id: Uuid,
pub session_status: SessionStatus,
- /// Existing token if one was already created (for idempotent response)
pub existing_token: Option<String>,
pub existing_token_expires_at: Option<DateTime<Utc>>,
}
@@ -56,13 +54,7 @@ pub async fn create_authorization_code(
Ok(auth_code)
}
-/// Atomically mark code as used and fetch session + existing token data
-///
-/// This is the idempotent query for /token endpoint:
-/// - Capture OLD used value before update
-/// - JOINs with session to get status
-/// - LEFT JOINs with access_tokens to get existing token (for idempotent response)
-/// - Returns None if code doesn't exist or is expired
+/// Mark code as used and fetch session + existing token data
///
/// Used by the /token endpoint
pub async fn get_code_for_token_exchange(
@@ -90,6 +82,7 @@ pub async fn get_code_for_token_exchange(
uc.id AS code_id,
cd.was_already_used,
uc.session_id,
+ vs.client_id,
vs.status AS session_status,
at.token AS existing_token,
at.expires_at AS existing_token_expires_at
@@ -107,6 +100,7 @@ pub async fn get_code_for_token_exchange(
Ok(result.map(|row: sqlx::postgres::PgRow| {
use sqlx::Row;
CodeExchangeData {
+ client_id: row.get("code_id"),
code_id: row.get("code_id"),
was_already_used: row.get("was_already_used"),
session_id: row.get("session_id"),
diff --git a/oauth2_gateway/src/db/clients.rs b/oauth2_gateway/src/db/clients.rs
@@ -30,6 +30,8 @@ pub async fn register_client(
let api_path = verifier_management_api_path
.unwrap_or("/management/api/verifications");
+ let secret_hash = bcrypt::hash(client_secret, bcrypt::DEFAULT_COST)?;
+
let client = sqlx::query_as::<_, Client>(
r#"
INSERT INTO oauth2gw.clients
@@ -40,7 +42,7 @@ pub async fn register_client(
"#
)
.bind(client_id)
- .bind(client_secret)
+ .bind(secret_hash)
.bind(webhook_url)
.bind(verifier_url)
.bind(api_path)
@@ -103,15 +105,23 @@ pub async fn authenticate_client(
SELECT id, client_id, secret_hash, webhook_url, verifier_url,
verifier_management_api_path, created_at, updated_at
FROM oauth2gw.clients
- WHERE client_id = $1 AND secret_hash = $2
+ WHERE client_id = $1
"#
)
.bind(client_id)
- .bind(client_secret)
.fetch_optional(pool)
.await?;
- Ok(client)
+ match client {
+ Some(c) => {
+ if bcrypt::verify(client_secret, &c.secret_hash)? {
+ Ok(Some(c))
+ } else {
+ Ok(None)
+ }
+ }
+ None => Ok(None)
+ }
}
/// Update client configuration
@@ -188,4 +198,4 @@ pub async fn list_clients(pool: &PgPool) -> Result<Vec<Client>> {
.await?;
Ok(clients)
-}
-\ No newline at end of file
+}
diff --git a/oauth2_gateway/src/db/notification_webhooks.rs b/oauth2_gateway/src/db/notification_webhooks.rs
@@ -21,7 +21,6 @@ pub struct PendingWebhook {
/// Fetch pending webhooks ready to be sent
///
-/// JOINs with authorization_codes to get the code value.
/// Only returns webhooks where next_attempt <= current epoch time.
///
/// Used by the background worker
diff --git a/oauth2_gateway/src/handlers.rs b/oauth2_gateway/src/handlers.rs
@@ -28,10 +28,11 @@ pub async fn setup(
Path(client_id): Path<String>,
Json(request): Json<SetupRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
- tracing::info!("Setup request for client: {}, scope: {}",
- client_id, request.scope);
+
+ tracing::info!("Setup request for client: {}, scope: {}", client_id, request.scope);
let nonce = crypto::generate_nonce();
+
tracing::debug!("Generated nonce: {}", nonce);
let session = crate::db::sessions::create_session(
@@ -40,25 +41,22 @@ pub async fn setup(
&nonce,
&request.scope,
15, // 15 minutes expiration
- )
- .await
+ ).await
.map_err(|e| {
tracing::error!("Failed to create session: {}", e);
- (StatusCode::INTERNAL_SERVER_ERROR,
- Json(ErrorResponse::new("internal_error")))
+ (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error")))
})?;
+
let session = match session {
Some(s) => s,
None => {
tracing::warn!("Client not found: {}", client_id);
- return Err((StatusCode::NOT_FOUND,
- Json(ErrorResponse::new("client_not_found"))))
+ return Err((StatusCode::NOT_FOUND, Json(ErrorResponse::new("client_not_found"))))
}
};
- tracing::info!("Created session {} for client {} with nonce {}",
- session.id, client_id, nonce);
+ tracing::info!("Created session {} for client {} with nonce {}", session.id, client_id, nonce);
Ok((StatusCode::OK, Json(SetupResponse { nonce })))
}
@@ -68,8 +66,9 @@ pub async fn authorize(
State(state): State<AppState>,
Query(params): Query<AuthorizeQuery>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
- tracing::info!("Authorize request for client: {}, nonce: {}",
- params.client_id, params.nonce);
+
+ tracing::info!("Authorize request for client: {}, nonce: {}",params.client_id, params.nonce);
+
// Validate response_type
if params.response_type != "code" {
@@ -77,19 +76,19 @@ pub async fn authorize(
Json(ErrorResponse::new("invalid_request"))));
}
+
// Fetch session and client data (idempotent)
let session_data = crate::db::sessions::get_session_for_authorize(
&state.pool,
¶ms.nonce,
¶ms.client_id,
- )
- .await
+ ).await
.map_err(|e| {
tracing::error!("DB error in authorize: {}", e);
- (StatusCode::INTERNAL_SERVER_ERROR,
- Json(ErrorResponse::new("internal_error")))
+ (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error")))
})?;
+
let data = match session_data {
Some(d) => d,
None => {
@@ -99,27 +98,30 @@ pub async fn authorize(
}
};
+
// Backend validation
if data.expires_at < Utc::now() {
tracing::warn!("Session expired: {}", data.session_id);
- return Err((StatusCode::GONE,
- Json(ErrorResponse::new("session_expired"))));
+ return Err((StatusCode::GONE, Json(ErrorResponse::new("session_expired"))));
}
+
// Check status for idempotency
match data.status {
SessionStatus::Authorized => {
// Already authorized - return cached response
- tracing::info!("Session {} already authorized, returning cached response",
- data.session_id);
+ tracing::info!("Session {} already authorized, returning cached response", data.session_id);
+
let verification_id = data.request_id
.and_then(|id| uuid::Uuid::parse_str(&id).ok())
.unwrap_or(uuid::Uuid::nil());
+
return Ok((StatusCode::OK, Json(AuthorizeResponse {
verification_id,
verification_url: data.verification_url.unwrap_or_default(),
})));
}
+
SessionStatus::Pending => {
// Proceed with authorization
}
@@ -135,9 +137,7 @@ pub async fn authorize(
let presentation_definition = build_presentation_definition(&data.scope);
// Call Swiyu Verifier
- let verifier_url = format!("{}{}",
- data.verifier_url,
- data.verifier_management_api_path);
+ let verifier_url = format!("{}{}", data.verifier_url, data.verifier_management_api_path);
let verifier_request = SwiyuCreateVerificationRequest {
accepted_issuer_dids: default_accepted_issuer_dids(),
@@ -150,11 +150,11 @@ pub async fn authorize(
dcql_query: None,
};
- // convert verification request to json for debug view
- tracing::debug!("Swiyu verifier request: {}", serde_json::to_string_pretty(&verifier_request).unwrap());
+ tracing::debug!("Swiyu verifier request: {}", serde_json::to_string_pretty(&verifier_request).unwrap());
tracing::debug!("Calling Swiyu verifier at: {}", verifier_url);
+
let verifier_response = state.http_client
.post(&verifier_url)
.json(&verifier_request)
@@ -162,27 +162,27 @@ pub async fn authorize(
.await
.map_err(|e| {
tracing::error!("Failed to call Swiyu verifier: {}", e);
- (StatusCode::BAD_GATEWAY,
- Json(ErrorResponse::new("verifier_unavailable")))
+ (StatusCode::BAD_GATEWAY, Json(ErrorResponse::new("verifier_unavailable")))
})?;
+
if !verifier_response.status().is_success() {
let status = verifier_response.status();
let body = verifier_response.text().await.unwrap_or_default();
tracing::error!("Swiyu verifier returned error {}: {}", status, body);
- return Err((StatusCode::BAD_GATEWAY,
- Json(ErrorResponse::new("verifier_error"))));
+ return Err((StatusCode::BAD_GATEWAY, Json(ErrorResponse::new("verifier_error"))));
}
+
let swiyu_response: SwiyuManagementResponse = verifier_response
.json()
.await
.map_err(|e| {
tracing::error!("Failed to parse Swiyu response: {}", e);
- (StatusCode::BAD_GATEWAY,
- Json(ErrorResponse::new("verifier_invalid_response")))
+ (StatusCode::BAD_GATEWAY, Json(ErrorResponse::new("verifier_invalid_response")))
})?;
+
// Update session with verifier data
let result = crate::db::sessions::update_session_authorized(
&state.pool,
@@ -190,14 +190,13 @@ pub async fn authorize(
&swiyu_response.verification_url,
&swiyu_response.id.to_string(),
swiyu_response.request_nonce.as_deref(),
- )
- .await
+ ).await
.map_err(|e| {
tracing::error!("Failed to update session: {}", e);
- (StatusCode::INTERNAL_SERVER_ERROR,
- Json(ErrorResponse::new("internal_error")))
+ (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error")))
})?;
+
tracing::info!("Session {} authorized, verification_id: {}",
data.session_id, swiyu_response.id);
@@ -216,8 +215,7 @@ fn build_presentation_definition(scope: &str) -> PresentationDefinition {
let attributes: Vec<&str> = scope.split_whitespace().collect();
- tracing::debug!("Building presentation definition for attributes: {:?}",
- attributes);
+ tracing::debug!("Building presentation definition for attributes: {:?}", attributes);
// First field: $.vct with filter for credential type
let vct_field = Field {
@@ -231,6 +229,7 @@ fn build_presentation_definition(scope: &str) -> PresentationDefinition {
}),
};
+
// Attribute fields from scope
let mut fields: Vec<Field> = vec![vct_field];
for attr in &attributes {
@@ -243,6 +242,7 @@ fn build_presentation_definition(scope: &str) -> PresentationDefinition {
});
}
+
let mut format = HashMap::new();
format.insert(
"vc+sd-jwt".to_string(),
@@ -274,6 +274,7 @@ pub async fn token(
State(state): State<AppState>,
Json(request): Json<TokenRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
+
tracing::info!("Token request for code: {}", request.code);
// Validate grant_type
@@ -284,27 +285,52 @@ pub async fn token(
));
}
+ // Authenticate client
+ let client = crate::db::clients::authenticate_client(
+ &state.pool,
+ &request.client_id,
+ &request.client_secret,
+ ).await
+ .map_err(|e| {
+ tracing::error!("DB error during client authentication: {}", e);
+ (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error")))
+ })?;
+
+ let client = match client {
+ Some(c) => c,
+ None => {
+ tracing::warn!("Client authentication failed for {}", request.client_id);
+ return Err((StatusCode::UNAUTHORIZED, Json(ErrorResponse::new("invalid_client"))));
+ }
+ };
+
+
// Fetch code (idempotent)
let code_data = crate::db::authorization_codes::get_code_for_token_exchange(
&state.pool,
&request.code,
- )
- .await
+ ).await
.map_err(|e| {
tracing::error!("DB error in token exchange: {}", e);
- (StatusCode::INTERNAL_SERVER_ERROR,
- Json(ErrorResponse::new("internal_error")))
+ (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error")))
})?;
let data = match code_data {
Some(d) => d,
None => {
tracing::warn!("Authorization code not found or expired: {}", request.code);
- return Err((StatusCode::BAD_REQUEST,
- Json(ErrorResponse::new("invalid_grant"))));
+ return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant"))));
}
};
+ // Verify the authorization code belongs to the client
+ if data.client_id != client.id {
+ tracing::warn!("Authorization code {} does not belong to the client {}",
+ request.code, request.client_id);
+
+ return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant"))));
+ }
+
// Check for existing token
if let Some(existing_token) = data.existing_token {
tracing::info!("Token already exists for session {}, returning cached response",
@@ -319,16 +345,14 @@ pub async fn token(
// Check if code was already used
if data.was_already_used {
tracing::warn!("Authorization code {} was already used", request.code);
- return Err((StatusCode::BAD_REQUEST,
- Json(ErrorResponse::new("invalid_grant"))));
+ return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant"))));
}
// Validate session status
if data.session_status != SessionStatus::Verified {
tracing::warn!("Session {} not in verified status: {:?}",
data.session_id, data.session_status);
- return Err((StatusCode::BAD_REQUEST,
- Json(ErrorResponse::new("invalid_grant"))));
+ return Err((StatusCode::BAD_REQUEST, Json(ErrorResponse::new("invalid_grant"))));
}
// Generate new token and complete session
@@ -338,12 +362,10 @@ pub async fn token(
data.session_id,
&access_token,
3600, // 1 hour
- )
- .await
+ ).await
.map_err(|e| {
tracing::error!("Failed to create token: {}", e);
- (StatusCode::INTERNAL_SERVER_ERROR,
- Json(ErrorResponse::new("internal_error")))
+ (StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse::new("internal_error")))
})?;
tracing::info!("Token created for session {}", data.session_id);
diff --git a/oauth2_gateway/src/main.rs b/oauth2_gateway/src/main.rs
@@ -49,13 +49,29 @@ async fn main() -> Result<()> {
.layer(TraceLayer::new_for_http())
.with_state(state);
- let addr = format!("{}:{}", config.server.host, config.server.port);
- let listener = tokio::net::TcpListener::bind(&addr).await?;
+ if config.server.is_unix_socket() {
+ let socket_path = config.server.socket_path.as_ref().unwrap();
- tracing::info!("Server listening on {}", addr);
- tracing::info!("Health check available at: http://{}/health", addr);
+ if std::path::Path::new(socket_path).exists() {
+ tracing::warn!("Removing existing socket file: {}", socket_path);
+ std::fs::remove_file(socket_path)?;
+ }
- axum::serve(listener, app).await?;
+ let listener = tokio::net::UnixListener::bind(socket_path)?;
+ tracing::info!("Server listening on Unix socket: {}", socket_path);
+
+ axum::serve(listener, app).await?;
+ } else {
+ let host = config.server.host.as_ref().unwrap();
+ let port = config.server.port.unwrap();
+ let addr = format!("{}:{}", host, port);
+
+ let listener = tokio::net::TcpListener::bind(&addr).await?;
+ tracing::info!("Server listening on {}", addr);
+ tracing::info!("Health check available at: http://{}/health", addr);
+
+ axum::serve(listener, app).await?;
+ }
Ok(())
}
diff --git a/oauth2_gateway/src/models.rs b/oauth2_gateway/src/models.rs
@@ -2,7 +2,6 @@ use serde::{Deserialize, Serialize};
use uuid::Uuid;
use std::collections::HashMap;
-// Setup endpoint
#[derive(Debug, Deserialize, Serialize)]
pub struct SetupRequest {
pub scope: String,
@@ -20,7 +19,6 @@ pub struct AuthorizeQuery {
pub nonce: String,
}
-// Authorize endpoint
#[derive(Debug, Deserialize, Serialize)]
pub struct AuthorizeResponse {
#[serde(rename = "verificationId")]
@@ -35,6 +33,8 @@ pub struct AuthorizeResponse {
pub struct TokenRequest {
pub grant_type: String,
pub code: String,
+ pub client_id: String,
+ pub client_secret: String,
}
#[derive(Debug, Deserialize, Serialize)]
@@ -95,7 +95,6 @@ pub fn default_accepted_issuer_dids() -> Vec<String> {
pub struct SwiyuCreateVerificationRequest {
#[serde(default = "default_accepted_issuer_dids")]
pub accepted_issuer_dids: Vec<String>,
-
#[serde(skip_serializing_if = "Option::is_none")]
pub trust_anchors: Option<Vec<TrustAnchor>>,