//! HTTP transport for MCP Protocol. //! //! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP //! transport on the same server. use axum::{ extract::{Query, State}, http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}}, response::{ IntoResponse, Response, sse::{Event, KeepAlive, Sse}, }, routing::{get, post}, Json, Router, }; use futures::stream::Stream; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, convert::Infallible, sync::Arc, time::Duration, }; use tokio::sync::{RwLock, broadcast, mpsc}; use tracing::{error, info, warn}; use uuid::Uuid; use crate::{AppState, auth, tools}; type SessionStore = RwLock>>; /// MCP Server State pub struct McpState { pub app: Arc, pub event_tx: broadcast::Sender, sessions: SessionStore, } impl McpState { pub fn new(app: Arc) -> Arc { let (event_tx, _) = broadcast::channel(100); Arc::new(Self { app, event_tx, sessions: RwLock::new(HashMap::new()), }) } async fn insert_session( &self, session_id: String, tx: mpsc::Sender, ) { self.sessions.write().await.insert(session_id, tx); } async fn remove_session(&self, session_id: &str) { self.sessions.write().await.remove(session_id); } async fn has_session(&self, session_id: &str) -> bool { self.sessions.read().await.contains_key(session_id) } async fn send_to_session( &self, session_id: &str, response: &JsonRpcResponse, ) -> Result<(), SessionSendError> { let tx = { let sessions = self.sessions.read().await; sessions .get(session_id) .cloned() .ok_or(SessionSendError::NotFound)? }; let payload = serde_json::to_value(response).expect("serializing JSON-RPC response should succeed"); if tx.send(payload).await.is_err() { self.remove_session(session_id).await; return Err(SessionSendError::Closed); } Ok(()) } } enum SessionSendError { NotFound, Closed, } struct SessionGuard { state: Arc, session_id: String, } impl SessionGuard { fn new(state: Arc, session_id: String) -> Self { Self { state, session_id } } } impl Drop for SessionGuard { fn drop(&mut self) { let state = self.state.clone(); let session_id = self.session_id.clone(); if let Ok(handle) = tokio::runtime::Handle::try_current() { handle.spawn(async move { state.remove_session(&session_id).await; }); } } } /// MCP Event for SSE streaming #[derive(Clone, Debug, Serialize)] pub struct McpEvent { pub id: String, pub event_type: String, pub data: serde_json::Value, } /// MCP JSON-RPC Request #[derive(Debug, Deserialize)] pub struct JsonRpcRequest { pub jsonrpc: String, #[serde(default)] pub id: Option, pub method: String, #[serde(default)] pub params: serde_json::Value, } /// MCP JSON-RPC Response #[derive(Debug, Serialize)] pub struct JsonRpcResponse { pub jsonrpc: String, pub id: serde_json::Value, #[serde(skip_serializing_if = "Option::is_none")] pub result: Option, #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } #[derive(Debug, Serialize)] pub struct JsonRpcError { pub code: i32, pub message: String, #[serde(skip_serializing_if = "Option::is_none")] pub data: Option, } #[derive(Debug, Default, Deserialize)] #[serde(rename_all = "camelCase")] struct PostMessageQuery { #[serde(default)] session_id: Option, } /// Create the MCP router pub fn mcp_router(state: Arc) -> Router { Router::new() .route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler)) .route("/mcp/sse", get(sse_handler)) .route("/mcp/message", post(message_handler)) .route("/mcp/health", get(health_handler)) .with_state(state) } fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> { let Some(origin) = headers .get(ORIGIN) .and_then(|value| value.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) else { return Ok(()); }; if origin.eq_ignore_ascii_case("null") { warn!("Rejected MCP request with null origin"); return Err(StatusCode::FORBIDDEN); } let origin_uri = origin.parse::().map_err(|_| { warn!("Rejected MCP request with invalid origin header: {}", origin); StatusCode::FORBIDDEN })?; let origin_host = origin_uri.host().ok_or_else(|| { warn!("Rejected MCP request without origin host: {}", origin); StatusCode::FORBIDDEN })?; let request_host = headers .get("X-Forwarded-Host") .or_else(|| headers.get(HOST)) .and_then(|value| value.to_str().ok()) .and_then(|value| value.split(',').next()) .map(str::trim) .filter(|value| !value.is_empty()) .ok_or_else(|| { warn!("Rejected MCP request without host header for origin {}", origin); StatusCode::FORBIDDEN })?; let origin_authority = origin_uri .authority() .map(|authority| authority.as_str()) .unwrap_or(origin_host); let origin_with_default_port = origin_uri .port_u16() .or_else(|| match origin_uri.scheme_str() { Some("https") => Some(443), Some("http") => Some(80), _ => None, }) .map(|port| format!("{origin_host}:{port}")); if request_host.eq_ignore_ascii_case(origin_host) || request_host.eq_ignore_ascii_case(origin_authority) || origin_with_default_port .as_deref() .is_some_and(|value| request_host.eq_ignore_ascii_case(value)) { Ok(()) } else { warn!( "Rejected MCP request due to origin/host mismatch: origin={}, host={}", origin, request_host ); Err(StatusCode::FORBIDDEN) } } async fn streamable_post_handler( State(state): State>, headers: HeaderMap, Json(request): Json, ) -> Response { if let Err(status) = validate_origin(&headers) { return status.into_response(); } info!( method = %request.method, agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), "Received streamable MCP request" ); let request = apply_request_context(request, &headers); let response = dispatch_request(&state, &request).await; match response { Some(response) => Json(response).into_response(), None => StatusCode::ACCEPTED.into_response(), } } async fn streamable_get_handler(headers: HeaderMap) -> Response { if let Err(status) = validate_origin(&headers) { return status.into_response(); } StatusCode::METHOD_NOT_ALLOWED.into_response() } async fn streamable_delete_handler(headers: HeaderMap) -> Response { if let Err(status) = validate_origin(&headers) { return status.into_response(); } StatusCode::METHOD_NOT_ALLOWED.into_response() } /// SSE endpoint for streaming events async fn sse_handler( State(state): State>, headers: HeaderMap, ) -> Result>>, StatusCode> { validate_origin(&headers)?; info!( agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), "Opening legacy SSE MCP stream" ); let mut broadcast_rx = state.event_tx.subscribe(); let (session_tx, mut session_rx) = mpsc::channel(32); let session_id = Uuid::new_v4().to_string(); let endpoint = session_message_endpoint(&session_id); state.insert_session(session_id.clone(), session_tx).await; let stream = async_stream::stream! { let _session_guard = SessionGuard::new(state.clone(), session_id.clone()); // Send endpoint event (required by MCP SSE protocol) // This tells the client where to POST JSON-RPC messages for this session. yield Ok(Event::default() .event("endpoint") .data(endpoint)); loop { tokio::select! { maybe_message = session_rx.recv() => { match maybe_message { Some(message) => { yield Ok(Event::default() .event("message") .json_data(&message) .unwrap()); } None => break, } } event = broadcast_rx.recv() => { match event { Ok(event) => { yield Ok(Event::default() .event(&event.event_type) .id(&event.id) .json_data(&event.data) .unwrap()); } Err(broadcast::error::RecvError::Lagged(n)) => { warn!("SSE client lagged, missed {} events", n); } Err(broadcast::error::RecvError::Closed) => { break; } } } } } }; Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))) } /// Message endpoint for JSON-RPC requests async fn message_handler( State(state): State>, Query(query): Query, headers: HeaderMap, Json(request): Json, ) -> Response { if let Err(status) = validate_origin(&headers) { return status.into_response(); } info!( method = %request.method, agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), "Received legacy SSE MCP request" ); if let Some(session_id) = query.session_id.as_deref() { if !state.has_session(session_id).await { return StatusCode::NOT_FOUND.into_response(); } } let request = apply_request_context(request, &headers); let response = dispatch_request(&state, &request).await; match query.session_id.as_deref() { Some(session_id) => route_session_response(&state, session_id, response).await, None => match response { Some(response) => Json(response).into_response(), None => StatusCode::ACCEPTED.into_response(), }, } } /// Health check endpoint async fn health_handler() -> Json { Json(serde_json::json!({ "status": "healthy", "server": "openbrain-mcp", "version": env!("CARGO_PKG_VERSION") })) } fn session_message_endpoint(session_id: &str) -> String { format!("/mcp/message?sessionId={session_id}") } async fn route_session_response( state: &Arc, session_id: &str, response: Option, ) -> Response { match response { Some(response) => match state.send_to_session(session_id, &response).await { Ok(()) => StatusCode::ACCEPTED.into_response(), Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(), Err(SessionSendError::Closed) => StatusCode::GONE.into_response(), }, None => StatusCode::ACCEPTED.into_response(), } } fn apply_request_context( mut request: JsonRpcRequest, headers: &HeaderMap, ) -> JsonRpcRequest { if let Some(agent_id) = auth::get_optional_agent_id(headers) { inject_agent_id(&mut request, &agent_id); } request } fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) { if request.method != "tools/call" { return; } if !request.params.is_object() { request.params = serde_json::json!({}); } let params = request .params .as_object_mut() .expect("params should be an object"); let arguments = params .entry("arguments") .or_insert_with(|| serde_json::json!({})); if !arguments.is_object() { *arguments = serde_json::json!({}); } arguments .as_object_mut() .expect("arguments should be an object") .entry("agent_id".to_string()) .or_insert_with(|| serde_json::json!(agent_id)); } async fn dispatch_request( state: &Arc, request: &JsonRpcRequest, ) -> Option { match request.method.as_str() { "initialize" => request .id .clone() .map(|id| success_response(id, initialize_result())), "ping" => request .id .clone() .map(|id| success_response(id, serde_json::json!({}))), "tools/list" => request.id.clone().map(|id| { success_response( id, serde_json::json!({ "tools": tools::get_tool_definitions() }), ) }), "tools/call" => handle_tools_call(state, request).await, "notifications/initialized" => { info!("Received MCP initialized notification"); None } _ => request.id.clone().map(|id| { error_response( id, -32601, format!("Method not found: {}", request.method), None, ) }), } } fn initialize_result() -> serde_json::Value { serde_json::json!({ "protocolVersion": "2024-11-05", "serverInfo": { "name": "openbrain-mcp", "version": env!("CARGO_PKG_VERSION") }, "capabilities": { "tools": { "listChanged": false } } }) } fn success_response( id: serde_json::Value, result: serde_json::Value, ) -> JsonRpcResponse { JsonRpcResponse { jsonrpc: "2.0".to_string(), id, result: Some(result), error: None, } } fn error_response( id: serde_json::Value, code: i32, message: String, data: Option, ) -> JsonRpcResponse { JsonRpcResponse { jsonrpc: "2.0".to_string(), id, result: None, error: Some(JsonRpcError { code, message, data, }), } } /// Handle tools/call request async fn handle_tools_call( state: &Arc, request: &JsonRpcRequest, ) -> Option { let params = &request.params; let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or(""); let arguments = params .get("arguments") .cloned() .unwrap_or(serde_json::json!({})); let result = match tools::execute_tool(&state.app, tool_name, arguments).await { Ok(result) => Ok(serde_json::json!({ "content": [{ "type": "text", "text": result }] })), Err(e) => { let full_error = format!("{:#}", e); error!("Tool execution error: {}", full_error); Err(JsonRpcError { code: -32000, message: full_error, data: None, }) } }; match (request.id.clone(), result) { (Some(id), Ok(result)) => Some(success_response(id, result)), (Some(id), Err(error)) => Some(JsonRpcResponse { jsonrpc: "2.0".to_string(), id, result: None, error: Some(error), }), (None, Ok(_)) => None, (None, Err(error)) => { error!( "Tool execution failed for notification '{}': {}", request.method, error.message ); None } } } #[cfg(test)] mod tests { use super::*; #[test] fn injects_agent_id_when_missing_from_tool_arguments() { let mut request = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(serde_json::json!("1")), method: "tools/call".to_string(), params: serde_json::json!({ "name": "query", "arguments": { "query": "editor preferences" } }), }; inject_agent_id(&mut request, "agent-from-header"); assert_eq!( request .params .get("arguments") .and_then(|value| value.get("agent_id")) .and_then(|value| value.as_str()), Some("agent-from-header") ); } #[test] fn preserves_explicit_agent_id() { let mut request = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(serde_json::json!("1")), method: "tools/call".to_string(), params: serde_json::json!({ "name": "query", "arguments": { "agent_id": "explicit-agent" } }), }; inject_agent_id(&mut request, "agent-from-header"); assert_eq!( request .params .get("arguments") .and_then(|value| value.get("agent_id")) .and_then(|value| value.as_str()), Some("explicit-agent") ); } #[test] fn session_endpoint_uses_camel_case_query_param() { assert_eq!( session_message_endpoint("abc123"), "/mcp/message?sessionId=abc123" ); } #[test] fn allows_matching_origin_and_host() { let mut headers = HeaderMap::new(); headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap()); headers.insert(HOST, "ob.ingwaz.work".parse().unwrap()); assert!(validate_origin(&headers).is_ok()); } #[test] fn rejects_mismatched_origin_and_host() { let mut headers = HeaderMap::new(); headers.insert(ORIGIN, "https://evil.example".parse().unwrap()); headers.insert(HOST, "ob.ingwaz.work".parse().unwrap()); assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN)); } }