kych

OAuth 2.0 API for Swiyu to enable Taler integration of Swiyu for KYC (experimental)
Log | Files | Refs | README

authorization_codes.rs (4361B)


      1 // Database operations for authorization_codes table
      2 
      3 use sqlx::PgPool;
      4 use anyhow::Result;
      5 use uuid::Uuid;
      6 use chrono::{DateTime, Utc};
      7 
      8 use super::sessions::SessionStatus;
      9 
     10 #[derive(Debug, Clone, sqlx::FromRow)]
     11 pub struct AuthorizationCode {
     12     pub id: Uuid,
     13     pub session_id: Uuid,
     14     pub code: String,
     15     pub expires_at: DateTime<Utc>,
     16     pub used: bool,
     17     pub used_at: Option<DateTime<Utc>>,
     18     pub created_at: DateTime<Utc>,
     19 }
     20 
     21 #[derive(Debug, Clone)]
     22 pub struct CodeExchangeData {
     23     pub client_id: Uuid,
     24     pub code_id: Uuid,
     25     pub was_already_used: bool,
     26     pub session_id: Uuid,
     27     pub session_status: SessionStatus,
     28     pub existing_token: Option<String>,
     29     pub existing_token_expires_at: Option<DateTime<Utc>>,
     30     pub redirect_uri: Option<String>,
     31 }
     32 
     33 /// Create a new authorization code for a session
     34 ///
     35 /// Called after verification completes successfully
     36 pub async fn create_authorization_code(
     37     pool: &PgPool,
     38     session_id: Uuid,
     39     code: &str,
     40     expires_in_minutes: i64,
     41 ) -> Result<AuthorizationCode> {
     42     let auth_code = sqlx::query_as::<_, AuthorizationCode>(
     43         r#"
     44         INSERT INTO oauth2gw.authorization_codes (session_id, code, expires_at)
     45         VALUES ($1, $2, NOW() + $3 * INTERVAL '1 minute')
     46         RETURNING id, session_id, code, expires_at, used, used_at, created_at
     47         "#
     48     )
     49     .bind(session_id)
     50     .bind(code)
     51     .bind(expires_in_minutes)
     52     .fetch_one(pool)
     53     .await?;
     54 
     55     Ok(auth_code)
     56 }
     57 
     58 /// Mark code as used and fetch session + existing token data
     59 ///
     60 /// Used by the /token endpoint
     61 pub async fn get_code_for_token_exchange(
     62     pool: &PgPool,
     63     code: &str,
     64 ) -> Result<Option<CodeExchangeData>> {
     65     // Use CTE to capture old 'used' value before the UPDATE changes it
     66     let result = sqlx::query(
     67         r#"
     68         WITH code_data AS (
     69             SELECT id, used AS was_already_used, session_id
     70             FROM oauth2gw.authorization_codes
     71             WHERE code = $1 AND expires_at > NOW()
     72             FOR UPDATE
     73         ),
     74         updated_code AS (
     75             UPDATE oauth2gw.authorization_codes ac
     76             SET used = TRUE,
     77                 used_at = CASE WHEN NOT ac.used THEN NOW() ELSE ac.used_at END
     78             FROM code_data cd
     79             WHERE ac.id = cd.id
     80             RETURNING ac.id, ac.session_id
     81         )
     82         SELECT
     83             uc.id AS code_id,
     84             cd.was_already_used,
     85             uc.session_id,
     86             vs.client_id,
     87             vs.status AS session_status,
     88             at.token AS existing_token,
     89             at.expires_at AS existing_token_expires_at,
     90             vs.redirect_uri
     91         FROM updated_code uc
     92         JOIN code_data cd ON uc.id = cd.id
     93         JOIN oauth2gw.verification_sessions vs ON vs.id = uc.session_id
     94         LEFT JOIN oauth2gw.access_tokens at
     95             ON at.session_id = vs.id AND at.revoked = FALSE
     96         "#
     97     )
     98     .bind(code)
     99     .fetch_optional(pool)
    100     .await?;
    101 
    102     Ok(result.map(|row: sqlx::postgres::PgRow| {
    103         use sqlx::Row;
    104         CodeExchangeData {
    105             client_id: row.get("client_id"),
    106             code_id: row.get("code_id"),
    107             was_already_used: row.get("was_already_used"),
    108             session_id: row.get("session_id"),
    109             session_status: row.get("session_status"),
    110             existing_token: row.get("existing_token"),
    111             existing_token_expires_at: row.get("existing_token_expires_at"),
    112             redirect_uri: row.get("redirect_uri"),
    113         }
    114     }))
    115 }
    116 
    117 /// Get authorization code by session ID
    118 pub async fn get_code_by_session(
    119     pool: &PgPool,
    120     session_id: Uuid,
    121 ) -> Result<Option<AuthorizationCode>> {
    122     let auth_code = sqlx::query_as::<_, AuthorizationCode>(
    123         r#"
    124         SELECT id, session_id, code, expires_at, used, used_at, created_at
    125         FROM oauth2gw.authorization_codes
    126         WHERE session_id = $1
    127         ORDER BY created_at DESC
    128         LIMIT 1
    129         "#
    130     )
    131     .bind(session_id)
    132     .fetch_optional(pool)
    133     .await?;
    134 
    135     Ok(auth_code)
    136 }
    137 
    138 /// Delete expired authorization codes (garbage collection)
    139 pub async fn delete_expired_codes(pool: &PgPool) -> Result<u64> {
    140     let result = sqlx::query(
    141         r#"
    142         DELETE FROM oauth2gw.authorization_codes
    143         WHERE expires_at < CURRENT_TIMESTAMP
    144         "#
    145     )
    146     .execute(pool)
    147     .await?;
    148 
    149     Ok(result.rows_affected())
    150 }