Fix MCP transport compatibility and batch_store e2e coverage

This commit is contained in:
Agent Zero
2026-03-22 03:18:08 +00:00
parent d7140eb7f3
commit 26c96b41dd
5 changed files with 329 additions and 15 deletions

View File

@@ -9,7 +9,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
- 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search - 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2 - 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing - 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
- 🔌 **MCP Protocol**: Standard Model Context Protocol over SSE transport - 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility
- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent - 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent
-**High Performance**: Rust implementation with async I/O -**High Performance**: Rust implementation with async I/O
@@ -103,14 +103,53 @@ Recommended target file in A0:
## MCP Integration ## MCP Integration
Connect to the server using SSE transport: OpenBrain exposes both MCP HTTP transports:
``` ```
SSE Endpoint: http://localhost:3100/mcp/sse Streamable HTTP Endpoint: http://localhost:3100/mcp
Message Endpoint: http://localhost:3100/mcp/message Legacy SSE Endpoint: http://localhost:3100/mcp/sse
Legacy Message Endpoint: http://localhost:3100/mcp/message
Health Check: http://localhost:3100/mcp/health Health Check: http://localhost:3100/mcp/health
``` ```
Use the streamable HTTP endpoint for modern clients such as Codex. Keep the
legacy SSE endpoints for older MCP clients that still use the deprecated
2024-11-05 HTTP+SSE transport.
Header roles:
- `X-Agent-ID` is the memory namespace. Keep this stable if multiple clients
should share the same OpenBrain memories.
- `X-Agent-Type` is an optional client profile label for logging and config
clarity, such as `agent-zero` or `codex`.
### Example: Codex Configuration
```toml
[mcp_servers.openbrain]
url = "https://ob.ingwaz.work/mcp"
http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbrain", "X-Agent-Type" = "codex" }
```
### Example: Agent Zero Configuration
```json
{
"mcpServers": {
"openbrain": {
"url": "https://ob.ingwaz.work/mcp/sse",
"headers": {
"X-API-Key": "YOUR_OPENBRAIN_API_KEY",
"X-Agent-ID": "openbrain",
"X-Agent-Type": "agent-zero"
}
}
}
}
```
Agent Zero should keep using the legacy HTTP+SSE transport unless and until its
client runtime supports streamable HTTP. Codex should use `/mcp`.
### Example: Store a Memory ### Example: Store a Memory
```json ```json
@@ -180,7 +219,7 @@ Health Check: http://localhost:3100/mcp/health
┌─────────────────────────────────────────────────────────┐ ┌─────────────────────────────────────────────────────────┐
│ AI Agent │ │ AI Agent │
└─────────────────────┬───────────────────────────────────┘ └─────────────────────┬───────────────────────────────────┘
│ MCP Protocol (SSE) │ MCP Protocol (Streamable HTTP / Legacy SSE)
┌─────────────────────▼───────────────────────────────────┐ ┌─────────────────────▼───────────────────────────────────┐
│ OpenBrain MCP Server │ │ OpenBrain MCP Server │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │

View File

@@ -90,6 +90,15 @@ pub fn get_optional_agent_id(headers: &HeaderMap) -> Option<String> {
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
} }
pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
headers
.get("X-Agent-Type")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
/// Extract agent ID from request headers or default /// Extract agent ID from request headers or default
pub fn get_agent_id(request: &Request) -> String { pub fn get_agent_id(request: &Request) -> String {
get_optional_agent_id(request.headers()) get_optional_agent_id(request.headers())
@@ -122,4 +131,15 @@ mod tests {
Some("agent-zero") Some("agent-zero")
); );
} }
#[test]
fn extracts_agent_type_from_header() {
let mut headers = HeaderMap::new();
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
assert_eq!(
get_optional_agent_type(&headers).as_deref(),
Some("codex")
);
}
} }

View File

