mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Initial public release
This commit is contained in:
531
src/transport.rs
Normal file
531
src/transport.rs
Normal file
@@ -0,0 +1,531 @@
|
||||
//! SSE Transport for MCP Protocol
|
||||
//!
|
||||
//! Implements Server-Sent Events transport for the Model Context Protocol.
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{
|
||||
IntoResponse, Response,
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{AppState, auth, tools};
|
||||
|
||||
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
||||
|
||||
/// MCP Server State
|
||||
pub struct McpState {
|
||||
pub app: Arc<AppState>,
|
||||
pub event_tx: broadcast::Sender<McpEvent>,
|
||||
sessions: SessionStore,
|
||||
}
|
||||
|
||||
impl McpState {
|
||||
pub fn new(app: Arc<AppState>) -> Arc<Self> {
|
||||
let (event_tx, _) = broadcast::channel(100);
|
||||
Arc::new(Self {
|
||||
app,
|
||||
event_tx,
|
||||
sessions: RwLock::new(HashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
&self,
|
||||
session_id: String,
|
||||
tx: mpsc::Sender<serde_json::Value>,
|
||||
) {
|
||||
self.sessions.write().await.insert(session_id, tx);
|
||||
}
|
||||
|
||||
async fn remove_session(&self, session_id: &str) {
|
||||
self.sessions.write().await.remove(session_id);
|
||||
}
|
||||
|
||||
async fn has_session(&self, session_id: &str) -> bool {
|
||||
self.sessions.read().await.contains_key(session_id)
|
||||
}
|
||||
|
||||
async fn send_to_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
response: &JsonRpcResponse,
|
||||
) -> Result<(), SessionSendError> {
|
||||
let tx = {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions
|
||||
.get(session_id)
|
||||
.cloned()
|
||||
.ok_or(SessionSendError::NotFound)?
|
||||
};
|
||||
|
||||
let payload =
|
||||
serde_json::to_value(response).expect("serializing JSON-RPC response should succeed");
|
||||
|
||||
if tx.send(payload).await.is_err() {
|
||||
self.remove_session(session_id).await;
|
||||
return Err(SessionSendError::Closed);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
enum SessionSendError {
|
||||
NotFound,
|
||||
Closed,
|
||||
}
|
||||
|
||||
struct SessionGuard {
|
||||
state: Arc<McpState>,
|
||||
session_id: String,
|
||||
}
|
||||
|
||||
impl SessionGuard {
|
||||
fn new(state: Arc<McpState>, session_id: String) -> Self {
|
||||
Self { state, session_id }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SessionGuard {
|
||||
fn drop(&mut self) {
|
||||
let state = self.state.clone();
|
||||
let session_id = self.session_id.clone();
|
||||
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||
handle.spawn(async move {
|
||||
state.remove_session(&session_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP Event for SSE streaming
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct McpEvent {
|
||||
pub id: String,
|
||||
pub event_type: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// MCP JSON-RPC Request
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
pub jsonrpc: String,
|
||||
#[serde(default)]
|
||||
pub id: Option<serde_json::Value>,
|
||||
pub method: String,
|
||||
#[serde(default)]
|
||||
pub params: serde_json::Value,
|
||||
}
|
||||
|
||||
/// MCP JSON-RPC Response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: serde_json::Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct JsonRpcError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct PostMessageQuery {
|
||||
#[serde(default)]
|
||||
session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Create the MCP router
|
||||
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||
Router::new()
|
||||
.route("/sse", get(sse_handler))
|
||||
.route("/message", post(message_handler))
|
||||
.route("/health", get(health_handler))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
/// SSE endpoint for streaming events
|
||||
async fn sse_handler(
|
||||
State(state): State<Arc<McpState>>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
let mut broadcast_rx = state.event_tx.subscribe();
|
||||
let (session_tx, mut session_rx) = mpsc::channel(32);
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let endpoint = session_message_endpoint(&session_id);
|
||||
|
||||
state.insert_session(session_id.clone(), session_tx).await;
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
let _session_guard = SessionGuard::new(state.clone(), session_id.clone());
|
||||
|
||||
// Send endpoint event (required by MCP SSE protocol)
|
||||
// This tells the client where to POST JSON-RPC messages for this session.
|
||||
yield Ok(Event::default()
|
||||
.event("endpoint")
|
||||
.data(endpoint));
|
||||
|
||||
// Send initial tools list so Agent Zero knows what's available
|
||||
let tools_response = JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: serde_json::json!("initial-tools"),
|
||||
result: Some(serde_json::json!({
|
||||
"tools": tools::get_tool_definitions()
|
||||
})),
|
||||
error: None,
|
||||
};
|
||||
|
||||
yield Ok(Event::default()
|
||||
.event("message")
|
||||
.json_data(&tools_response)
|
||||
.unwrap());
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
maybe_message = session_rx.recv() => {
|
||||
match maybe_message {
|
||||
Some(message) => {
|
||||
yield Ok(Event::default()
|
||||
.event("message")
|
||||
.json_data(&message)
|
||||
.unwrap());
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
event = broadcast_rx.recv() => {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
yield Ok(Event::default()
|
||||
.event(&event.event_type)
|
||||
.id(&event.id)
|
||||
.json_data(&event.data)
|
||||
.unwrap());
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||
warn!("SSE client lagged, missed {} events", n);
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
|
||||
}
|
||||
|
||||
/// Message endpoint for JSON-RPC requests
|
||||
async fn message_handler(
|
||||
State(state): State<Arc<McpState>>,
|
||||
Query(query): Query<PostMessageQuery>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<JsonRpcRequest>,
|
||||
) -> Response {
|
||||
info!("Received MCP request: {}", request.method);
|
||||
|
||||
if let Some(session_id) = query.session_id.as_deref() {
|
||||
if !state.has_session(session_id).await {
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
}
|
||||
|
||||
let request = apply_request_context(request, &headers);
|
||||
let response = dispatch_request(&state, &request).await;
|
||||
|
||||
match query.session_id.as_deref() {
|
||||
Some(session_id) => route_session_response(&state, session_id, response).await,
|
||||
None => match response {
|
||||
Some(response) => Json(response).into_response(),
|
||||
None => StatusCode::ACCEPTED.into_response(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Health check endpoint
|
||||
async fn health_handler() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"server": "openbrain-mcp",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}))
|
||||
}
|
||||
|
||||
fn session_message_endpoint(session_id: &str) -> String {
|
||||
format!("/mcp/message?sessionId={session_id}")
|
||||
}
|
||||
|
||||
async fn route_session_response(
|
||||
state: &Arc<McpState>,
|
||||
session_id: &str,
|
||||
response: Option<JsonRpcResponse>,
|
||||
) -> Response {
|
||||
match response {
|
||||
Some(response) => match state.send_to_session(session_id, &response).await {
|
||||
Ok(()) => StatusCode::ACCEPTED.into_response(),
|
||||
Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(),
|
||||
Err(SessionSendError::Closed) => StatusCode::GONE.into_response(),
|
||||
},
|
||||
None => StatusCode::ACCEPTED.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_request_context(
|
||||
mut request: JsonRpcRequest,
|
||||
headers: &HeaderMap,
|
||||
) -> JsonRpcRequest {
|
||||
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||
inject_agent_id(&mut request, &agent_id);
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
if request.method != "tools/call" {
|
||||
return;
|
||||
}
|
||||
|
||||
if !request.params.is_object() {
|
||||
request.params = serde_json::json!({});
|
||||
}
|
||||
|
||||
let params = request
|
||||
.params
|
||||
.as_object_mut()
|
||||
.expect("params should be an object");
|
||||
let arguments = params
|
||||
.entry("arguments")
|
||||
.or_insert_with(|| serde_json::json!({}));
|
||||
|
||||
if !arguments.is_object() {
|
||||
*arguments = serde_json::json!({});
|
||||
}
|
||||
|
||||
arguments
|
||||
.as_object_mut()
|
||||
.expect("arguments should be an object")
|
||||
.entry("agent_id".to_string())
|
||||
.or_insert_with(|| serde_json::json!(agent_id));
|
||||
}
|
||||
|
||||
async fn dispatch_request(
|
||||
state: &Arc<McpState>,
|
||||
request: &JsonRpcRequest,
|
||||
) -> Option<JsonRpcResponse> {
|
||||
match request.method.as_str() {
|
||||
"initialize" => request
|
||||
.id
|
||||
.clone()
|
||||
.map(|id| success_response(id, initialize_result())),
|
||||
"ping" => request
|
||||
.id
|
||||
.clone()
|
||||
.map(|id| success_response(id, serde_json::json!({}))),
|
||||
"tools/list" => request.id.clone().map(|id| {
|
||||
success_response(
|
||||
id,
|
||||
serde_json::json!({
|
||||
"tools": tools::get_tool_definitions()
|
||||
}),
|
||||
)
|
||||
}),
|
||||
"tools/call" => handle_tools_call(state, request).await,
|
||||
"notifications/initialized" => {
|
||||
info!("Received MCP initialized notification");
|
||||
None
|
||||
}
|
||||
_ => request.id.clone().map(|id| {
|
||||
error_response(
|
||||
id,
|
||||
-32601,
|
||||
format!("Method not found: {}", request.method),
|
||||
None,
|
||||
)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn initialize_result() -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": {
|
||||
"name": "openbrain-mcp",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
},
|
||||
"capabilities": {
|
||||
"tools": {
|
||||
"listChanged": false
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn success_response(
|
||||
id: serde_json::Value,
|
||||
result: serde_json::Value,
|
||||
) -> JsonRpcResponse {
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error_response(
|
||||
id: serde_json::Value,
|
||||
code: i32,
|
||||
message: String,
|
||||
data: Option<serde_json::Value>,
|
||||
) -> JsonRpcResponse {
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code,
|
||||
message,
|
||||
data,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle tools/call request
|
||||
async fn handle_tools_call(
|
||||
state: &Arc<McpState>,
|
||||
request: &JsonRpcRequest,
|
||||
) -> Option<JsonRpcResponse> {
|
||||
let params = &request.params;
|
||||
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let arguments = params
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
|
||||
let result = match tools::execute_tool(&state.app, tool_name, arguments).await {
|
||||
Ok(result) => Ok(serde_json::json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": result
|
||||
}]
|
||||
})),
|
||||
Err(e) => {
|
||||
let full_error = format!("{:#}", e);
|
||||
error!("Tool execution error: {}", full_error);
|
||||
Err(JsonRpcError {
|
||||
code: -32000,
|
||||
message: full_error,
|
||||
data: None,
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
match (request.id.clone(), result) {
|
||||
(Some(id), Ok(result)) => Some(success_response(id, result)),
|
||||
(Some(id), Err(error)) => Some(JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(error),
|
||||
}),
|
||||
(None, Ok(_)) => None,
|
||||
(None, Err(error)) => {
|
||||
error!(
|
||||
"Tool execution failed for notification '{}': {}",
|
||||
request.method, error.message
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn injects_agent_id_when_missing_from_tool_arguments() {
|
||||
let mut request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "query",
|
||||
"arguments": {
|
||||
"query": "editor preferences"
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
inject_agent_id(&mut request, "agent-from-header");
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.params
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get("agent_id"))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("agent-from-header")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_explicit_agent_id() {
|
||||
let mut request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "query",
|
||||
"arguments": {
|
||||
"agent_id": "explicit-agent"
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
inject_agent_id(&mut request, "agent-from-header");
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.params
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get("agent_id"))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("explicit-agent")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_endpoint_uses_camel_case_query_param() {
|
||||
assert_eq!(
|
||||
session_message_endpoint("abc123"),
|
||||
"/mcp/message?sessionId=abc123"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user