main.rs (3349B)
1 use anyhow::Result; 2 use axum::{ 3 Router, 4 routing::{get, post}, 5 }; 6 use clap::Parser; 7 use kych_oauth2_gateway_lib::{config::Config, db, handlers, state::AppState}; 8 use std::{fs, os::unix::fs::PermissionsExt}; 9 use tower_http::{services::ServeDir, trace::TraceLayer}; 10 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 11 12 #[derive(Parser, Debug)] 13 #[command(version)] 14 struct Args { 15 #[arg(short = 'c', long = "config", value_name = "FILE")] 16 config: String, 17 18 #[arg(short = 'L', long = "log-level", value_name = "LEVEL", default_value = "INFO")] 19 log_level: String, 20 } 21 22 #[tokio::main] 23 async fn main() -> Result<()> { 24 let args = Args::parse(); 25 26 let level = args.log_level.to_lowercase(); 27 let filter = format!( 28 "kych_oauth2_gateway={},kych_oauth2_gateway_lib={},tower_http={},sqlx=warn", 29 level, level, level 30 ); 31 32 tracing_subscriber::registry() 33 .with( 34 tracing_subscriber::EnvFilter::try_from_default_env() 35 .unwrap_or_else(|_| filter.into()), 36 ) 37 .with( 38 tracing_subscriber::fmt::layer() 39 .compact() 40 .with_ansi(false) 41 .with_timer(tracing_subscriber::fmt::time::LocalTime::rfc_3339()), 42 ) 43 .init(); 44 45 tracing::info!("Starting Kych OAuth2 Gateway v{}", env!("CARGO_PKG_VERSION")); 46 tracing::info!("Loading configuration from: {}", args.config); 47 48 let config = Config::from_file(&args.config)?; 49 50 tracing::info!("Connecting to database: {}", config.database.url); 51 let pool = db::create_pool(&config.database.url).await?; 52 53 let state = AppState::new(config.clone(), pool); 54 55 let app = Router::new() 56 .route("/config", get(handlers::config)) 57 .route("/setup/{client_id}", post(handlers::setup)) 58 .route("/authorize/{nonce}", get(handlers::authorize)) 59 .route("/token", post(handlers::token)) 60 .route("/info", get(handlers::info)) 61 .route("/notification", post(handlers::notification_webhook)) 62 .route("/status/{verification_id}", get(handlers::status)) 63 .route("/finalize/{verification_id}", get(handlers::finalize)) 64 .nest_service("/js", ServeDir::new("js")) 65 .layer(TraceLayer::new_for_http()) 66 .with_state(state); 67 68 if config.server.is_unix_socket() { 69 let socket_path = config.server.socket_path.as_ref().unwrap(); 70 let socket_mode = config.server.socket_mode; 71 72 if std::path::Path::new(socket_path).exists() { 73 tracing::warn!("Removing left-over `{}' from previous execution", socket_path); 74 std::fs::remove_file(socket_path)?; 75 } 76 77 let listener = tokio::net::UnixListener::bind(socket_path)?; 78 let permissions = std::fs::Permissions::from_mode(socket_mode); 79 fs::set_permissions(socket_path, permissions)?; 80 tracing::info!("set socket '{}' to mode {:o}", socket_path, socket_mode); 81 82 axum::serve(listener, app).await?; 83 } else { 84 let host = config.server.host.as_ref().unwrap(); 85 let port = config.server.port.unwrap(); 86 let addr = format!("{}:{}", host, port); 87 88 let listener = tokio::net::TcpListener::bind(&addr).await?; 89 tracing::info!("Server listening on {}", addr); 90 91 axum::serve(listener, app).await?; 92 } 93 94 Ok(()) 95 }