mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Fix MCP transport compatibility and batch_store e2e coverage
This commit is contained in:
161
src/transport.rs
161
src/transport.rs
@@ -1,10 +1,11 @@
|
||||
//! SSE Transport for MCP Protocol
|
||||
//! HTTP transport for MCP Protocol.
|
||||
//!
|
||||
//! Implements Server-Sent Events transport for the Model Context Protocol.
|
||||
//! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP
|
||||
//! transport on the same server.
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
|
||||
response::{
|
||||
IntoResponse, Response,
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
@@ -162,16 +163,131 @@ struct PostMessageQuery {
|
||||
/// Create the MCP router
|
||||
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||
Router::new()
|
||||
.route("/sse", get(sse_handler))
|
||||
.route("/message", post(message_handler))
|
||||
.route("/health", get(health_handler))
|
||||
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
|
||||
.route("/mcp/sse", get(sse_handler))
|
||||
.route("/mcp/message", post(message_handler))
|
||||
.route("/mcp/health", get(health_handler))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
||||
let Some(origin) = headers
|
||||
.get(ORIGIN)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if origin.eq_ignore_ascii_case("null") {
|
||||
warn!("Rejected MCP request with null origin");
|
||||
return Err(StatusCode::FORBIDDEN);
|
||||
}
|
||||
|
||||
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
||||
warn!("Rejected MCP request with invalid origin header: {}", origin);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
let origin_host = origin_uri.host().ok_or_else(|| {
|
||||
warn!("Rejected MCP request without origin host: {}", origin);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
let request_host = headers
|
||||
.get("X-Forwarded-Host")
|
||||
.or_else(|| headers.get(HOST))
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| value.split(',').next())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
warn!("Rejected MCP request without host header for origin {}", origin);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
let origin_authority = origin_uri
|
||||
.authority()
|
||||
.map(|authority| authority.as_str())
|
||||
.unwrap_or(origin_host);
|
||||
let origin_with_default_port = origin_uri
|
||||
.port_u16()
|
||||
.or_else(|| match origin_uri.scheme_str() {
|
||||
Some("https") => Some(443),
|
||||
Some("http") => Some(80),
|
||||
_ => None,
|
||||
})
|
||||
.map(|port| format!("{origin_host}:{port}"));
|
||||
|
||||
if request_host.eq_ignore_ascii_case(origin_host)
|
||||
|| request_host.eq_ignore_ascii_case(origin_authority)
|
||||
|| origin_with_default_port
|
||||
.as_deref()
|
||||
.is_some_and(|value| request_host.eq_ignore_ascii_case(value))
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
warn!(
|
||||
"Rejected MCP request due to origin/host mismatch: origin={}, host={}",
|
||||
origin, request_host
|
||||
);
|
||||
Err(StatusCode::FORBIDDEN)
|
||||
}
|
||||
}
|
||||
|
||||
async fn streamable_post_handler(
|
||||
State(state): State<Arc<McpState>>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<JsonRpcRequest>,
|
||||
) -> Response {
|
||||
if let Err(status) = validate_origin(&headers) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
info!(
|
||||
method = %request.method,
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Received streamable MCP request"
|
||||
);
|
||||
|
||||
let request = apply_request_context(request, &headers);
|
||||
let response = dispatch_request(&state, &request).await;
|
||||
|
||||
match response {
|
||||
Some(response) => Json(response).into_response(),
|
||||
None => StatusCode::ACCEPTED.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn streamable_get_handler(headers: HeaderMap) -> Response {
|
||||
if let Err(status) = validate_origin(&headers) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
StatusCode::METHOD_NOT_ALLOWED.into_response()
|
||||
}
|
||||
|
||||
async fn streamable_delete_handler(headers: HeaderMap) -> Response {
|
||||
if let Err(status) = validate_origin(&headers) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
StatusCode::METHOD_NOT_ALLOWED.into_response()
|
||||
}
|
||||
|
||||
/// SSE endpoint for streaming events
|
||||
async fn sse_handler(
|
||||
State(state): State<Arc<McpState>>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
headers: HeaderMap,
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
||||
validate_origin(&headers)?;
|
||||
info!(
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Opening legacy SSE MCP stream"
|
||||
);
|
||||
let mut broadcast_rx = state.event_tx.subscribe();
|
||||
let (session_tx, mut session_rx) = mpsc::channel(32);
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
@@ -222,7 +338,7 @@ async fn sse_handler(
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
|
||||
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
|
||||
}
|
||||
|
||||
/// Message endpoint for JSON-RPC requests
|
||||
@@ -232,7 +348,16 @@ async fn message_handler(
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<JsonRpcRequest>,
|
||||
) -> Response {
|
||||
info!("Received MCP request: {}", request.method);
|
||||
if let Err(status) = validate_origin(&headers) {
|
||||
return status.into_response();
|
||||
}
|
||||
|
||||
info!(
|
||||
method = %request.method,
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Received legacy SSE MCP request"
|
||||
);
|
||||
|
||||
if let Some(session_id) = query.session_id.as_deref() {
|
||||
if !state.has_session(session_id).await {
|
||||
@@ -513,4 +638,22 @@ mod tests {
|
||||
"/mcp/message?sessionId=abc123"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allows_matching_origin_and_host() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap());
|
||||
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
|
||||
|
||||
assert!(validate_origin(&headers).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_mismatched_origin_and_host() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(ORIGIN, "https://evil.example".parse().unwrap());
|
||||
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
|
||||
|
||||
assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user