config.rs (12052B)
1 use anyhow::{Context, Result}; 2 use serde::{Deserialize, Serialize}; 3 use ini::Ini; 4 use std::collections::HashSet; 5 use std::path::Path; 6 7 const MAIN_SECTION: &str = "kych-oauth2-gateway"; 8 9 #[derive(Debug, Clone, Serialize, Deserialize)] 10 pub struct Config { 11 pub server: ServerConfig, 12 pub database: DatabaseConfig, 13 pub crypto: CryptoConfig, 14 pub vc: VcConfig, 15 pub allowed_scopes: Option<Vec<String>>, 16 pub clients: Vec<ClientConfig>, 17 } 18 19 #[derive(Debug, Clone, Serialize, Deserialize)] 20 pub struct ServerConfig { 21 pub host: Option<String>, 22 pub port: Option<u16>, 23 pub socket_path: Option<String>, 24 pub socket_mode: u32, 25 } 26 27 impl ServerConfig { 28 pub fn validate(&self) -> Result<()> { 29 let has_tcp = self.host.is_some() || self.port.is_some(); 30 let has_unix = self.socket_path.is_some(); 31 32 if has_tcp && has_unix { 33 anyhow::bail!("Cannot specify both TCP (HOST/PORT) and Unix socket (UNIXPATH)"); 34 } 35 36 if !has_tcp && !has_unix { 37 anyhow::bail!("Must specify either TCP (HOST/PORT) or Unix socket (UNIXPATH)"); 38 } 39 40 if has_tcp && (self.host.is_none() || self.port.is_none()) { 41 anyhow::bail!("HOST and PORT must both be specified for TCP"); 42 } 43 44 Ok(()) 45 } 46 47 pub fn is_unix_socket(&self) -> bool { 48 self.socket_path.is_some() 49 } 50 } 51 52 #[derive(Debug, Clone, Serialize, Deserialize)] 53 pub struct DatabaseConfig { 54 pub url: String, 55 } 56 57 #[derive(Debug, Clone, Serialize, Deserialize)] 58 pub struct CryptoConfig { 59 pub nonce_bytes: usize, 60 pub token_bytes: usize, 61 pub authorization_code_bytes: usize, 62 pub authorization_code_ttl_minutes: i64, 63 } 64 65 #[derive(Debug, Clone, Serialize, Deserialize)] 66 pub struct ClientConfig { 67 pub section_name: String, 68 pub client_id: String, 69 pub client_secret: String, 70 pub verifier_url: String, 71 pub verifier_management_api_path: String, 72 pub redirect_uri: String, 73 pub accepted_issuer_dids: Option<String>, 74 } 75 76 #[derive(Debug, Clone, Serialize, Deserialize)] 77 pub struct VcConfig { 78 pub vc_type: String, 79 pub vc_format: String, 80 pub vc_algorithms: Vec<String>, 81 pub vc_claims: HashSet<String>, 82 } 83 84 impl Config { 85 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> { 86 let ini = Ini::load_from_file(path.as_ref()) 87 .context("Failed to load config file")?; 88 89 let main_section = ini 90 .section(Some(MAIN_SECTION)) 91 .context(format!("Missing [{}] section", MAIN_SECTION))?; 92 93 let host = main_section.get("HOST") 94 .filter(|s| !s.is_empty()) 95 .map(|s| s.to_string()); 96 let port = main_section 97 .get("PORT") 98 .filter(|s| !s.is_empty()) 99 .map(|s| s.parse::<u16>()) 100 .transpose() 101 .context("Invalid PORT")?; 102 let socket_path = main_section.get("UNIXPATH") 103 .filter(|s| !s.is_empty()) 104 .map(|s| s.to_string()); 105 let socket_mode = main_section 106 .get("UNIXPATH_MODE") 107 .filter(|s| !s.is_empty()) 108 .map(|s| u32::from_str_radix(s, 8)) 109 .transpose() 110 .context("Invalid UNIXPATH_MODE (expected octal, e.g. 666)")? 111 .unwrap_or(0o666); 112 113 let server = ServerConfig { 114 host, 115 port, 116 socket_path, 117 socket_mode, 118 }; 119 120 server.validate()?; 121 122 let database = DatabaseConfig { 123 url: main_section 124 .get("DATABASE") 125 .context("Missing DATABASE")? 126 .to_string(), 127 }; 128 129 let crypto = CryptoConfig { 130 nonce_bytes: main_section 131 .get("NONCE_BYTES") 132 .context("Missing NONCE_BYTES")? 133 .parse() 134 .context("Invalid NONCE_BYTES")?, 135 token_bytes: main_section 136 .get("TOKEN_BYTES") 137 .context("Missing TOKEN_BYTES")? 138 .parse() 139 .context("Invalid TOKEN_BYTES")?, 140 authorization_code_bytes: main_section 141 .get("AUTH_CODE_BYTES") 142 .context("Missing AUTH_CODE_BYTES")? 143 .parse() 144 .context("Invalid AUTH_CODE_BYTES")?, 145 authorization_code_ttl_minutes: main_section 146 .get("AUTH_CODE_TTL_MINUTES") 147 .unwrap_or("10") 148 .parse() 149 .context("Invalid AUTH_CODE_TTL_MINUTES")?, 150 }; 151 152 let allowed_scopes = match main_section.get("ALLOWED_SCOPES") { 153 Some(raw) if !raw.trim().is_empty() => Some(parse_allowed_scopes(raw)?), 154 _ => None, 155 }; 156 157 let vc_type = main_section 158 .get("VC_TYPE") 159 .filter(|s| !s.is_empty()) 160 .context("missing required config: VC_TYPE")? 161 .to_string(); 162 163 let vc_format = main_section 164 .get("VC_FORMAT") 165 .filter(|s| !s.is_empty()) 166 .context("missing required config: VC_FORMAT")? 167 .to_string(); 168 169 let vc_algorithms = parse_bracketed_list( 170 main_section 171 .get("VC_ALGORITHMS") 172 .context("missing required config: VC_ALGORITHMS")?, 173 "VC_ALGORITHMS", 174 )?; 175 if vc_algorithms.is_empty() { 176 anyhow::bail!("VC_ALGORITHMS must contain at least one algorithm"); 177 } 178 179 let vc_claims_list = parse_bracketed_list( 180 main_section 181 .get("VC_CLAIMS") 182 .context("missing required config: VC_CLAIMS")?, 183 "VC_CLAIMS", 184 )?; 185 if vc_claims_list.is_empty() { 186 anyhow::bail!("VC_CLAIMS must contain at least one claim"); 187 } 188 let vc_claims: HashSet<String> = vc_claims_list.into_iter().collect(); 189 190 let vc = VcConfig { 191 vc_type, 192 vc_format, 193 vc_algorithms, 194 vc_claims, 195 }; 196 197 let mut clients = Vec::new(); 198 for (section_name, properties) in ini.iter() { 199 let section_name = match section_name { 200 Some(name) if name.starts_with("client_") => name, 201 _ => continue, 202 }; 203 204 let client_id = properties.get("CLIENT_ID") 205 .context(format!("Missing CLIENT_ID in section [{}]", section_name))? 206 .to_string(); 207 let client_secret = properties.get("CLIENT_SECRET") 208 .context(format!("Missing CLIENT_SECRET in section [{}]", section_name))? 209 .to_string(); 210 let verifier_url = properties.get("VERIFIER_URL") 211 .context(format!("Missing VERIFIER_URL in section [{}]", section_name))? 212 .to_string(); 213 let verifier_management_api_path = properties.get("VERIFIER_MANAGEMENT_API_PATH") 214 .unwrap_or("/management/api/verifications") 215 .to_string(); 216 let redirect_uri = properties.get("REDIRECT_URI") 217 .filter(|s| !s.is_empty()) 218 .context(format!("Missing REDIRECT_URI in section [{}]", section_name))? 219 .to_string(); 220 let accepted_issuer_dids = properties.get("ACCEPTED_ISSUER_DIDS") 221 .filter(|s| !s.is_empty()) 222 .map(|s| s.to_string()); 223 224 clients.push(ClientConfig { 225 section_name: section_name.to_string(), 226 client_id, 227 client_secret, 228 verifier_url, 229 verifier_management_api_path, 230 redirect_uri, 231 accepted_issuer_dids, 232 }); 233 } 234 235 Ok(Config { 236 server, 237 database, 238 crypto, 239 vc, 240 allowed_scopes, 241 clients, 242 }) 243 } 244 } 245 246 fn parse_allowed_scopes(raw: &str) -> Result<Vec<String>> { 247 let trimmed = raw.trim(); 248 let trimmed = trimmed.strip_prefix('{').unwrap_or(trimmed); 249 let trimmed = trimmed.strip_suffix('}').unwrap_or(trimmed); 250 251 let scopes: Vec<String> = trimmed 252 .split(|c: char| c == ',' || c.is_whitespace()) 253 .map(|s| s.trim()) 254 .filter(|s| !s.is_empty()) 255 .map(|s| s.to_string()) 256 .collect(); 257 258 if scopes.is_empty() { 259 anyhow::bail!("ALLOWED_SCOPES must contain at least one scope"); 260 } 261 262 Ok(scopes) 263 } 264 265 fn parse_bracketed_list(value: &str, field_name: &str) -> Result<Vec<String>> { 266 let trimmed = value.trim(); 267 if !trimmed.starts_with('{') || !trimmed.ends_with('}') { 268 anyhow::bail!("invalid {} format: expected {{item1, item2, ...}}", field_name); 269 } 270 let inner = &trimmed[1..trimmed.len() - 1]; 271 let items: Vec<String> = inner 272 .split(',') 273 .map(|s| s.trim().to_string()) 274 .filter(|s| !s.is_empty()) 275 .collect(); 276 Ok(items) 277 } 278 279 #[cfg(test)] 280 mod tests { 281 use super::*; 282 283 #[test] 284 fn test_server_validate_tcp_ok() { 285 let server = ServerConfig { 286 host: Some("127.0.0.1".to_string()), 287 port: Some(8080), 288 socket_path: None, 289 socket_mode: 0o666, 290 }; 291 292 assert!(server.validate().is_ok()); 293 } 294 295 #[test] 296 fn test_server_validate_unix_ok() { 297 let server = ServerConfig { 298 host: None, 299 port: None, 300 socket_path: Some("/tmp/kych.sock".to_string()), 301 socket_mode: 0o666, 302 }; 303 304 assert!(server.validate().is_ok()); 305 } 306 307 #[test] 308 fn test_server_validate_both_err() { 309 let server = ServerConfig { 310 host: Some("127.0.0.1".to_string()), 311 port: Some(8080), 312 socket_path: Some("/tmp/kych.sock".to_string()), 313 socket_mode: 0o666, 314 }; 315 316 assert!(server.validate().is_err()); 317 } 318 319 #[test] 320 fn test_server_validate_neither_err() { 321 let server = ServerConfig { 322 host: None, 323 port: None, 324 socket_path: None, 325 socket_mode: 0o666, 326 }; 327 328 assert!(server.validate().is_err()); 329 } 330 331 #[test] 332 fn test_server_validate_missing_port_err() { 333 let server = ServerConfig { 334 host: Some("127.0.0.1".to_string()), 335 port: None, 336 socket_path: None, 337 socket_mode: 0o666, 338 }; 339 340 assert!(server.validate().is_err()); 341 } 342 343 #[test] 344 fn test_parse_allowed_scopes_variants() { 345 let scopes = parse_allowed_scopes("{a, b c}").unwrap(); 346 assert_eq!(scopes, vec!["a", "b", "c"]); 347 348 let scopes = parse_allowed_scopes(" a b c ").unwrap(); 349 assert_eq!(scopes, vec!["a", "b", "c"]); 350 351 let scopes = parse_allowed_scopes("a,b,c").unwrap(); 352 assert_eq!(scopes, vec!["a", "b", "c"]); 353 } 354 355 #[test] 356 fn test_parse_allowed_scopes_empty_err() { 357 assert!(parse_allowed_scopes("").is_err()); 358 assert!(parse_allowed_scopes(" ").is_err()); 359 assert!(parse_allowed_scopes("{}").is_err()); 360 assert!(parse_allowed_scopes("{ }").is_err()); 361 } 362 363 #[test] 364 fn test_parse_bracketed_list_valid() { 365 let items = parse_bracketed_list("{a, b, c}", "TEST").unwrap(); 366 assert_eq!(items, vec!["a", "b", "c"]); 367 } 368 369 #[test] 370 fn test_parse_bracketed_list_single_item() { 371 let items = parse_bracketed_list("{ES256}", "TEST").unwrap(); 372 assert_eq!(items, vec!["ES256"]); 373 } 374 375 #[test] 376 fn test_parse_bracketed_list_extra_whitespace() { 377 let items = parse_bracketed_list("{ a , b }", "TEST").unwrap(); 378 assert_eq!(items, vec!["a", "b"]); 379 } 380 381 #[test] 382 fn test_parse_bracketed_list_missing_braces() { 383 let result = parse_bracketed_list("a, b", "TEST"); 384 assert!(result.is_err()); 385 let err = result.unwrap_err().to_string(); 386 assert!(err.contains("TEST")); 387 } 388 389 #[test] 390 fn test_parse_bracketed_list_empty() { 391 let items = parse_bracketed_list("{}", "TEST").unwrap(); 392 assert!(items.is_empty()); 393 } 394 }