Files
openbrain-mcp/src/transport.rs

660 lines
19 KiB
Rust

//! HTTP transport for MCP Protocol.
//!
//! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP
//! transport on the same server.
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
time::Duration,
};
use tokio::sync::{RwLock, broadcast, mpsc};
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{AppState, auth, tools};
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
/// MCP Server State
pub struct McpState {
pub app: Arc<AppState>,
pub event_tx: broadcast::Sender<McpEvent>,
sessions: SessionStore,
}
impl McpState {
pub fn new(app: Arc<AppState>) -> Arc<Self> {
let (event_tx, _) = broadcast::channel(100);
Arc::new(Self {
app,
event_tx,
sessions: RwLock::new(HashMap::new()),
})
}
async fn insert_session(
&self,
session_id: String,
tx: mpsc::Sender<serde_json::Value>,
) {
self.sessions.write().await.insert(session_id, tx);
}
async fn remove_session(&self, session_id: &str) {
self.sessions.write().await.remove(session_id);
}
async fn has_session(&self, session_id: &str) -> bool {
self.sessions.read().await.contains_key(session_id)
}
async fn send_to_session(
&self,
session_id: &str,
response: &JsonRpcResponse,
) -> Result<(), SessionSendError> {
let tx = {
let sessions = self.sessions.read().await;
sessions
.get(session_id)
.cloned()
.ok_or(SessionSendError::NotFound)?
};
let payload =
serde_json::to_value(response).expect("serializing JSON-RPC response should succeed");
if tx.send(payload).await.is_err() {
self.remove_session(session_id).await;
return Err(SessionSendError::Closed);
}
Ok(())
}
}
enum SessionSendError {
NotFound,
Closed,
}
struct SessionGuard {
state: Arc<McpState>,
session_id: String,
}
impl SessionGuard {
fn new(state: Arc<McpState>, session_id: String) -> Self {
Self { state, session_id }
}
}
impl Drop for SessionGuard {
fn drop(&mut self) {
let state = self.state.clone();
let session_id = self.session_id.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
state.remove_session(&session_id).await;
});
}
}
}
/// MCP Event for SSE streaming
#[derive(Clone, Debug, Serialize)]
pub struct McpEvent {
pub id: String,
pub event_type: String,
pub data: serde_json::Value,
}
/// MCP JSON-RPC Request
#[derive(Debug, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default)]
pub id: Option<serde_json::Value>,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// MCP JSON-RPC Response
#[derive(Debug, Serialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Serialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PostMessageQuery {
#[serde(default)]
session_id: Option<String>,
}
/// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new()
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
.route("/mcp/sse", get(sse_handler))
.route("/mcp/message", post(message_handler))
.route("/mcp/health", get(health_handler))
.with_state(state)
}
fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
let Some(origin) = headers
.get(ORIGIN)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return Ok(());
};
if origin.eq_ignore_ascii_case("null") {
warn!("Rejected MCP request with null origin");
return Err(StatusCode::FORBIDDEN);
}
let origin_uri = origin.parse::<Uri>().map_err(|_| {
warn!("Rejected MCP request with invalid origin header: {}", origin);
StatusCode::FORBIDDEN
})?;
let origin_host = origin_uri.host().ok_or_else(|| {
warn!("Rejected MCP request without origin host: {}", origin);
StatusCode::FORBIDDEN
})?;
let request_host = headers
.get("X-Forwarded-Host")
.or_else(|| headers.get(HOST))
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(',').next())
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
warn!("Rejected MCP request without host header for origin {}", origin);
StatusCode::FORBIDDEN
})?;
let origin_authority = origin_uri
.authority()
.map(|authority| authority.as_str())
.unwrap_or(origin_host);
let origin_with_default_port = origin_uri
.port_u16()
.or_else(|| match origin_uri.scheme_str() {
Some("https") => Some(443),
Some("http") => Some(80),
_ => None,
})
.map(|port| format!("{origin_host}:{port}"));
if request_host.eq_ignore_ascii_case(origin_host)
|| request_host.eq_ignore_ascii_case(origin_authority)
|| origin_with_default_port
.as_deref()
.is_some_and(|value| request_host.eq_ignore_ascii_case(value))
{
Ok(())
} else {
warn!(
"Rejected MCP request due to origin/host mismatch: origin={}, host={}",
origin, request_host
);
Err(StatusCode::FORBIDDEN)
}
}
async fn streamable_post_handler(
State(state): State<Arc<McpState>>,
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received streamable MCP request"
);
let request = apply_request_context(request, &headers);
let response = dispatch_request(&state, &request).await;
match response {
Some(response) => Json(response).into_response(),
None => StatusCode::ACCEPTED.into_response(),
}
}
async fn streamable_get_handler(headers: HeaderMap) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
async fn streamable_delete_handler(headers: HeaderMap) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
StatusCode::METHOD_NOT_ALLOWED.into_response()
}
/// SSE endpoint for streaming events
async fn sse_handler(
State(state): State<Arc<McpState>>,
headers: HeaderMap,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
validate_origin(&headers)?;
info!(
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Opening legacy SSE MCP stream"
);
let mut broadcast_rx = state.event_tx.subscribe();
let (session_tx, mut session_rx) = mpsc::channel(32);
let session_id = Uuid::new_v4().to_string();
let endpoint = session_message_endpoint(&session_id);
state.insert_session(session_id.clone(), session_tx).await;
let stream = async_stream::stream! {
let _session_guard = SessionGuard::new(state.clone(), session_id.clone());
// Send endpoint event (required by MCP SSE protocol)
// This tells the client where to POST JSON-RPC messages for this session.
yield Ok(Event::default()
.event("endpoint")
.data(endpoint));
loop {
tokio::select! {
maybe_message = session_rx.recv() => {
match maybe_message {
Some(message) => {
yield Ok(Event::default()
.event("message")
.json_data(&message)
.unwrap());
}
None => break,
}
}
event = broadcast_rx.recv() => {
match event {
Ok(event) => {
yield Ok(Event::default()
.event(&event.event_type)
.id(&event.id)
.json_data(&event.data)
.unwrap());
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("SSE client lagged, missed {} events", n);
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
};
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
}
/// Message endpoint for JSON-RPC requests
async fn message_handler(
State(state): State<Arc<McpState>>,
Query(query): Query<PostMessageQuery>,
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
if let Err(status) = validate_origin(&headers) {
return status.into_response();
}
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received legacy SSE MCP request"
);
if let Some(session_id) = query.session_id.as_deref() {
if !state.has_session(session_id).await {
return StatusCode::NOT_FOUND.into_response();
}
}
let request = apply_request_context(request, &headers);
let response = dispatch_request(&state, &request).await;
match query.session_id.as_deref() {
Some(session_id) => route_session_response(&state, session_id, response).await,
None => match response {
Some(response) => Json(response).into_response(),
None => StatusCode::ACCEPTED.into_response(),
},
}
}
/// Health check endpoint
async fn health_handler() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "healthy",
"server": "openbrain-mcp",
"version": env!("CARGO_PKG_VERSION")
}))
}
fn session_message_endpoint(session_id: &str) -> String {
format!("/mcp/message?sessionId={session_id}")
}
async fn route_session_response(
state: &Arc<McpState>,
session_id: &str,
response: Option<JsonRpcResponse>,
) -> Response {
match response {
Some(response) => match state.send_to_session(session_id, &response).await {
Ok(()) => StatusCode::ACCEPTED.into_response(),
Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(),
Err(SessionSendError::Closed) => StatusCode::GONE.into_response(),
},
None => StatusCode::ACCEPTED.into_response(),
}
}
fn apply_request_context(
mut request: JsonRpcRequest,
headers: &HeaderMap,
) -> JsonRpcRequest {
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
inject_agent_id(&mut request, &agent_id);
}
request
}
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
if request.method != "tools/call" {
return;
}
if !request.params.is_object() {
request.params = serde_json::json!({});
}
let params = request
.params
.as_object_mut()
.expect("params should be an object");
let arguments = params
.entry("arguments")
.or_insert_with(|| serde_json::json!({}));
if !arguments.is_object() {
*arguments = serde_json::json!({});
}
arguments
.as_object_mut()
.expect("arguments should be an object")
.entry("agent_id".to_string())
.or_insert_with(|| serde_json::json!(agent_id));
}
async fn dispatch_request(
state: &Arc<McpState>,
request: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
match request.method.as_str() {
"initialize" => request
.id
.clone()
.map(|id| success_response(id, initialize_result())),
"ping" => request
.id
.clone()
.map(|id| success_response(id, serde_json::json!({}))),
"tools/list" => request.id.clone().map(|id| {
success_response(
id,
serde_json::json!({
"tools": tools::get_tool_definitions()
}),
)
}),
"tools/call" => handle_tools_call(state, request).await,
"notifications/initialized" => {
info!("Received MCP initialized notification");
None
}
_ => request.id.clone().map(|id| {
error_response(
id,
-32601,
format!("Method not found: {}", request.method),
None,
)
}),
}
}
fn initialize_result() -> serde_json::Value {
serde_json::json!({
"protocolVersion": "2024-11-05",
"serverInfo": {
"name": "openbrain-mcp",
"version": env!("CARGO_PKG_VERSION")
},
"capabilities": {
"tools": {
"listChanged": false
}
}
})
}
fn success_response(
id: serde_json::Value,
result: serde_json::Value,
) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
fn error_response(
id: serde_json::Value,
code: i32,
message: String,
data: Option<serde_json::Value>,
) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code,
message,
data,
}),
}
}
/// Handle tools/call request
async fn handle_tools_call(
state: &Arc<McpState>,
request: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
let params = &request.params;
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
let arguments = params
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
let result = match tools::execute_tool(&state.app, tool_name, arguments).await {
Ok(result) => Ok(serde_json::json!({
"content": [{
"type": "text",
"text": result
}]
})),
Err(e) => {
let full_error = format!("{:#}", e);
error!("Tool execution error: {}", full_error);
Err(JsonRpcError {
code: -32000,
message: full_error,
data: None,
})
}
};
match (request.id.clone(), result) {
(Some(id), Ok(result)) => Some(success_response(id, result)),
(Some(id), Err(error)) => Some(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(error),
}),
(None, Ok(_)) => None,
(None, Err(error)) => {
error!(
"Tool execution failed for notification '{}': {}",
request.method, error.message
);
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn injects_agent_id_when_missing_from_tool_arguments() {
let mut request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"query": "editor preferences"
}
}),
};
inject_agent_id(&mut request, "agent-from-header");
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("agent-from-header")
);
}
#[test]
fn preserves_explicit_agent_id() {
let mut request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"agent_id": "explicit-agent"
}
}),
};
inject_agent_id(&mut request, "agent-from-header");
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("explicit-agent")
);
}
#[test]
fn session_endpoint_uses_camel_case_query_param() {
assert_eq!(
session_message_endpoint("abc123"),
"/mcp/message?sessionId=abc123"
);
}
#[test]
fn allows_matching_origin_and_host() {
let mut headers = HeaderMap::new();
headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap());
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
assert!(validate_origin(&headers).is_ok());
}
#[test]
fn rejects_mismatched_origin_and_host() {
let mut headers = HeaderMap::new();
headers.insert(ORIGIN, "https://evil.example".parse().unwrap());
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN));
}
}