Initial public release

This commit is contained in:
Agent Zero
2026-03-07 13:41:36 -05:00
commit 774982dc5a
22 changed files with 3517 additions and 0 deletions

531
src/transport.rs Normal file
View File

@@ -0,0 +1,531 @@
//! SSE Transport for MCP Protocol
//!
//! Implements Server-Sent Events transport for the Model Context Protocol.
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
time::Duration,
};
use tokio::sync::{RwLock, broadcast, mpsc};
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{AppState, auth, tools};
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
/// MCP Server State
pub struct McpState {
pub app: Arc<AppState>,
pub event_tx: broadcast::Sender<McpEvent>,
sessions: SessionStore,
}
impl McpState {
pub fn new(app: Arc<AppState>) -> Arc<Self> {
let (event_tx, _) = broadcast::channel(100);
Arc::new(Self {
app,
event_tx,
sessions: RwLock::new(HashMap::new()),
})
}
async fn insert_session(
&self,
session_id: String,
tx: mpsc::Sender<serde_json::Value>,
) {
self.sessions.write().await.insert(session_id, tx);
}
async fn remove_session(&self, session_id: &str) {
self.sessions.write().await.remove(session_id);
}
async fn has_session(&self, session_id: &str) -> bool {
self.sessions.read().await.contains_key(session_id)
}
async fn send_to_session(
&self,
session_id: &str,
response: &JsonRpcResponse,
) -> Result<(), SessionSendError> {
let tx = {
let sessions = self.sessions.read().await;
sessions
.get(session_id)
.cloned()
.ok_or(SessionSendError::NotFound)?
};
let payload =
serde_json::to_value(response).expect("serializing JSON-RPC response should succeed");
if tx.send(payload).await.is_err() {
self.remove_session(session_id).await;
return Err(SessionSendError::Closed);
}
Ok(())
}
}
enum SessionSendError {
NotFound,
Closed,
}
struct SessionGuard {
state: Arc<McpState>,
session_id: String,
}
impl SessionGuard {
fn new(state: Arc<McpState>, session_id: String) -> Self {
Self { state, session_id }
}
}
impl Drop for SessionGuard {
fn drop(&mut self) {
let state = self.state.clone();
let session_id = self.session_id.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
state.remove_session(&session_id).await;
});
}
}
}
/// MCP Event for SSE streaming
#[derive(Clone, Debug, Serialize)]
pub struct McpEvent {
pub id: String,
pub event_type: String,
pub data: serde_json::Value,
}
/// MCP JSON-RPC Request
#[derive(Debug, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default)]
pub id: Option<serde_json::Value>,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// MCP JSON-RPC Response
#[derive(Debug, Serialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Serialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PostMessageQuery {
#[serde(default)]
session_id: Option<String>,
}
/// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new()
.route("/sse", get(sse_handler))
.route("/message", post(message_handler))
.route("/health", get(health_handler))
.with_state(state)
}
/// SSE endpoint for streaming events
async fn sse_handler(
State(state): State<Arc<McpState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let mut broadcast_rx = state.event_tx.subscribe();
let (session_tx, mut session_rx) = mpsc::channel(32);
let session_id = Uuid::new_v4().to_string();
let endpoint = session_message_endpoint(&session_id);
state.insert_session(session_id.clone(), session_tx).await;
let stream = async_stream::stream! {
let _session_guard = SessionGuard::new(state.clone(), session_id.clone());
// Send endpoint event (required by MCP SSE protocol)
// This tells the client where to POST JSON-RPC messages for this session.
yield Ok(Event::default()
.event("endpoint")
.data(endpoint));
// Send initial tools list so Agent Zero knows what's available
let tools_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: serde_json::json!("initial-tools"),
result: Some(serde_json::json!({
"tools": tools::get_tool_definitions()
})),
error: None,
};
yield Ok(Event::default()
.event("message")
.json_data(&tools_response)
.unwrap());
loop {
tokio::select! {
maybe_message = session_rx.recv() => {
match maybe_message {
Some(message) => {
yield Ok(Event::default()
.event("message")
.json_data(&message)
.unwrap());
}
None => break,
}
}
event = broadcast_rx.recv() => {
match event {
Ok(event) => {
yield Ok(Event::default()
.event(&event.event_type)
.id(&event.id)
.json_data(&event.data)
.unwrap());
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("SSE client lagged, missed {} events", n);
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
};
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
}
/// Message endpoint for JSON-RPC requests
async fn message_handler(
State(state): State<Arc<McpState>>,
Query(query): Query<PostMessageQuery>,
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
info!("Received MCP request: {}", request.method);
if let Some(session_id) = query.session_id.as_deref() {
if !state.has_session(session_id).await {
return StatusCode::NOT_FOUND.into_response();
}
}
let request = apply_request_context(request, &headers);
let response = dispatch_request(&state, &request).await;
match query.session_id.as_deref() {
Some(session_id) => route_session_response(&state, session_id, response).await,
None => match response {
Some(response) => Json(response).into_response(),
None => StatusCode::ACCEPTED.into_response(),
},
}
}
/// Health check endpoint
async fn health_handler() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "healthy",
"server": "openbrain-mcp",
"version": env!("CARGO_PKG_VERSION")
}))
}
fn session_message_endpoint(session_id: &str) -> String {
format!("/mcp/message?sessionId={session_id}")
}
async fn route_session_response(
state: &Arc<McpState>,
session_id: &str,
response: Option<JsonRpcResponse>,
) -> Response {
match response {
Some(response) => match state.send_to_session(session_id, &response).await {
Ok(()) => StatusCode::ACCEPTED.into_response(),
Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(),
Err(SessionSendError::Closed) => StatusCode::GONE.into_response(),
},
None => StatusCode::ACCEPTED.into_response(),
}
}
fn apply_request_context(
mut request: JsonRpcRequest,
headers: &HeaderMap,
) -> JsonRpcRequest {
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
inject_agent_id(&mut request, &agent_id);
}
request
}
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
if request.method != "tools/call" {
return;
}
if !request.params.is_object() {
request.params = serde_json::json!({});
}
let params = request
.params
.as_object_mut()
.expect("params should be an object");
let arguments = params
.entry("arguments")
.or_insert_with(|| serde_json::json!({}));
if !arguments.is_object() {
*arguments = serde_json::json!({});
}
arguments
.as_object_mut()
.expect("arguments should be an object")
.entry("agent_id".to_string())
.or_insert_with(|| serde_json::json!(agent_id));
}
async fn dispatch_request(
state: &Arc<McpState>,
request: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
match request.method.as_str() {
"initialize" => request
.id
.clone()
.map(|id| success_response(id, initialize_result())),
"ping" => request
.id
.clone()
.map(|id| success_response(id, serde_json::json!({}))),
"tools/list" => request.id.clone().map(|id| {
success_response(
id,
serde_json::json!({
"tools": tools::get_tool_definitions()
}),
)
}),
"tools/call" => handle_tools_call(state, request).await,
"notifications/initialized" => {
info!("Received MCP initialized notification");
None
}
_ => request.id.clone().map(|id| {
error_response(
id,
-32601,
format!("Method not found: {}", request.method),
None,
)
}),
}
}
fn initialize_result() -> serde_json::Value {
serde_json::json!({
"protocolVersion": "2024-11-05",
"serverInfo": {
"name": "openbrain-mcp",
"version": env!("CARGO_PKG_VERSION")
},
"capabilities": {
"tools": {
"listChanged": false
}
}
})
}
fn success_response(
id: serde_json::Value,
result: serde_json::Value,
) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
fn error_response(
id: serde_json::Value,
code: i32,
message: String,
data: Option<serde_json::Value>,
) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code,
message,
data,
}),
}
}
/// Handle tools/call request
async fn handle_tools_call(
state: &Arc<McpState>,
request: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
let params = &request.params;
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
let arguments = params
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
let result = match tools::execute_tool(&state.app, tool_name, arguments).await {
Ok(result) => Ok(serde_json::json!({
"content": [{
"type": "text",
"text": result
}]
})),
Err(e) => {
let full_error = format!("{:#}", e);
error!("Tool execution error: {}", full_error);
Err(JsonRpcError {
code: -32000,
message: full_error,
data: None,
})
}
};
match (request.id.clone(), result) {
(Some(id), Ok(result)) => Some(success_response(id, result)),
(Some(id), Err(error)) => Some(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(error),
}),
(None, Ok(_)) => None,
(None, Err(error)) => {
error!(
"Tool execution failed for notification '{}': {}",
request.method, error.message
);
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn injects_agent_id_when_missing_from_tool_arguments() {
let mut request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"query": "editor preferences"
}
}),
};
inject_agent_id(&mut request, "agent-from-header");
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("agent-from-header")
);
}
#[test]
fn preserves_explicit_agent_id() {
let mut request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"agent_id": "explicit-agent"
}
}),
};
inject_agent_id(&mut request, "agent-from-header");
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("explicit-agent")
);
}
#[test]
fn session_endpoint_uses_camel_case_query_param() {
assert_eq!(
session_message_endpoint("abc123"),
"/mcp/message?sessionId=abc123"
);
}
}