//! Authentication module for OpenBrain MCP //! //! Provides API key-based authentication for securing the MCP endpoints. use axum::{ extract::{Request, State}, http::{HeaderMap, StatusCode, header::AUTHORIZATION}, middleware::Next, response::Response, }; use sha2::{Digest, Sha256}; use std::sync::Arc; use tracing::warn; use crate::AppState; /// Hash an API key for secure comparison pub fn hash_api_key(key: &str) -> String { let mut hasher = Sha256::new(); hasher.update(key.as_bytes()); hex::encode(hasher.finalize()) } /// Middleware for API key authentication pub async fn auth_middleware( State(state): State>, request: Request, next: Next, ) -> Result { // Skip auth if disabled if !state.config.auth.enabled { return Ok(next.run(request).await); } let api_key = extract_api_key(request.headers()); match api_key { Some(key) => { // Check if key is valid let key_hash = hash_api_key(&key); let valid = state .config .auth .api_keys .iter() .any(|k| hash_api_key(k) == key_hash); if valid { Ok(next.run(request).await) } else { warn!("Invalid API key or bearer token attempted"); Err(StatusCode::UNAUTHORIZED) } } None => { warn!("Missing API key or bearer token in request"); Err(StatusCode::UNAUTHORIZED) } } } fn extract_api_key(headers: &HeaderMap) -> Option { headers .get("X-API-Key") .and_then(|v| v.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) .or_else(|| { headers .get(AUTHORIZATION) .and_then(|v| v.to_str().ok()) .and_then(|value| { let (scheme, token) = value.split_once(' ')?; scheme .eq_ignore_ascii_case("bearer") .then_some(token.trim()) }) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) }) } pub fn get_optional_agent_id(headers: &HeaderMap) -> Option { headers .get("X-Agent-ID") .and_then(|v| v.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) } pub fn get_optional_agent_type(headers: &HeaderMap) -> Option { headers .get("X-Agent-Type") .and_then(|v| v.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) } /// Extract agent ID from request headers or default pub fn get_agent_id(request: &Request) -> String { get_optional_agent_id(request.headers()) .unwrap_or_else(|| "default".to_string()) } #[cfg(test)] mod tests { use super::*; use axum::http::{HeaderValue, header::AUTHORIZATION}; #[test] fn extracts_api_key_from_bearer_header() { let mut headers = HeaderMap::new(); headers.insert( AUTHORIZATION, HeaderValue::from_static("Bearer test-token"), ); assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token")); } #[test] fn extracts_agent_id_from_header() { let mut headers = HeaderMap::new(); headers.insert("X-Agent-ID", HeaderValue::from_static("agent-zero")); assert_eq!( get_optional_agent_id(&headers).as_deref(), Some("agent-zero") ); } #[test] fn extracts_agent_type_from_header() { let mut headers = HeaderMap::new(); headers.insert("X-Agent-Type", HeaderValue::from_static("codex")); assert_eq!( get_optional_agent_type(&headers).as_deref(), Some("codex") ); } }