Scope memories by API token and add shared-token e2e coverage

This commit is contained in:
Agent Zero
2026-04-01 23:30:58 -04:00
parent 98baa27c90
commit 026ae27366
17 changed files with 1096 additions and 428 deletions

View File

@@ -5,27 +5,25 @@
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
http::{
header::{HOST, ORIGIN},
HeaderMap, StatusCode, Uri,
},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
time::Duration,
};
use tokio::sync::{RwLock, broadcast, mpsc};
use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{AppState, auth, tools};
use crate::{auth, tools, AppState};
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
@@ -46,11 +44,7 @@ impl McpState {
})
}
async fn insert_session(
&self,
session_id: String,
tx: mpsc::Sender<serde_json::Value>,
) {
async fn insert_session(&self, session_id: String, tx: mpsc::Sender<serde_json::Value>) {
self.sessions.write().await.insert(session_id, tx);
}
@@ -163,7 +157,12 @@ struct PostMessageQuery {
/// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new()
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
.route(
"/mcp",
get(streamable_get_handler)
.post(streamable_post_handler)
.delete(streamable_delete_handler),
)
.route("/mcp/sse", get(sse_handler))
.route("/mcp/message", post(message_handler))
.route("/mcp/health", get(health_handler))
@@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
}
let origin_uri = origin.parse::<Uri>().map_err(|_| {
warn!("Rejected MCP request with invalid origin header: {}", origin);
warn!(
"Rejected MCP request with invalid origin header: {}",
origin
);
StatusCode::FORBIDDEN
})?;
@@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
warn!("Rejected MCP request without host header for origin {}", origin);
warn!(
"Rejected MCP request without host header for origin {}",
origin
);
StatusCode::FORBIDDEN
})?;
@@ -247,12 +252,12 @@ async fn streamable_post_handler(
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received streamable MCP request"
);
let request = apply_request_context(request, &headers);
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
let response = dispatch_request(&state, &request).await;
match response {
@@ -284,8 +289,12 @@ async fn sse_handler(
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
validate_origin(&headers)?;
info!(
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
client_id = auth::get_optional_agent_id(&headers)
.as_deref()
.unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers)
.as_deref()
.unwrap_or("unset"),
"Opening legacy SSE MCP stream"
);
let mut broadcast_rx = state.event_tx.subscribe();
@@ -354,7 +363,7 @@ async fn message_handler(
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received legacy SSE MCP request"
);
@@ -365,7 +374,7 @@ async fn message_handler(
}
}
let request = apply_request_context(request, &headers);
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
let response = dispatch_request(&state, &request).await;
match query.session_id.as_deref() {
@@ -408,17 +417,10 @@ async fn route_session_response(
fn apply_request_context(
mut request: JsonRpcRequest,
headers: &HeaderMap,
auth_enabled: bool,
) -> JsonRpcRequest {
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
inject_agent_id(&mut request, &agent_id);
}
request
}
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
if request.method != "tools/call" {
return;
return request;
}
if !request.params.is_object() {
@@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
.params
.as_object_mut()
.expect("params should be an object");
let tool_name = params
.get("name")
.and_then(|value| value.as_str())
.unwrap_or("")
.to_string();
let arguments = params
.entry("arguments")
.or_insert_with(|| serde_json::json!({}));
@@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
*arguments = serde_json::json!({});
}
arguments
let arguments = arguments
.as_object_mut()
.expect("arguments should be an object")
.entry("agent_id".to_string())
.or_insert_with(|| serde_json::json!(agent_id));
.expect("arguments should be an object");
arguments.insert(
tools::INTERNAL_AUTH_SCOPE_ARG.to_string(),
serde_json::json!(auth::get_auth_scope(headers, auth_enabled)),
);
if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id")
{
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
arguments.insert("agent_id".to_string(), serde_json::json!(agent_id));
}
}
request
}
async fn dispatch_request(
@@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value {
})
}
fn success_response(
id: serde_json::Value,
result: serde_json::Value,
) -> JsonRpcResponse {
fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
@@ -578,10 +593,11 @@ async fn handle_tools_call(
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn injects_agent_id_when_missing_from_tool_arguments() {
let mut request = JsonRpcRequest {
fn request_context_injects_auth_scope_for_tool_calls() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
@@ -592,8 +608,39 @@ mod tests {
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
inject_agent_id(&mut request, "agent-from-header");
let request = apply_request_context(request, &headers, true);
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG))
.and_then(|value| value.as_str()),
Some(auth::hash_api_key("test-token").as_str())
);
}
#[test]
fn request_context_does_not_inject_query_filter_from_header() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"query": "editor preferences"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
let request = apply_request_context(request, &headers, false);
assert_eq!(
request
@@ -601,25 +648,55 @@ mod tests {
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("agent-from-header")
None
);
}
#[test]
fn preserves_explicit_agent_id() {
let mut request = JsonRpcRequest {
fn request_context_injects_store_agent_id_from_header_when_missing() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"name": "store",
"arguments": {
"content": "prefers dark mode"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
let request = apply_request_context(request, &headers, false);
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("codex-desktop")
);
}
#[test]
fn request_context_preserves_explicit_store_agent_id() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "store",
"arguments": {
"agent_id": "explicit-agent"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
inject_agent_id(&mut request, "agent-from-header");
let request = apply_request_context(request, &headers, false);
assert_eq!(
request