From 26c96b41dd2a0f1b2fd050d17c96462f051af390 Mon Sep 17 00:00:00 2001 From: Agent Zero Date: Sun, 22 Mar 2026 03:18:08 +0000 Subject: [PATCH] Fix MCP transport compatibility and batch_store e2e coverage --- README.md | 49 +++++++++++++-- src/auth.rs | 20 ++++++ src/lib.rs | 2 +- src/transport.rs | 161 ++++++++++++++++++++++++++++++++++++++++++++--- tests/e2e_mcp.rs | 112 +++++++++++++++++++++++++++++++++ 5 files changed, 329 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 9a039d1..c254d55 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with - 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search - 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2 - 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing -- 🔌 **MCP Protocol**: Standard Model Context Protocol over SSE transport +- 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility - 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent - ⚡ **High Performance**: Rust implementation with async I/O @@ -103,14 +103,53 @@ Recommended target file in A0: ## MCP Integration -Connect to the server using SSE transport: +OpenBrain exposes both MCP HTTP transports: ``` -SSE Endpoint: http://localhost:3100/mcp/sse -Message Endpoint: http://localhost:3100/mcp/message +Streamable HTTP Endpoint: http://localhost:3100/mcp +Legacy SSE Endpoint: http://localhost:3100/mcp/sse +Legacy Message Endpoint: http://localhost:3100/mcp/message Health Check: http://localhost:3100/mcp/health ``` +Use the streamable HTTP endpoint for modern clients such as Codex. Keep the +legacy SSE endpoints for older MCP clients that still use the deprecated +2024-11-05 HTTP+SSE transport. + +Header roles: +- `X-Agent-ID` is the memory namespace. Keep this stable if multiple clients + should share the same OpenBrain memories. +- `X-Agent-Type` is an optional client profile label for logging and config + clarity, such as `agent-zero` or `codex`. + +### Example: Codex Configuration + +```toml +[mcp_servers.openbrain] +url = "https://ob.ingwaz.work/mcp" +http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbrain", "X-Agent-Type" = "codex" } +``` + +### Example: Agent Zero Configuration + +```json +{ + "mcpServers": { + "openbrain": { + "url": "https://ob.ingwaz.work/mcp/sse", + "headers": { + "X-API-Key": "YOUR_OPENBRAIN_API_KEY", + "X-Agent-ID": "openbrain", + "X-Agent-Type": "agent-zero" + } + } + } +} +``` + +Agent Zero should keep using the legacy HTTP+SSE transport unless and until its +client runtime supports streamable HTTP. Codex should use `/mcp`. + ### Example: Store a Memory ```json @@ -180,7 +219,7 @@ Health Check: http://localhost:3100/mcp/health ┌─────────────────────────────────────────────────────────┐ │ AI Agent │ └─────────────────────┬───────────────────────────────────┘ - │ MCP Protocol (SSE) + │ MCP Protocol (Streamable HTTP / Legacy SSE) ┌─────────────────────▼───────────────────────────────────┐ │ OpenBrain MCP Server │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ diff --git a/src/auth.rs b/src/auth.rs index ffba822..8238c01 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -90,6 +90,15 @@ pub fn get_optional_agent_id(headers: &HeaderMap) -> Option { .map(ToOwned::to_owned) } +pub fn get_optional_agent_type(headers: &HeaderMap) -> Option { + headers + .get("X-Agent-Type") + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) +} + /// Extract agent ID from request headers or default pub fn get_agent_id(request: &Request) -> String { get_optional_agent_id(request.headers()) @@ -122,4 +131,15 @@ mod tests { Some("agent-zero") ); } + + #[test] + fn extracts_agent_type_from_header() { + let mut headers = HeaderMap::new(); + headers.insert("X-Agent-Type", HeaderValue::from_static("codex")); + + assert_eq!( + get_optional_agent_type(&headers).as_deref(), + Some("codex") + ); + } } diff --git a/src/lib.rs b/src/lib.rs index 26e8e90..808f9c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,7 +130,7 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> { let app = Router::new() .merge(health_router) - .nest("/mcp", mcp_router) + .merge(mcp_router) .layer(TraceLayer::new_for_http()) .layer( CorsLayer::new() diff --git a/src/transport.rs b/src/transport.rs index 12b2414..72ffec7 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -1,10 +1,11 @@ -//! SSE Transport for MCP Protocol +//! HTTP transport for MCP Protocol. //! -//! Implements Server-Sent Events transport for the Model Context Protocol. +//! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP +//! transport on the same server. use axum::{ extract::{Query, State}, - http::{HeaderMap, StatusCode}, + http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}}, response::{ IntoResponse, Response, sse::{Event, KeepAlive, Sse}, @@ -162,16 +163,131 @@ struct PostMessageQuery { /// Create the MCP router pub fn mcp_router(state: Arc) -> Router { Router::new() - .route("/sse", get(sse_handler)) - .route("/message", post(message_handler)) - .route("/health", get(health_handler)) + .route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler)) + .route("/mcp/sse", get(sse_handler)) + .route("/mcp/message", post(message_handler)) + .route("/mcp/health", get(health_handler)) .with_state(state) } +fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> { + let Some(origin) = headers + .get(ORIGIN) + .and_then(|value| value.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + return Ok(()); + }; + + if origin.eq_ignore_ascii_case("null") { + warn!("Rejected MCP request with null origin"); + return Err(StatusCode::FORBIDDEN); + } + + let origin_uri = origin.parse::().map_err(|_| { + warn!("Rejected MCP request with invalid origin header: {}", origin); + StatusCode::FORBIDDEN + })?; + + let origin_host = origin_uri.host().ok_or_else(|| { + warn!("Rejected MCP request without origin host: {}", origin); + StatusCode::FORBIDDEN + })?; + + let request_host = headers + .get("X-Forwarded-Host") + .or_else(|| headers.get(HOST)) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.split(',').next()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .ok_or_else(|| { + warn!("Rejected MCP request without host header for origin {}", origin); + StatusCode::FORBIDDEN + })?; + + let origin_authority = origin_uri + .authority() + .map(|authority| authority.as_str()) + .unwrap_or(origin_host); + let origin_with_default_port = origin_uri + .port_u16() + .or_else(|| match origin_uri.scheme_str() { + Some("https") => Some(443), + Some("http") => Some(80), + _ => None, + }) + .map(|port| format!("{origin_host}:{port}")); + + if request_host.eq_ignore_ascii_case(origin_host) + || request_host.eq_ignore_ascii_case(origin_authority) + || origin_with_default_port + .as_deref() + .is_some_and(|value| request_host.eq_ignore_ascii_case(value)) + { + Ok(()) + } else { + warn!( + "Rejected MCP request due to origin/host mismatch: origin={}, host={}", + origin, request_host + ); + Err(StatusCode::FORBIDDEN) + } +} + +async fn streamable_post_handler( + State(state): State>, + headers: HeaderMap, + Json(request): Json, +) -> Response { + if let Err(status) = validate_origin(&headers) { + return status.into_response(); + } + + info!( + method = %request.method, + agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), + agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), + "Received streamable MCP request" + ); + + let request = apply_request_context(request, &headers); + let response = dispatch_request(&state, &request).await; + + match response { + Some(response) => Json(response).into_response(), + None => StatusCode::ACCEPTED.into_response(), + } +} + +async fn streamable_get_handler(headers: HeaderMap) -> Response { + if let Err(status) = validate_origin(&headers) { + return status.into_response(); + } + + StatusCode::METHOD_NOT_ALLOWED.into_response() +} + +async fn streamable_delete_handler(headers: HeaderMap) -> Response { + if let Err(status) = validate_origin(&headers) { + return status.into_response(); + } + + StatusCode::METHOD_NOT_ALLOWED.into_response() +} + /// SSE endpoint for streaming events async fn sse_handler( State(state): State>, -) -> Sse>> { + headers: HeaderMap, +) -> Result>>, StatusCode> { + validate_origin(&headers)?; + info!( + agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), + agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), + "Opening legacy SSE MCP stream" + ); let mut broadcast_rx = state.event_tx.subscribe(); let (session_tx, mut session_rx) = mpsc::channel(32); let session_id = Uuid::new_v4().to_string(); @@ -222,7 +338,7 @@ async fn sse_handler( } }; - Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))) + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))) } /// Message endpoint for JSON-RPC requests @@ -232,7 +348,16 @@ async fn message_handler( headers: HeaderMap, Json(request): Json, ) -> Response { - info!("Received MCP request: {}", request.method); + if let Err(status) = validate_origin(&headers) { + return status.into_response(); + } + + info!( + method = %request.method, + agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), + agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), + "Received legacy SSE MCP request" + ); if let Some(session_id) = query.session_id.as_deref() { if !state.has_session(session_id).await { @@ -513,4 +638,22 @@ mod tests { "/mcp/message?sessionId=abc123" ); } + + #[test] + fn allows_matching_origin_and_host() { + let mut headers = HeaderMap::new(); + headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap()); + headers.insert(HOST, "ob.ingwaz.work".parse().unwrap()); + + assert!(validate_origin(&headers).is_ok()); + } + + #[test] + fn rejects_mismatched_origin_and_host() { + let mut headers = HeaderMap::new(); + headers.insert(ORIGIN, "https://evil.example".parse().unwrap()); + headers.insert(HOST, "ob.ingwaz.work".parse().unwrap()); + + assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN)); + } } diff --git a/tests/e2e_mcp.rs b/tests/e2e_mcp.rs index 8b7f844..8003a55 100644 --- a/tests/e2e_mcp.rs +++ b/tests/e2e_mcp.rs @@ -110,6 +110,26 @@ async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> V .expect("JSON-RPC response body") } +async fn call_streamable_jsonrpc( + client: &reqwest::Client, + base: &str, + request: Value, +) -> reqwest::Response { + let mut req_builder = client + .post(format!("{base}/mcp")) + .header("Accept", "application/json, text/event-stream") + .json(&request); + + if let Some(key) = api_key() { + req_builder = req_builder.header("X-API-Key", key); + } + + req_builder + .send() + .await + .expect("streamable JSON-RPC HTTP request") +} + /// Make an authenticated GET request to an MCP endpoint async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response { let mut req_builder = client.get(format!("{base}{path}")); @@ -357,6 +377,98 @@ async fn e2e_transport_tools_list_and_unknown_method() { ); } +#[tokio::test] +async fn e2e_streamable_initialize_and_tools_list() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let initialize_response: Value = call_streamable_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "streamable-init-1", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "e2e-client", + "version": "0.1.0" + } + } + }), + ) + .await + .json() + .await + .expect("streamable initialize JSON"); + + assert_eq!( + initialize_response + .get("result") + .and_then(|value| value.get("protocolVersion")) + .and_then(Value::as_str), + Some("2024-11-05") + ); + + let tools_list_response: Value = call_streamable_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "streamable-tools-list-1", + "method": "tools/list", + "params": {} + }), + ) + .await + .json() + .await + .expect("streamable tools/list JSON"); + + assert!( + tools_list_response + .get("result") + .and_then(|value| value.get("tools")) + .and_then(Value::as_array) + .map(|tools| !tools.is_empty()) + .unwrap_or(false), + "streamable /mcp tools/list should return tool definitions" + ); +} + +#[tokio::test] +async fn e2e_streamable_get_returns_405() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let mut request = client + .get(format!("{base}/mcp")) + .header("Accept", "text/event-stream"); + + if let Some(key) = api_key() { + request = request.header("X-API-Key", key); + } + + let response = request.send().await.expect("GET /mcp"); + assert_eq!( + response.status(), + reqwest::StatusCode::METHOD_NOT_ALLOWED, + "streamable GET /mcp should explicitly return 405 when standalone SSE streams are not offered" + ); +} + #[tokio::test] async fn e2e_purge_requires_confirm_flag() { let base = base_url();