@@ -130,7 +130,7 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
let app = Router::new() let app = Router::new()
.merge(health_router) .merge(health_router)
.nest("/mcp", mcp_router) .merge(mcp_router)
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer( .layer(
CorsLayer::new() CorsLayer::new()

View File

@@ -1,10 +1,11 @@
//! SSE Transport for MCP Protocol //! HTTP transport for MCP Protocol.
//! //!
//! Implements Server-Sent Events transport for the Model Context Protocol. //! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP
//! transport on the same server.
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
response::{ response::{
IntoResponse, Response, IntoResponse, Response,
sse::{Event, KeepAlive, Sse}, sse::{Event, KeepAlive, Sse},
@@ -162,16 +163,131 @@ struct PostMessageQuery {
/// Create the MCP router /// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router { pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new() Router::new()
.route("/sse", get(sse_handler)) .route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
.route("/message", post(message_handler)) .route("/mcp/sse", get(sse_handler))
.route("/health", get(health_handler)) .route("/mcp/message", post(message_handler))
.route("/mcp/health", get(health_handler))
.with_state(state) .with_state(state)
} }
fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
let Some(origin) = headers
.get(ORIGIN)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return Ok(());
};
if origin.eq_ignore_ascii_case("null") {
warn!("Rejected MCP request with null origin");
return Err(StatusCode::FORBIDDEN);
}
let origin_uri = origin.parse::<Uri>().map_err(|_| {
warn!("Rejected MCP request with invalid origin header: {}", origin);
StatusCode::FORBIDDEN
})?;
let origin_host = origin_uri.host().ok_or_else(|| {
warn!("Rejected MCP request without origin host: {}", origin);
StatusCode::FORBIDDEN
})?;
let request_host = headers
.get("X-Forwarded-Host")
.or_else(|| headers.get(HOST))
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(',').next())
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
warn!("Rejected MCP request without host header for origin {}", origin);
StatusCode::FORBIDDEN
})?;
let origin_authority = origin_uri
.authority()
.map(|authority| authority.as_str())
.unwrap_or(origin_host);
let origin_with_default_port = origin_uri
.port_u16()
.or_else(|| match origin_uri.scheme_str() {
Some("https") => Some(443),
Some("http") => Some(80),
_ => None,
})
.map(|port| format!("{origin_host}:{port}"));
if request_host.eq_ignore_ascii_case(origin_host)
|| request_host.eq_ignore_ascii_case(origin_authority)
|| origin_with_default_port
.as_deref()
.is_some_and(|value| request_host.eq_ignore_ascii_case(value))
{
Ok(())
} else {
warn!(
"Rejected MCP request due to origin/host mismatch: origin={}, host={}",
origin, request_host
);
Err(StatusCode::FORBIDDEN)
}
}
async fn streamable_post_handler(
State(state): State<Arc<McpState>>,
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received streamable MCP request"
);
let request = apply_request_context(request, &headers);
let response = dispatch_request(&state, &request).await;
match response {
Some(response) => Json(response).into_response(),
None => StatusCode::ACCEPTED.into_response(),
}
}
async fn streamable_get_handler(headers: HeaderMap) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
async fn streamable_delete_handler(headers: HeaderMap) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
/// SSE endpoint for streaming events /// SSE endpoint for streaming events
async fn sse_handler( async fn sse_handler(
State(state): State<Arc<McpState>>, State(state): State<Arc<McpState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> { headers: HeaderMap,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
validate_origin(&headers)?;
info!(
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Opening legacy SSE MCP stream"
);
let mut broadcast_rx = state.event_tx.subscribe(); let mut broadcast_rx = state.event_tx.subscribe();
let (session_tx, mut session_rx) = mpsc::channel(32); let (session_tx, mut session_rx) = mpsc::channel(32);
let session_id = Uuid::new_v4().to_string(); let session_id = Uuid::new_v4().to_string();
@@ -222,7 +338,7 @@ async fn sse_handler(
} }
}; };
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))) Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
} }
/// Message endpoint for JSON-RPC requests /// Message endpoint for JSON-RPC requests
@@ -232,7 +348,16 @@ async fn message_handler(
headers: HeaderMap, headers: HeaderMap,
Json(request): Json<JsonRpcRequest>, Json(request): Json<JsonRpcRequest>,
) -> Response { ) -> Response {
info!("Received MCP request: {}", request.method); if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received legacy SSE MCP request"
);
if let Some(session_id) = query.session_id.as_deref() { if let Some(session_id) = query.session_id.as_deref() {
if !state.has_session(session_id).await { if !state.has_session(session_id).await {
@@ -513,4 +638,22 @@ mod tests {
"/mcp/message?sessionId=abc123" "/mcp/message?sessionId=abc123"
); );
} }
#[test]
fn allows_matching_origin_and_host() {
let mut headers = HeaderMap::new();
headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap());
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
assert!(validate_origin(&headers).is_ok());
}
#[test]
fn rejects_mismatched_origin_and_host() {
let mut headers = HeaderMap::new();
headers.insert(ORIGIN, "https://evil.example".parse().unwrap());
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN));
}
} }

