authorization_codes.rs (4361B)
1 // Database operations for authorization_codes table 2 3 use sqlx::PgPool; 4 use anyhow::Result; 5 use uuid::Uuid; 6 use chrono::{DateTime, Utc}; 7 8 use super::sessions::SessionStatus; 9 10 #[derive(Debug, Clone, sqlx::FromRow)] 11 pub struct AuthorizationCode { 12 pub id: Uuid, 13 pub session_id: Uuid, 14 pub code: String, 15 pub expires_at: DateTime<Utc>, 16 pub used: bool, 17 pub used_at: Option<DateTime<Utc>>, 18 pub created_at: DateTime<Utc>, 19 } 20 21 #[derive(Debug, Clone)] 22 pub struct CodeExchangeData { 23 pub client_id: Uuid, 24 pub code_id: Uuid, 25 pub was_already_used: bool, 26 pub session_id: Uuid, 27 pub session_status: SessionStatus, 28 pub existing_token: Option<String>, 29 pub existing_token_expires_at: Option<DateTime<Utc>>, 30 pub redirect_uri: Option<String>, 31 } 32 33 /// Create a new authorization code for a session 34 /// 35 /// Called after verification completes successfully 36 pub async fn create_authorization_code( 37 pool: &PgPool, 38 session_id: Uuid, 39 code: &str, 40 expires_in_minutes: i64, 41 ) -> Result<AuthorizationCode> { 42 let auth_code = sqlx::query_as::<_, AuthorizationCode>( 43 r#" 44 INSERT INTO oauth2gw.authorization_codes (session_id, code, expires_at) 45 VALUES ($1, $2, NOW() + $3 * INTERVAL '1 minute') 46 RETURNING id, session_id, code, expires_at, used, used_at, created_at 47 "# 48 ) 49 .bind(session_id) 50 .bind(code) 51 .bind(expires_in_minutes) 52 .fetch_one(pool) 53 .await?; 54 55 Ok(auth_code) 56 } 57 58 /// Mark code as used and fetch session + existing token data 59 /// 60 /// Used by the /token endpoint 61 pub async fn get_code_for_token_exchange( 62 pool: &PgPool, 63 code: &str, 64 ) -> Result<Option<CodeExchangeData>> { 65 // Use CTE to capture old 'used' value before the UPDATE changes it 66 let result = sqlx::query( 67 r#" 68 WITH code_data AS ( 69 SELECT id, used AS was_already_used, session_id 70 FROM oauth2gw.authorization_codes 71 WHERE code = $1 AND expires_at > NOW() 72 FOR UPDATE 73 ), 74 updated_code AS ( 75 UPDATE oauth2gw.authorization_codes ac 76 SET used = TRUE, 77 used_at = CASE WHEN NOT ac.used THEN NOW() ELSE ac.used_at END 78 FROM code_data cd 79 WHERE ac.id = cd.id 80 RETURNING ac.id, ac.session_id 81 ) 82 SELECT 83 uc.id AS code_id, 84 cd.was_already_used, 85 uc.session_id, 86 vs.client_id, 87 vs.status AS session_status, 88 at.token AS existing_token, 89 at.expires_at AS existing_token_expires_at, 90 vs.redirect_uri 91 FROM updated_code uc 92 JOIN code_data cd ON uc.id = cd.id 93 JOIN oauth2gw.verification_sessions vs ON vs.id = uc.session_id 94 LEFT JOIN oauth2gw.access_tokens at 95 ON at.session_id = vs.id AND at.revoked = FALSE 96 "# 97 ) 98 .bind(code) 99 .fetch_optional(pool) 100 .await?; 101 102 Ok(result.map(|row: sqlx::postgres::PgRow| { 103 use sqlx::Row; 104 CodeExchangeData { 105 client_id: row.get("client_id"), 106 code_id: row.get("code_id"), 107 was_already_used: row.get("was_already_used"), 108 session_id: row.get("session_id"), 109 session_status: row.get("session_status"), 110 existing_token: row.get("existing_token"), 111 existing_token_expires_at: row.get("existing_token_expires_at"), 112 redirect_uri: row.get("redirect_uri"), 113 } 114 })) 115 } 116 117 /// Get authorization code by session ID 118 pub async fn get_code_by_session( 119 pool: &PgPool, 120 session_id: Uuid, 121 ) -> Result<Option<AuthorizationCode>> { 122 let auth_code = sqlx::query_as::<_, AuthorizationCode>( 123 r#" 124 SELECT id, session_id, code, expires_at, used, used_at, created_at 125 FROM oauth2gw.authorization_codes 126 WHERE session_id = $1 127 ORDER BY created_at DESC 128 LIMIT 1 129 "# 130 ) 131 .bind(session_id) 132 .fetch_optional(pool) 133 .await?; 134 135 Ok(auth_code) 136 } 137 138 /// Delete expired authorization codes (garbage collection) 139 pub async fn delete_expired_codes(pool: &PgPool) -> Result<u64> { 140 let result = sqlx::query( 141 r#" 142 DELETE FROM oauth2gw.authorization_codes 143 WHERE expires_at < CURRENT_TIMESTAMP 144 "# 145 ) 146 .execute(pool) 147 .await?; 148 149 Ok(result.rows_affected()) 150 }