Fix MCP transport compatibility and batch_store e2e coverage

This commit is contained in:
Agent Zero
2026-03-22 03:18:08 +00:00
parent d7140eb7f3
commit 26c96b41dd
5 changed files with 329 additions and 15 deletions

View File

@@ -1,10 +1,11 @@
//! SSE Transport for MCP Protocol
//! HTTP transport for MCP Protocol.
//!
//! Implements Server-Sent Events transport for the Model Context Protocol.
//! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP
//! transport on the same server.
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
@@ -162,16 +163,131 @@ struct PostMessageQuery {
/// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new()
.route("/sse", get(sse_handler))
.route("/message", post(message_handler))
.route("/health", get(health_handler))
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
.route("/mcp/sse", get(sse_handler))
.route("/mcp/message", post(message_handler))
.route("/mcp/health", get(health_handler))
.with_state(state)
}
fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
let Some(origin) = headers
.get(ORIGIN)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return Ok(());
};
if origin.eq_ignore_ascii_case("null") {
warn!("Rejected MCP request with null origin");
return Err(StatusCode::FORBIDDEN);
}
let origin_uri = origin.parse::<Uri>().map_err(|_| {
warn!("Rejected MCP request with invalid origin header: {}", origin);
StatusCode::FORBIDDEN
})?;
let origin_host = origin_uri.host().ok_or_else(|| {
warn!("Rejected MCP request without origin host: {}", origin);
StatusCode::FORBIDDEN
})?;
let request_host = headers
.get("X-Forwarded-Host")
.or_else(|| headers.get(HOST))
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(',').next())
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
warn!("Rejected MCP request without host header for origin {}", origin);
StatusCode::FORBIDDEN
})?;
let origin_authority = origin_uri
.authority()
.map(|authority| authority.as_str())
.unwrap_or(origin_host);
let origin_with_default_port = origin_uri
.port_u16()
.or_else(|| match origin_uri.scheme_str() {
Some("https") => Some(443),
Some("http") => Some(80),
_ => None,
})
.map(|port| format!("{origin_host}:{port}"));
if request_host.eq_ignore_ascii_case(origin_host)
|| request_host.eq_ignore_ascii_case(origin_authority)
|| origin_with_default_port
.as_deref()
.is_some_and(|value| request_host.eq_ignore_ascii_case(value))
{
Ok(())
} else {
warn!(
"Rejected MCP request due to origin/host mismatch: origin={}, host={}",
origin, request_host
);
Err(StatusCode::FORBIDDEN)
}
}
async fn streamable_post_handler(
State(state): State<Arc<McpState>>,
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received streamable MCP request"
);
let request = apply_request_context(request, &headers);
let response = dispatch_request(&state, &request).await;
match response {
Some(response) => Json(response).into_response(),
None => StatusCode::ACCEPTED.into_response(),
}
}
async fn streamable_get_handler(headers: HeaderMap) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
async fn streamable_delete_handler(headers: HeaderMap) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
/// SSE endpoint for streaming events
async fn sse_handler(
State(state): State<Arc<McpState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
headers: HeaderMap,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
validate_origin(&headers)?;
info!(
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Opening legacy SSE MCP stream"
);
let mut broadcast_rx = state.event_tx.subscribe();
let (session_tx, mut session_rx) = mpsc::channel(32);
let session_id = Uuid::new_v4().to_string();
@@ -222,7 +338,7 @@ async fn sse_handler(
}
};
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
}
/// Message endpoint for JSON-RPC requests
@@ -232,7 +348,16 @@ async fn message_handler(
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
info!("Received MCP request: {}", request.method);
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received legacy SSE MCP request"
);
if let Some(session_id) = query.session_id.as_deref() {
if !state.has_session(session_id).await {
@@ -513,4 +638,22 @@ mod tests {
"/mcp/message?sessionId=abc123"
);
}
#[test]
fn allows_matching_origin_and_host() {
let mut headers = HeaderMap::new();
headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap());
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
assert!(validate_origin(&headers).is_ok());
}
#[test]
fn rejects_mismatched_origin_and_host() {
let mut headers = HeaderMap::new();
headers.insert(ORIGIN, "https://evil.example".parse().unwrap());
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN));
}
}