main.rs (2844B)
1 use anyhow::Result; 2 use axum::{ 3 Router, 4 routing::{get, post}, 5 }; 6 use clap::Parser; 7 use oauth2_gateway::{config::Config, db, handlers, state::AppState}; 8 use std::{fs, os::unix::fs::PermissionsExt}; 9 use tower_http::trace::TraceLayer; 10 use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 11 12 #[derive(Parser, Debug)] 13 #[command(version)] 14 struct Args { 15 /// Configuration 16 #[arg(short = 'c', long = "config", value_name = "FILE")] 17 config: String, 18 } 19 20 #[tokio::main] 21 async fn main() -> Result<()> { 22 // Init logging, tracing 23 tracing_subscriber::registry() 24 .with( 25 tracing_subscriber::EnvFilter::try_from_default_env() 26 .unwrap_or_else(|_| "oauth2_gateway=info,tower_http=info,sqlx=warn".into()), 27 ) 28 .with( 29 tracing_subscriber::fmt::layer() 30 .compact() 31 .with_ansi(false) 32 .with_timer(tracing_subscriber::fmt::time::LocalTime::rfc_3339()), 33 ) 34 .init(); 35 36 let args = Args::parse(); 37 38 tracing::info!("Starting OAuth2 Gateway v{}", env!("CARGO_PKG_VERSION")); 39 tracing::info!("Loading configuration from: {}", args.config); 40 41 let config = Config::from_file(&args.config)?; 42 43 tracing::info!("Connecting to database: {}", config.database.url); 44 let pool = db::create_pool(&config.database.url).await?; 45 46 let state = AppState::new(config.clone(), pool); 47 48 let app = Router::new() 49 .route("/health", get(handlers::health_check)) 50 .route("/setup/{client_id}", post(handlers::setup)) 51 .route("/authorize/{nonce}", get(handlers::authorize)) 52 .route("/token", post(handlers::token)) 53 .route("/info", get(handlers::info)) 54 .route("/notification", post(handlers::notification_webhook)) 55 .layer(TraceLayer::new_for_http()) 56 .with_state(state); 57 58 if config.server.is_unix_socket() { 59 let socket_path = config.server.socket_path.as_ref().unwrap(); 60 61 if std::path::Path::new(socket_path).exists() { 62 tracing::warn!("Removing existing socket file: {}", socket_path); 63 std::fs::remove_file(socket_path)?; 64 } 65 66 let listener = tokio::net::UnixListener::bind(socket_path)?; 67 let permissions = std::fs::Permissions::from_mode(0o766); 68 let _ = fs::set_permissions(socket_path, permissions); 69 70 tracing::info!("Server listening on Unix socket: {}", socket_path); 71 72 axum::serve(listener, app).await?; 73 } else { 74 let host = config.server.host.as_ref().unwrap(); 75 let port = config.server.port.unwrap(); 76 let addr = format!("{}:{}", host, port); 77 78 let listener = tokio::net::TcpListener::bind(&addr).await?; 79 tracing::info!("Server listening on {}", addr); 80 81 axum::serve(listener, app).await?; 82 } 83 84 Ok(()) 85 }