kych

OAuth 2.0 API for Swiyu to enable Taler integration of Swiyu for KYC (experimental)
Log | Files | Refs | README

config.rs (12052B)


      1 use anyhow::{Context, Result};
      2 use serde::{Deserialize, Serialize};
      3 use ini::Ini;
      4 use std::collections::HashSet;
      5 use std::path::Path;
      6 
      7 const MAIN_SECTION: &str = "kych-oauth2-gateway";
      8 
      9 #[derive(Debug, Clone, Serialize, Deserialize)]
     10 pub struct Config {
     11     pub server: ServerConfig,
     12     pub database: DatabaseConfig,
     13     pub crypto: CryptoConfig,
     14     pub vc: VcConfig,
     15     pub allowed_scopes: Option<Vec<String>>,
     16     pub clients: Vec<ClientConfig>,
     17 }
     18 
     19 #[derive(Debug, Clone, Serialize, Deserialize)]
     20 pub struct ServerConfig {
     21     pub host: Option<String>,
     22     pub port: Option<u16>,
     23     pub socket_path: Option<String>,
     24     pub socket_mode: u32,
     25 }
     26 
     27 impl ServerConfig {
     28     pub fn validate(&self) -> Result<()> {
     29         let has_tcp = self.host.is_some() || self.port.is_some();
     30         let has_unix = self.socket_path.is_some();
     31 
     32         if has_tcp && has_unix {
     33             anyhow::bail!("Cannot specify both TCP (HOST/PORT) and Unix socket (UNIXPATH)");
     34         }
     35 
     36         if !has_tcp && !has_unix {
     37             anyhow::bail!("Must specify either TCP (HOST/PORT) or Unix socket (UNIXPATH)");
     38         }
     39 
     40         if has_tcp && (self.host.is_none() || self.port.is_none()) {
     41             anyhow::bail!("HOST and PORT must both be specified for TCP");
     42         }
     43 
     44         Ok(())
     45     }
     46 
     47     pub fn is_unix_socket(&self) -> bool {
     48         self.socket_path.is_some()
     49     }
     50 }
     51 
     52 #[derive(Debug, Clone, Serialize, Deserialize)]
     53 pub struct DatabaseConfig {
     54     pub url: String,
     55 }
     56 
     57 #[derive(Debug, Clone, Serialize, Deserialize)]
     58 pub struct CryptoConfig {
     59     pub nonce_bytes: usize,
     60     pub token_bytes: usize,
     61     pub authorization_code_bytes: usize,
     62     pub authorization_code_ttl_minutes: i64,
     63 }
     64 
     65 #[derive(Debug, Clone, Serialize, Deserialize)]
     66 pub struct ClientConfig {
     67     pub section_name: String,
     68     pub client_id: String,
     69     pub client_secret: String,
     70     pub verifier_url: String,
     71     pub verifier_management_api_path: String,
     72     pub redirect_uri: String,
     73     pub accepted_issuer_dids: Option<String>,
     74 }
     75 
     76 #[derive(Debug, Clone, Serialize, Deserialize)]
     77 pub struct VcConfig {
     78     pub vc_type: String,
     79     pub vc_format: String,
     80     pub vc_algorithms: Vec<String>,
     81     pub vc_claims: HashSet<String>,
     82 }
     83 
     84 impl Config {
     85     pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
     86         let ini = Ini::load_from_file(path.as_ref())
     87             .context("Failed to load config file")?;
     88 
     89         let main_section = ini
     90             .section(Some(MAIN_SECTION))
     91             .context(format!("Missing [{}] section", MAIN_SECTION))?;
     92 
     93         let host = main_section.get("HOST")
     94             .filter(|s| !s.is_empty())
     95             .map(|s| s.to_string());
     96         let port = main_section
     97             .get("PORT")
     98             .filter(|s| !s.is_empty())
     99             .map(|s| s.parse::<u16>())
    100             .transpose()
    101             .context("Invalid PORT")?;
    102         let socket_path = main_section.get("UNIXPATH")
    103             .filter(|s| !s.is_empty())
    104             .map(|s| s.to_string());
    105         let socket_mode = main_section
    106             .get("UNIXPATH_MODE")
    107             .filter(|s| !s.is_empty())
    108             .map(|s| u32::from_str_radix(s, 8))
    109             .transpose()
    110             .context("Invalid UNIXPATH_MODE (expected octal, e.g. 666)")?
    111             .unwrap_or(0o666);
    112 
    113         let server = ServerConfig {
    114             host,
    115             port,
    116             socket_path,
    117             socket_mode,
    118         };
    119 
    120         server.validate()?;
    121 
    122         let database = DatabaseConfig {
    123             url: main_section
    124                 .get("DATABASE")
    125                 .context("Missing DATABASE")?
    126                 .to_string(),
    127         };
    128 
    129         let crypto = CryptoConfig {
    130             nonce_bytes: main_section
    131                 .get("NONCE_BYTES")
    132                 .context("Missing NONCE_BYTES")?
    133                 .parse()
    134                 .context("Invalid NONCE_BYTES")?,
    135             token_bytes: main_section
    136                 .get("TOKEN_BYTES")
    137                 .context("Missing TOKEN_BYTES")?
    138                 .parse()
    139                 .context("Invalid TOKEN_BYTES")?,
    140             authorization_code_bytes: main_section
    141                 .get("AUTH_CODE_BYTES")
    142                 .context("Missing AUTH_CODE_BYTES")?
    143                 .parse()
    144                 .context("Invalid AUTH_CODE_BYTES")?,
    145             authorization_code_ttl_minutes: main_section
    146                 .get("AUTH_CODE_TTL_MINUTES")
    147                 .unwrap_or("10")
    148                 .parse()
    149                 .context("Invalid AUTH_CODE_TTL_MINUTES")?,
    150         };
    151 
    152         let allowed_scopes = match main_section.get("ALLOWED_SCOPES") {
    153             Some(raw) if !raw.trim().is_empty() => Some(parse_allowed_scopes(raw)?),
    154             _ => None,
    155         };
    156 
    157         let vc_type = main_section
    158             .get("VC_TYPE")
    159             .filter(|s| !s.is_empty())
    160             .context("missing required config: VC_TYPE")?
    161             .to_string();
    162 
    163         let vc_format = main_section
    164             .get("VC_FORMAT")
    165             .filter(|s| !s.is_empty())
    166             .context("missing required config: VC_FORMAT")?
    167             .to_string();
    168 
    169         let vc_algorithms = parse_bracketed_list(
    170             main_section
    171                 .get("VC_ALGORITHMS")
    172                 .context("missing required config: VC_ALGORITHMS")?,
    173             "VC_ALGORITHMS",
    174         )?;
    175         if vc_algorithms.is_empty() {
    176             anyhow::bail!("VC_ALGORITHMS must contain at least one algorithm");
    177         }
    178 
    179         let vc_claims_list = parse_bracketed_list(
    180             main_section
    181                 .get("VC_CLAIMS")
    182                 .context("missing required config: VC_CLAIMS")?,
    183             "VC_CLAIMS",
    184         )?;
    185         if vc_claims_list.is_empty() {
    186             anyhow::bail!("VC_CLAIMS must contain at least one claim");
    187         }
    188         let vc_claims: HashSet<String> = vc_claims_list.into_iter().collect();
    189 
    190         let vc = VcConfig {
    191             vc_type,
    192             vc_format,
    193             vc_algorithms,
    194             vc_claims,
    195         };
    196 
    197         let mut clients = Vec::new();
    198         for (section_name, properties) in ini.iter() {
    199             let section_name = match section_name {
    200                 Some(name) if name.starts_with("client_") => name,
    201                 _ => continue,
    202             };
    203 
    204             let client_id = properties.get("CLIENT_ID")
    205                 .context(format!("Missing CLIENT_ID in section [{}]", section_name))?
    206                 .to_string();
    207             let client_secret = properties.get("CLIENT_SECRET")
    208                 .context(format!("Missing CLIENT_SECRET in section [{}]", section_name))?
    209                 .to_string();
    210             let verifier_url = properties.get("VERIFIER_URL")
    211                 .context(format!("Missing VERIFIER_URL in section [{}]", section_name))?
    212                 .to_string();
    213             let verifier_management_api_path = properties.get("VERIFIER_MANAGEMENT_API_PATH")
    214                 .unwrap_or("/management/api/verifications")
    215                 .to_string();
    216             let redirect_uri = properties.get("REDIRECT_URI")
    217                 .filter(|s| !s.is_empty())
    218                 .context(format!("Missing REDIRECT_URI in section [{}]", section_name))?
    219                 .to_string();
    220             let accepted_issuer_dids = properties.get("ACCEPTED_ISSUER_DIDS")
    221                 .filter(|s| !s.is_empty())
    222                 .map(|s| s.to_string());
    223 
    224             clients.push(ClientConfig {
    225                 section_name: section_name.to_string(),
    226                 client_id,
    227                 client_secret,
    228                 verifier_url,
    229                 verifier_management_api_path,
    230                 redirect_uri,
    231                 accepted_issuer_dids,
    232             });
    233         }
    234 
    235         Ok(Config {
    236             server,
    237             database,
    238             crypto,
    239             vc,
    240             allowed_scopes,
    241             clients,
    242         })
    243     }
    244 }
    245 
    246 fn parse_allowed_scopes(raw: &str) -> Result<Vec<String>> {
    247     let trimmed = raw.trim();
    248     let trimmed = trimmed.strip_prefix('{').unwrap_or(trimmed);
    249     let trimmed = trimmed.strip_suffix('}').unwrap_or(trimmed);
    250 
    251     let scopes: Vec<String> = trimmed
    252         .split(|c: char| c == ',' || c.is_whitespace())
    253         .map(|s| s.trim())
    254         .filter(|s| !s.is_empty())
    255         .map(|s| s.to_string())
    256         .collect();
    257 
    258     if scopes.is_empty() {
    259         anyhow::bail!("ALLOWED_SCOPES must contain at least one scope");
    260     }
    261 
    262     Ok(scopes)
    263 }
    264 
    265 fn parse_bracketed_list(value: &str, field_name: &str) -> Result<Vec<String>> {
    266     let trimmed = value.trim();
    267     if !trimmed.starts_with('{') || !trimmed.ends_with('}') {
    268         anyhow::bail!("invalid {} format: expected {{item1, item2, ...}}", field_name);
    269     }
    270     let inner = &trimmed[1..trimmed.len() - 1];
    271     let items: Vec<String> = inner
    272         .split(',')
    273         .map(|s| s.trim().to_string())
    274         .filter(|s| !s.is_empty())
    275         .collect();
    276     Ok(items)
    277 }
    278 
    279 #[cfg(test)]
    280 mod tests {
    281     use super::*;
    282 
    283     #[test]
    284     fn test_server_validate_tcp_ok() {
    285         let server = ServerConfig {
    286             host: Some("127.0.0.1".to_string()),
    287             port: Some(8080),
    288             socket_path: None,
    289             socket_mode: 0o666,
    290         };
    291 
    292         assert!(server.validate().is_ok());
    293     }
    294 
    295     #[test]
    296     fn test_server_validate_unix_ok() {
    297         let server = ServerConfig {
    298             host: None,
    299             port: None,
    300             socket_path: Some("/tmp/kych.sock".to_string()),
    301             socket_mode: 0o666,
    302         };
    303 
    304         assert!(server.validate().is_ok());
    305     }
    306 
    307     #[test]
    308     fn test_server_validate_both_err() {
    309         let server = ServerConfig {
    310             host: Some("127.0.0.1".to_string()),
    311             port: Some(8080),
    312             socket_path: Some("/tmp/kych.sock".to_string()),
    313             socket_mode: 0o666,
    314         };
    315 
    316         assert!(server.validate().is_err());
    317     }
    318 
    319     #[test]
    320     fn test_server_validate_neither_err() {
    321         let server = ServerConfig {
    322             host: None,
    323             port: None,
    324             socket_path: None,
    325             socket_mode: 0o666,
    326         };
    327 
    328         assert!(server.validate().is_err());
    329     }
    330 
    331     #[test]
    332     fn test_server_validate_missing_port_err() {
    333         let server = ServerConfig {
    334             host: Some("127.0.0.1".to_string()),
    335             port: None,
    336             socket_path: None,
    337             socket_mode: 0o666,
    338         };
    339 
    340         assert!(server.validate().is_err());
    341     }
    342 
    343     #[test]
    344     fn test_parse_allowed_scopes_variants() {
    345         let scopes = parse_allowed_scopes("{a, b c}").unwrap();
    346         assert_eq!(scopes, vec!["a", "b", "c"]);
    347 
    348         let scopes = parse_allowed_scopes("  a b  c ").unwrap();
    349         assert_eq!(scopes, vec!["a", "b", "c"]);
    350 
    351         let scopes = parse_allowed_scopes("a,b,c").unwrap();
    352         assert_eq!(scopes, vec!["a", "b", "c"]);
    353     }
    354 
    355     #[test]
    356     fn test_parse_allowed_scopes_empty_err() {
    357         assert!(parse_allowed_scopes("").is_err());
    358         assert!(parse_allowed_scopes("   ").is_err());
    359         assert!(parse_allowed_scopes("{}").is_err());
    360         assert!(parse_allowed_scopes("{   }").is_err());
    361     }
    362 
    363     #[test]
    364     fn test_parse_bracketed_list_valid() {
    365         let items = parse_bracketed_list("{a, b, c}", "TEST").unwrap();
    366         assert_eq!(items, vec!["a", "b", "c"]);
    367     }
    368 
    369     #[test]
    370     fn test_parse_bracketed_list_single_item() {
    371         let items = parse_bracketed_list("{ES256}", "TEST").unwrap();
    372         assert_eq!(items, vec!["ES256"]);
    373     }
    374 
    375     #[test]
    376     fn test_parse_bracketed_list_extra_whitespace() {
    377         let items = parse_bracketed_list("{ a , b }", "TEST").unwrap();
    378         assert_eq!(items, vec!["a", "b"]);
    379     }
    380 
    381     #[test]
    382     fn test_parse_bracketed_list_missing_braces() {
    383         let result = parse_bracketed_list("a, b", "TEST");
    384         assert!(result.is_err());
    385         let err = result.unwrap_err().to_string();
    386         assert!(err.contains("TEST"));
    387     }
    388 
    389     #[test]
    390     fn test_parse_bracketed_list_empty() {
    391         let items = parse_bracketed_list("{}", "TEST").unwrap();
    392         assert!(items.is_empty());
    393     }
    394 }