Files
openbrain-mcp/src/auth.rs

146 lines
3.9 KiB
Rust

//! Authentication module for OpenBrain MCP
//!
//! Provides API key-based authentication for securing the MCP endpoints.
use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode, header::AUTHORIZATION},
middleware::Next,
response::Response,
};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use tracing::warn;
use crate::AppState;
/// Hash an API key for secure comparison
pub fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
hex::encode(hasher.finalize())
}
/// Middleware for API key authentication
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Skip auth if disabled
if !state.config.auth.enabled {
return Ok(next.run(request).await);
}
let api_key = extract_api_key(request.headers());
match api_key {
Some(key) => {
// Check if key is valid
let key_hash = hash_api_key(&key);
let valid = state
.config
.auth
.api_keys
.iter()
.any(|k| hash_api_key(k) == key_hash);
if valid {
Ok(next.run(request).await)
} else {
warn!("Invalid API key or bearer token attempted");
Err(StatusCode::UNAUTHORIZED)
}
}
None => {
warn!("Missing API key or bearer token in request");
Err(StatusCode::UNAUTHORIZED)
}
}
}
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.or_else(|| {
headers
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|value| {
let (scheme, token) = value.split_once(' ')?;
scheme
.eq_ignore_ascii_case("bearer")
.then_some(token.trim())
})
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
})
}
pub fn get_optional_agent_id(headers: &HeaderMap) -> Option<String> {
headers
.get("X-Agent-ID")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
headers
.get("X-Agent-Type")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
/// Extract agent ID from request headers or default
pub fn get_agent_id(request: &Request) -> String {
get_optional_agent_id(request.headers())
.unwrap_or_else(|| "default".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{HeaderValue, header::AUTHORIZATION};
#[test]
fn extracts_api_key_from_bearer_header() {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer test-token"),
);
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
}
#[test]
fn extracts_agent_id_from_header() {
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("agent-zero"));
assert_eq!(
get_optional_agent_id(&headers).as_deref(),
Some("agent-zero")
);
}
#[test]
fn extracts_agent_type_from_header() {
let mut headers = HeaderMap::new();
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
assert_eq!(
get_optional_agent_type(&headers).as_deref(),
Some("codex")
);
}
}