View File

@@ -110,6 +110,26 @@ async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> V
.expect("JSON-RPC response body") .expect("JSON-RPC response body")
} }
async fn call_streamable_jsonrpc(
client: &reqwest::Client,
base: &str,
request: Value,
) -> reqwest::Response {
let mut req_builder = client
.post(format!("{base}/mcp"))
.header("Accept", "application/json, text/event-stream")
.json(&request);
if let Some(key) = api_key() {
req_builder = req_builder.header("X-API-Key", key);
}
req_builder
.send()
.await
.expect("streamable JSON-RPC HTTP request")
}
/// Make an authenticated GET request to an MCP endpoint /// Make an authenticated GET request to an MCP endpoint
async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response { async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response {
let mut req_builder = client.get(format!("{base}{path}")); let mut req_builder = client.get(format!("{base}{path}"));
@@ -357,6 +377,98 @@ async fn e2e_transport_tools_list_and_unknown_method() {
); );
} }
#[tokio::test]
async fn e2e_streamable_initialize_and_tools_list() {
let base = base_url();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(20))
.build()
.expect("reqwest client");
wait_until_ready(&client, &base).await;
let initialize_response: Value = call_streamable_jsonrpc(
&client,
&base,
json!({
"jsonrpc": "2.0",
"id": "streamable-init-1",
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "e2e-client",
"version": "0.1.0"
}
}
}),
)
.await
.json()
.await
.expect("streamable initialize JSON");
assert_eq!(
initialize_response
.get("result")
.and_then(|value| value.get("protocolVersion"))
.and_then(Value::as_str),
Some("2024-11-05")
);
let tools_list_response: Value = call_streamable_jsonrpc(
&client,
&base,
json!({
"jsonrpc": "2.0",
"id": "streamable-tools-list-1",
"method": "tools/list",
"params": {}
}),
)
.await
.json()
.await
.expect("streamable tools/list JSON");
assert!(
tools_list_response
.get("result")
.and_then(|value| value.get("tools"))
.and_then(Value::as_array)
.map(|tools| !tools.is_empty())
.unwrap_or(false),
"streamable /mcp tools/list should return tool definitions"
);
}
#[tokio::test]
async fn e2e_streamable_get_returns_405() {
let base = base_url();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(20))
.build()
.expect("reqwest client");
wait_until_ready(&client, &base).await;
let mut request = client
.get(format!("{base}/mcp"))
.header("Accept", "text/event-stream");
if let Some(key) = api_key() {
request = request.header("X-API-Key", key);
}
let response = request.send().await.expect("GET /mcp");
assert_eq!(
response.status(),
reqwest::StatusCode::METHOD_NOT_ALLOWED,
"streamable GET /mcp should explicitly return 405 when standalone SSE streams are not offered"
);
}
#[tokio::test] #[tokio::test]
async fn e2e_purge_requires_confirm_flag() { async fn e2e_purge_requires_confirm_flag() {
let base = base_url(); let base = base_url();