mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-06-16 06:17:08 +00:00
Scope memories by API token and add shared-token e2e coverage
This commit is contained in:
175
src/transport.rs
175
src/transport.rs
@@ -5,27 +5,25 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
|
||||
http::{
|
||||
header::{HOST, ORIGIN},
|
||||
HeaderMap, StatusCode, Uri,
|
||||
},
|
||||
response::{
|
||||
IntoResponse, Response,
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse, Response,
|
||||
},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||
use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{AppState, auth, tools};
|
||||
use crate::{auth, tools, AppState};
|
||||
|
||||
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
||||
|
||||
@@ -46,11 +44,7 @@ impl McpState {
|
||||
})
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
&self,
|
||||
session_id: String,
|
||||
tx: mpsc::Sender<serde_json::Value>,
|
||||
) {
|
||||
async fn insert_session(&self, session_id: String, tx: mpsc::Sender<serde_json::Value>) {
|
||||
self.sessions.write().await.insert(session_id, tx);
|
||||
}
|
||||
|
||||
@@ -163,7 +157,12 @@ struct PostMessageQuery {
|
||||
/// Create the MCP router
|
||||
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||
Router::new()
|
||||
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
|
||||
.route(
|
||||
"/mcp",
|
||||
get(streamable_get_handler)
|
||||
.post(streamable_post_handler)
|
||||
.delete(streamable_delete_handler),
|
||||
)
|
||||
.route("/mcp/sse", get(sse_handler))
|
||||
.route("/mcp/message", post(message_handler))
|
||||
.route("/mcp/health", get(health_handler))
|
||||
@@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
||||
}
|
||||
|
||||
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
||||
warn!("Rejected MCP request with invalid origin header: {}", origin);
|
||||
warn!(
|
||||
"Rejected MCP request with invalid origin header: {}",
|
||||
origin
|
||||
);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
@@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
warn!("Rejected MCP request without host header for origin {}", origin);
|
||||
warn!(
|
||||
"Rejected MCP request without host header for origin {}",
|
||||
origin
|
||||
);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
@@ -247,12 +252,12 @@ async fn streamable_post_handler(
|
||||
|
||||
info!(
|
||||
method = %request.method,
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Received streamable MCP request"
|
||||
);
|
||||
|
||||
let request = apply_request_context(request, &headers);
|
||||
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
|
||||
let response = dispatch_request(&state, &request).await;
|
||||
|
||||
match response {
|
||||
@@ -284,8 +289,12 @@ async fn sse_handler(
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
||||
validate_origin(&headers)?;
|
||||
info!(
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
client_id = auth::get_optional_agent_id(&headers)
|
||||
.as_deref()
|
||||
.unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers)
|
||||
.as_deref()
|
||||
.unwrap_or("unset"),
|
||||
"Opening legacy SSE MCP stream"
|
||||
);
|
||||
let mut broadcast_rx = state.event_tx.subscribe();
|
||||
@@ -354,7 +363,7 @@ async fn message_handler(
|
||||
|
||||
info!(
|
||||
method = %request.method,
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Received legacy SSE MCP request"
|
||||
);
|
||||
@@ -365,7 +374,7 @@ async fn message_handler(
|
||||
}
|
||||
}
|
||||
|
||||
let request = apply_request_context(request, &headers);
|
||||
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
|
||||
let response = dispatch_request(&state, &request).await;
|
||||
|
||||
match query.session_id.as_deref() {
|
||||
@@ -408,17 +417,10 @@ async fn route_session_response(
|
||||
fn apply_request_context(
|
||||
mut request: JsonRpcRequest,
|
||||
headers: &HeaderMap,
|
||||
auth_enabled: bool,
|
||||
) -> JsonRpcRequest {
|
||||
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||
inject_agent_id(&mut request, &agent_id);
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
if request.method != "tools/call" {
|
||||
return;
|
||||
return request;
|
||||
}
|
||||
|
||||
if !request.params.is_object() {
|
||||
@@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
.params
|
||||
.as_object_mut()
|
||||
.expect("params should be an object");
|
||||
let tool_name = params
|
||||
.get("name")
|
||||
.and_then(|value| value.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let arguments = params
|
||||
.entry("arguments")
|
||||
.or_insert_with(|| serde_json::json!({}));
|
||||
@@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
*arguments = serde_json::json!({});
|
||||
}
|
||||
|
||||
arguments
|
||||
let arguments = arguments
|
||||
.as_object_mut()
|
||||
.expect("arguments should be an object")
|
||||
.entry("agent_id".to_string())
|
||||
.or_insert_with(|| serde_json::json!(agent_id));
|
||||
.expect("arguments should be an object");
|
||||
arguments.insert(
|
||||
tools::INTERNAL_AUTH_SCOPE_ARG.to_string(),
|
||||
serde_json::json!(auth::get_auth_scope(headers, auth_enabled)),
|
||||
);
|
||||
|
||||
if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id")
|
||||
{
|
||||
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||
arguments.insert("agent_id".to_string(), serde_json::json!(agent_id));
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
async fn dispatch_request(
|
||||
@@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value {
|
||||
})
|
||||
}
|
||||
|
||||
fn success_response(
|
||||
id: serde_json::Value,
|
||||
result: serde_json::Value,
|
||||
) -> JsonRpcResponse {
|
||||
fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse {
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
@@ -578,10 +593,11 @@ async fn handle_tools_call(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::HeaderValue;
|
||||
|
||||
#[test]
|
||||
fn injects_agent_id_when_missing_from_tool_arguments() {
|
||||
let mut request = JsonRpcRequest {
|
||||
fn request_context_injects_auth_scope_for_tool_calls() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
@@ -592,8 +608,39 @@ mod tests {
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
inject_agent_id(&mut request, "agent-from-header");
|
||||
let request = apply_request_context(request, &headers, true);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.params
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some(auth::hash_api_key("test-token").as_str())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_context_does_not_inject_query_filter_from_header() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "query",
|
||||
"arguments": {
|
||||
"query": "editor preferences"
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
let request = apply_request_context(request, &headers, false);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
@@ -601,25 +648,55 @@ mod tests {
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get("agent_id"))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("agent-from-header")
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_explicit_agent_id() {
|
||||
let mut request = JsonRpcRequest {
|
||||
fn request_context_injects_store_agent_id_from_header_when_missing() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "query",
|
||||
"name": "store",
|
||||
"arguments": {
|
||||
"content": "prefers dark mode"
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
let request = apply_request_context(request, &headers, false);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.params
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get("agent_id"))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("codex-desktop")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_context_preserves_explicit_store_agent_id() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "store",
|
||||
"arguments": {
|
||||
"agent_id": "explicit-agent"
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
inject_agent_id(&mut request, "agent-from-header");
|
||||
let request = apply_request_context(request, &headers, false);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
|
||||
Reference in New Issue
Block a user