mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
660 lines
19 KiB
Rust
660 lines
19 KiB
Rust
//! HTTP transport for MCP Protocol.
|
|
//!
|
|
//! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP
|
|
//! transport on the same server.
|
|
|
|
use axum::{
|
|
extract::{Query, State},
|
|
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
|
|
response::{
|
|
IntoResponse, Response,
|
|
sse::{Event, KeepAlive, Sse},
|
|
},
|
|
routing::{get, post},
|
|
Json, Router,
|
|
};
|
|
use futures::stream::Stream;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::{
|
|
collections::HashMap,
|
|
convert::Infallible,
|
|
sync::Arc,
|
|
time::Duration,
|
|
};
|
|
use tokio::sync::{RwLock, broadcast, mpsc};
|
|
use tracing::{error, info, warn};
|
|
use uuid::Uuid;
|
|
|
|
use crate::{AppState, auth, tools};
|
|
|
|
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
|
|
|
/// MCP Server State
|
|
pub struct McpState {
|
|
pub app: Arc<AppState>,
|
|
pub event_tx: broadcast::Sender<McpEvent>,
|
|
sessions: SessionStore,
|
|
}
|
|
|
|
impl McpState {
|
|
pub fn new(app: Arc<AppState>) -> Arc<Self> {
|
|
let (event_tx, _) = broadcast::channel(100);
|
|
Arc::new(Self {
|
|
app,
|
|
event_tx,
|
|
sessions: RwLock::new(HashMap::new()),
|
|
})
|
|
}
|
|
|
|
async fn insert_session(
|
|
&self,
|
|
session_id: String,
|
|
tx: mpsc::Sender<serde_json::Value>,
|
|
) {
|
|
self.sessions.write().await.insert(session_id, tx);
|
|
}
|
|
|
|
async fn remove_session(&self, session_id: &str) {
|
|
self.sessions.write().await.remove(session_id);
|
|
}
|
|
|
|
async fn has_session(&self, session_id: &str) -> bool {
|
|
self.sessions.read().await.contains_key(session_id)
|
|
}
|
|
|
|
async fn send_to_session(
|
|
&self,
|
|
session_id: &str,
|
|
response: &JsonRpcResponse,
|
|
) -> Result<(), SessionSendError> {
|
|
let tx = {
|
|
let sessions = self.sessions.read().await;
|
|
sessions
|
|
.get(session_id)
|
|
.cloned()
|
|
.ok_or(SessionSendError::NotFound)?
|
|
};
|
|
|
|
let payload =
|
|
serde_json::to_value(response).expect("serializing JSON-RPC response should succeed");
|
|
|
|
if tx.send(payload).await.is_err() {
|
|
self.remove_session(session_id).await;
|
|
return Err(SessionSendError::Closed);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
enum SessionSendError {
|
|
NotFound,
|
|
Closed,
|
|
}
|
|
|
|
struct SessionGuard {
|
|
state: Arc<McpState>,
|
|
session_id: String,
|
|
}
|
|
|
|
impl SessionGuard {
|
|
fn new(state: Arc<McpState>, session_id: String) -> Self {
|
|
Self { state, session_id }
|
|
}
|
|
}
|
|
|
|
impl Drop for SessionGuard {
|
|
fn drop(&mut self) {
|
|
let state = self.state.clone();
|
|
let session_id = self.session_id.clone();
|
|
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
|
handle.spawn(async move {
|
|
state.remove_session(&session_id).await;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
/// MCP Event for SSE streaming
|
|
#[derive(Clone, Debug, Serialize)]
|
|
pub struct McpEvent {
|
|
pub id: String,
|
|
pub event_type: String,
|
|
pub data: serde_json::Value,
|
|
}
|
|
|
|
/// MCP JSON-RPC Request
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct JsonRpcRequest {
|
|
pub jsonrpc: String,
|
|
#[serde(default)]
|
|
pub id: Option<serde_json::Value>,
|
|
pub method: String,
|
|
#[serde(default)]
|
|
pub params: serde_json::Value,
|
|
}
|
|
|
|
/// MCP JSON-RPC Response
|
|
#[derive(Debug, Serialize)]
|
|
pub struct JsonRpcResponse {
|
|
pub jsonrpc: String,
|
|
pub id: serde_json::Value,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub result: Option<serde_json::Value>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub error: Option<JsonRpcError>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct JsonRpcError {
|
|
pub code: i32,
|
|
pub message: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub data: Option<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Debug, Default, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct PostMessageQuery {
|
|
#[serde(default)]
|
|
session_id: Option<String>,
|
|
}
|
|
|
|
/// Create the MCP router
|
|
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
|
Router::new()
|
|
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
|
|
.route("/mcp/sse", get(sse_handler))
|
|
.route("/mcp/message", post(message_handler))
|
|
.route("/mcp/health", get(health_handler))
|
|
.with_state(state)
|
|
}
|
|
|
|
fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
|
let Some(origin) = headers
|
|
.get(ORIGIN)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
else {
|
|
return Ok(());
|
|
};
|
|
|
|
if origin.eq_ignore_ascii_case("null") {
|
|
warn!("Rejected MCP request with null origin");
|
|
return Err(StatusCode::FORBIDDEN);
|
|
}
|
|
|
|
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
|
warn!("Rejected MCP request with invalid origin header: {}", origin);
|
|
StatusCode::FORBIDDEN
|
|
})?;
|
|
|
|
let origin_host = origin_uri.host().ok_or_else(|| {
|
|
warn!("Rejected MCP request without origin host: {}", origin);
|
|
StatusCode::FORBIDDEN
|
|
})?;
|
|
|
|
let request_host = headers
|
|
.get("X-Forwarded-Host")
|
|
.or_else(|| headers.get(HOST))
|
|
.and_then(|value| value.to_str().ok())
|
|
.and_then(|value| value.split(',').next())
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
.ok_or_else(|| {
|
|
warn!("Rejected MCP request without host header for origin {}", origin);
|
|
StatusCode::FORBIDDEN
|
|
})?;
|
|
|
|
let origin_authority = origin_uri
|
|
.authority()
|
|
.map(|authority| authority.as_str())
|
|
.unwrap_or(origin_host);
|
|
let origin_with_default_port = origin_uri
|
|
.port_u16()
|
|
.or_else(|| match origin_uri.scheme_str() {
|
|
Some("https") => Some(443),
|
|
Some("http") => Some(80),
|
|
_ => None,
|
|
})
|
|
.map(|port| format!("{origin_host}:{port}"));
|
|
|
|
if request_host.eq_ignore_ascii_case(origin_host)
|
|
|| request_host.eq_ignore_ascii_case(origin_authority)
|
|
|| origin_with_default_port
|
|
.as_deref()
|
|
.is_some_and(|value| request_host.eq_ignore_ascii_case(value))
|
|
{
|
|
Ok(())
|
|
} else {
|
|
warn!(
|
|
"Rejected MCP request due to origin/host mismatch: origin={}, host={}",
|
|
origin, request_host
|
|
);
|
|
Err(StatusCode::FORBIDDEN)
|
|
}
|
|
}
|
|
|
|
async fn streamable_post_handler(
|
|
State(state): State<Arc<McpState>>,
|
|
headers: HeaderMap,
|
|
Json(request): Json<JsonRpcRequest>,
|
|
) -> Response {
|
|
if let Err(status) = validate_origin(&headers) {
|
|
return status.into_response();
|
|
}
|
|
|
|
info!(
|
|
method = %request.method,
|
|
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
|
"Received streamable MCP request"
|
|
);
|
|
|
|
let request = apply_request_context(request, &headers);
|
|
let response = dispatch_request(&state, &request).await;
|
|
|
|
match response {
|
|
Some(response) => Json(response).into_response(),
|
|
None => StatusCode::ACCEPTED.into_response(),
|
|
}
|
|
}
|
|
|
|
async fn streamable_get_handler(headers: HeaderMap) -> Response {
|
|
if let Err(status) = validate_origin(&headers) {
|
|
return status.into_response();
|
|
}
|
|
|
|
StatusCode::METHOD_NOT_ALLOWED.into_response()
|
|
}
|
|
|
|
async fn streamable_delete_handler(headers: HeaderMap) -> Response {
|
|
if let Err(status) = validate_origin(&headers) {
|
|
return status.into_response();
|
|
}
|
|
|
|
StatusCode::METHOD_NOT_ALLOWED.into_response()
|
|
}
|
|
|
|
/// SSE endpoint for streaming events
|
|
async fn sse_handler(
|
|
State(state): State<Arc<McpState>>,
|
|
headers: HeaderMap,
|
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
|
validate_origin(&headers)?;
|
|
info!(
|
|
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
|
"Opening legacy SSE MCP stream"
|
|
);
|
|
let mut broadcast_rx = state.event_tx.subscribe();
|
|
let (session_tx, mut session_rx) = mpsc::channel(32);
|
|
let session_id = Uuid::new_v4().to_string();
|
|
let endpoint = session_message_endpoint(&session_id);
|
|
|
|
state.insert_session(session_id.clone(), session_tx).await;
|
|
|
|
let stream = async_stream::stream! {
|
|
let _session_guard = SessionGuard::new(state.clone(), session_id.clone());
|
|
|
|
// Send endpoint event (required by MCP SSE protocol)
|
|
// This tells the client where to POST JSON-RPC messages for this session.
|
|
yield Ok(Event::default()
|
|
.event("endpoint")
|
|
.data(endpoint));
|
|
|
|
loop {
|
|
tokio::select! {
|
|
maybe_message = session_rx.recv() => {
|
|
match maybe_message {
|
|
Some(message) => {
|
|
yield Ok(Event::default()
|
|
.event("message")
|
|
.json_data(&message)
|
|
.unwrap());
|
|
}
|
|
None => break,
|
|
}
|
|
}
|
|
event = broadcast_rx.recv() => {
|
|
match event {
|
|
Ok(event) => {
|
|
yield Ok(Event::default()
|
|
.event(&event.event_type)
|
|
.id(&event.id)
|
|
.json_data(&event.data)
|
|
.unwrap());
|
|
}
|
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
|
warn!("SSE client lagged, missed {} events", n);
|
|
}
|
|
Err(broadcast::error::RecvError::Closed) => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
|
|
}
|
|
|
|
/// Message endpoint for JSON-RPC requests
|
|
async fn message_handler(
|
|
State(state): State<Arc<McpState>>,
|
|
Query(query): Query<PostMessageQuery>,
|
|
headers: HeaderMap,
|
|
Json(request): Json<JsonRpcRequest>,
|
|
) -> Response {
|
|
if let Err(status) = validate_origin(&headers) {
|
|
return status.into_response();
|
|
}
|
|
|
|
info!(
|
|
method = %request.method,
|
|
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
|
"Received legacy SSE MCP request"
|
|
);
|
|
|
|
if let Some(session_id) = query.session_id.as_deref() {
|
|
if !state.has_session(session_id).await {
|
|
return StatusCode::NOT_FOUND.into_response();
|
|
}
|
|
}
|
|
|
|
let request = apply_request_context(request, &headers);
|
|
let response = dispatch_request(&state, &request).await;
|
|
|
|
match query.session_id.as_deref() {
|
|
Some(session_id) => route_session_response(&state, session_id, response).await,
|
|
None => match response {
|
|
Some(response) => Json(response).into_response(),
|
|
None => StatusCode::ACCEPTED.into_response(),
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Health check endpoint
|
|
async fn health_handler() -> Json<serde_json::Value> {
|
|
Json(serde_json::json!({
|
|
"status": "healthy",
|
|
"server": "openbrain-mcp",
|
|
"version": env!("CARGO_PKG_VERSION")
|
|
}))
|
|
}
|
|
|
|
fn session_message_endpoint(session_id: &str) -> String {
|
|
format!("/mcp/message?sessionId={session_id}")
|
|
}
|
|
|
|
async fn route_session_response(
|
|
state: &Arc<McpState>,
|
|
session_id: &str,
|
|
response: Option<JsonRpcResponse>,
|
|
) -> Response {
|
|
match response {
|
|
Some(response) => match state.send_to_session(session_id, &response).await {
|
|
Ok(()) => StatusCode::ACCEPTED.into_response(),
|
|
Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(),
|
|
Err(SessionSendError::Closed) => StatusCode::GONE.into_response(),
|
|
},
|
|
None => StatusCode::ACCEPTED.into_response(),
|
|
}
|
|
}
|
|
|
|
fn apply_request_context(
|
|
mut request: JsonRpcRequest,
|
|
headers: &HeaderMap,
|
|
) -> JsonRpcRequest {
|
|
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
|
inject_agent_id(&mut request, &agent_id);
|
|
}
|
|
|
|
request
|
|
}
|
|
|
|
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
|
if request.method != "tools/call" {
|
|
return;
|
|
}
|
|
|
|
if !request.params.is_object() {
|
|
request.params = serde_json::json!({});
|
|
}
|
|
|
|
let params = request
|
|
.params
|
|
.as_object_mut()
|
|
.expect("params should be an object");
|
|
let arguments = params
|
|
.entry("arguments")
|
|
.or_insert_with(|| serde_json::json!({}));
|
|
|
|
if !arguments.is_object() {
|
|
*arguments = serde_json::json!({});
|
|
}
|
|
|
|
arguments
|
|
.as_object_mut()
|
|
.expect("arguments should be an object")
|
|
.entry("agent_id".to_string())
|
|
.or_insert_with(|| serde_json::json!(agent_id));
|
|
}
|
|
|
|
async fn dispatch_request(
|
|
state: &Arc<McpState>,
|
|
request: &JsonRpcRequest,
|
|
) -> Option<JsonRpcResponse> {
|
|
match request.method.as_str() {
|
|
"initialize" => request
|
|
.id
|
|
.clone()
|
|
.map(|id| success_response(id, initialize_result())),
|
|
"ping" => request
|
|
.id
|
|
.clone()
|
|
.map(|id| success_response(id, serde_json::json!({}))),
|
|
"tools/list" => request.id.clone().map(|id| {
|
|
success_response(
|
|
id,
|
|
serde_json::json!({
|
|
"tools": tools::get_tool_definitions()
|
|
}),
|
|
)
|
|
}),
|
|
"tools/call" => handle_tools_call(state, request).await,
|
|
"notifications/initialized" => {
|
|
info!("Received MCP initialized notification");
|
|
None
|
|
}
|
|
_ => request.id.clone().map(|id| {
|
|
error_response(
|
|
id,
|
|
-32601,
|
|
format!("Method not found: {}", request.method),
|
|
None,
|
|
)
|
|
}),
|
|
}
|
|
}
|
|
|
|
fn initialize_result() -> serde_json::Value {
|
|
serde_json::json!({
|
|
"protocolVersion": "2024-11-05",
|
|
"serverInfo": {
|
|
"name": "openbrain-mcp",
|
|
"version": env!("CARGO_PKG_VERSION")
|
|
},
|
|
"capabilities": {
|
|
"tools": {
|
|
"listChanged": false
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
fn success_response(
|
|
id: serde_json::Value,
|
|
result: serde_json::Value,
|
|
) -> JsonRpcResponse {
|
|
JsonRpcResponse {
|
|
jsonrpc: "2.0".to_string(),
|
|
id,
|
|
result: Some(result),
|
|
error: None,
|
|
}
|
|
}
|
|
|
|
fn error_response(
|
|
id: serde_json::Value,
|
|
code: i32,
|
|
message: String,
|
|
data: Option<serde_json::Value>,
|
|
) -> JsonRpcResponse {
|
|
JsonRpcResponse {
|
|
jsonrpc: "2.0".to_string(),
|
|
id,
|
|
result: None,
|
|
error: Some(JsonRpcError {
|
|
code,
|
|
message,
|
|
data,
|
|
}),
|
|
}
|
|
}
|
|
|
|
/// Handle tools/call request
|
|
async fn handle_tools_call(
|
|
state: &Arc<McpState>,
|
|
request: &JsonRpcRequest,
|
|
) -> Option<JsonRpcResponse> {
|
|
let params = &request.params;
|
|
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
|
let arguments = params
|
|
.get("arguments")
|
|
.cloned()
|
|
.unwrap_or(serde_json::json!({}));
|
|
|
|
let result = match tools::execute_tool(&state.app, tool_name, arguments).await {
|
|
Ok(result) => Ok(serde_json::json!({
|
|
"content": [{
|
|
"type": "text",
|
|
"text": result
|
|
}]
|
|
})),
|
|
Err(e) => {
|
|
let full_error = format!("{:#}", e);
|
|
error!("Tool execution error: {}", full_error);
|
|
Err(JsonRpcError {
|
|
code: -32000,
|
|
message: full_error,
|
|
data: None,
|
|
})
|
|
}
|
|
};
|
|
|
|
match (request.id.clone(), result) {
|
|
(Some(id), Ok(result)) => Some(success_response(id, result)),
|
|
(Some(id), Err(error)) => Some(JsonRpcResponse {
|
|
jsonrpc: "2.0".to_string(),
|
|
id,
|
|
result: None,
|
|
error: Some(error),
|
|
}),
|
|
(None, Ok(_)) => None,
|
|
(None, Err(error)) => {
|
|
error!(
|
|
"Tool execution failed for notification '{}': {}",
|
|
request.method, error.message
|
|
);
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn injects_agent_id_when_missing_from_tool_arguments() {
|
|
let mut request = JsonRpcRequest {
|
|
jsonrpc: "2.0".to_string(),
|
|
id: Some(serde_json::json!("1")),
|
|
method: "tools/call".to_string(),
|
|
params: serde_json::json!({
|
|
"name": "query",
|
|
"arguments": {
|
|
"query": "editor preferences"
|
|
}
|
|
}),
|
|
};
|
|
|
|
inject_agent_id(&mut request, "agent-from-header");
|
|
|
|
assert_eq!(
|
|
request
|
|
.params
|
|
.get("arguments")
|
|
.and_then(|value| value.get("agent_id"))
|
|
.and_then(|value| value.as_str()),
|
|
Some("agent-from-header")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn preserves_explicit_agent_id() {
|
|
let mut request = JsonRpcRequest {
|
|
jsonrpc: "2.0".to_string(),
|
|
id: Some(serde_json::json!("1")),
|
|
method: "tools/call".to_string(),
|
|
params: serde_json::json!({
|
|
"name": "query",
|
|
"arguments": {
|
|
"agent_id": "explicit-agent"
|
|
}
|
|
}),
|
|
};
|
|
|
|
inject_agent_id(&mut request, "agent-from-header");
|
|
|
|
assert_eq!(
|
|
request
|
|
.params
|
|
.get("arguments")
|
|
.and_then(|value| value.get("agent_id"))
|
|
.and_then(|value| value.as_str()),
|
|
Some("explicit-agent")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn session_endpoint_uses_camel_case_query_param() {
|
|
assert_eq!(
|
|
session_message_endpoint("abc123"),
|
|
"/mcp/message?sessionId=abc123"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn allows_matching_origin_and_host() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap());
|
|
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
|
|
|
|
assert!(validate_origin(&headers).is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn rejects_mismatched_origin_and_host() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(ORIGIN, "https://evil.example".parse().unwrap());
|
|
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
|
|
|
|
assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN));
|
|
}
|
|
}
|