authorization_codes.rs (4243B)
1 // Database operations for authorization_codes table 2 3 use sqlx::PgPool; 4 use anyhow::Result; 5 use uuid::Uuid; 6 use chrono::{DateTime, Utc}; 7 8 use super::sessions::SessionStatus; 9 10 #[derive(Debug, Clone, sqlx::FromRow)] 11 pub struct AuthorizationCode { 12 pub id: Uuid, 13 pub session_id: Uuid, 14 pub code: String, 15 pub expires_at: DateTime<Utc>, 16 pub used: bool, 17 pub used_at: Option<DateTime<Utc>>, 18 pub created_at: DateTime<Utc>, 19 } 20 21 #[derive(Debug, Clone)] 22 pub struct CodeExchangeData { 23 pub client_id: Uuid, 24 pub code_id: Uuid, 25 pub was_already_used: bool, 26 pub session_id: Uuid, 27 pub session_status: SessionStatus, 28 pub existing_token: Option<String>, 29 pub existing_token_expires_at: Option<DateTime<Utc>>, 30 } 31 32 /// Create a new authorization code for a session 33 /// 34 /// Called after verification completes successfully 35 pub async fn create_authorization_code( 36 pool: &PgPool, 37 session_id: Uuid, 38 code: &str, 39 expires_in_minutes: i64, 40 ) -> Result<AuthorizationCode> { 41 let auth_code = sqlx::query_as::<_, AuthorizationCode>( 42 r#" 43 INSERT INTO oauth2gw.authorization_codes (session_id, code, expires_at) 44 VALUES ($1, $2, NOW() + $3 * INTERVAL '1 minute') 45 RETURNING id, session_id, code, expires_at, used, used_at, created_at 46 "# 47 ) 48 .bind(session_id) 49 .bind(code) 50 .bind(expires_in_minutes) 51 .fetch_one(pool) 52 .await?; 53 54 Ok(auth_code) 55 } 56 57 /// Mark code as used and fetch session + existing token data 58 /// 59 /// Used by the /token endpoint 60 pub async fn get_code_for_token_exchange( 61 pool: &PgPool, 62 code: &str, 63 ) -> Result<Option<CodeExchangeData>> { 64 // Use CTE to capture old 'used' value before the UPDATE changes it 65 let result = sqlx::query( 66 r#" 67 WITH code_data AS ( 68 SELECT id, used AS was_already_used, session_id 69 FROM oauth2gw.authorization_codes 70 WHERE code = $1 AND expires_at > NOW() 71 FOR UPDATE 72 ), 73 updated_code AS ( 74 UPDATE oauth2gw.authorization_codes ac 75 SET used = TRUE, 76 used_at = CASE WHEN NOT ac.used THEN NOW() ELSE ac.used_at END 77 FROM code_data cd 78 WHERE ac.id = cd.id 79 RETURNING ac.id, ac.session_id 80 ) 81 SELECT 82 uc.id AS code_id, 83 cd.was_already_used, 84 uc.session_id, 85 vs.client_id, 86 vs.status AS session_status, 87 at.token AS existing_token, 88 at.expires_at AS existing_token_expires_at 89 FROM updated_code uc 90 JOIN code_data cd ON uc.id = cd.id 91 JOIN oauth2gw.verification_sessions vs ON vs.id = uc.session_id 92 LEFT JOIN oauth2gw.access_tokens at 93 ON at.session_id = vs.id AND at.revoked = FALSE 94 "# 95 ) 96 .bind(code) 97 .fetch_optional(pool) 98 .await?; 99 100 Ok(result.map(|row: sqlx::postgres::PgRow| { 101 use sqlx::Row; 102 CodeExchangeData { 103 client_id: row.get("client_id"), 104 code_id: row.get("code_id"), 105 was_already_used: row.get("was_already_used"), 106 session_id: row.get("session_id"), 107 session_status: row.get("session_status"), 108 existing_token: row.get("existing_token"), 109 existing_token_expires_at: row.get("existing_token_expires_at"), 110 } 111 })) 112 } 113 114 /// Get authorization code by session ID 115 pub async fn get_code_by_session( 116 pool: &PgPool, 117 session_id: Uuid, 118 ) -> Result<Option<AuthorizationCode>> { 119 let auth_code = sqlx::query_as::<_, AuthorizationCode>( 120 r#" 121 SELECT id, session_id, code, expires_at, used, used_at, created_at 122 FROM oauth2gw.authorization_codes 123 WHERE session_id = $1 124 ORDER BY created_at DESC 125 LIMIT 1 126 "# 127 ) 128 .bind(session_id) 129 .fetch_optional(pool) 130 .await?; 131 132 Ok(auth_code) 133 } 134 135 /// Delete expired authorization codes (garbage collection) 136 pub async fn delete_expired_codes(pool: &PgPool) -> Result<u64> { 137 let result = sqlx::query( 138 r#" 139 DELETE FROM oauth2gw.authorization_codes 140 WHERE expires_at < CURRENT_TIMESTAMP 141 "# 142 ) 143 .execute(pool) 144 .await?; 145 146 Ok(result.rows_affected()) 147 }