diff --git a/AGENTS.md b/AGENTS.md index 922f3a2..f8d0e84 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,7 +7,9 @@ memory. ## External Memory System - Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge` -- Always use the exact `agent_id` value `openbrain` +- Memory visibility is determined by the API token in the MCP client config, not by `agent_id` +- On `openbrain.store`, use `agent_id` only as a provenance label for the storing agent when that label is useful +- On `openbrain.query`, do not send `agent_id` for normal retrieval; use `source_agent_id` only when you intentionally want to filter by source agent - Do not hardcode live credentials into the repository - Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` first - Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names @@ -19,7 +21,7 @@ memory. - Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence` - If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful - If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact -- Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID +- Use `openbrain.purge` cautiously because it is coarse-grained; it deletes memories visible to the current API token and can optionally narrow by `source_agent_id` and `before`, not by individual memory ID - For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless cleanup or reset is explicitly requested ## Agent Identity & Source Tagging diff --git a/README.md b/README.md index 08f3312..a472519 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with - 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2 - 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing - 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility -- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent +- 🔐 **Shared Memory by Token**: Agents using the same API token share memory visibility while retaining source-agent provenance - ♻️ **Deduplicated Ingest**: Near-duplicate facts are merged instead of stored repeatedly - ⚡ **High Performance**: Rust implementation with async I/O @@ -20,8 +20,8 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with |------|-------------| | `store` | Store a memory with automatic embedding generation, optional TTL, and automatic deduplication | | `batch_store` | Store 1-50 memories atomically in a single call with the same deduplication rules | -| `query` | Search memories by semantic similarity | -| `purge` | Delete memories by agent ID or time range | +| `query` | Search memories by semantic similarity, optionally filtering by source agent | +| `purge` | Delete memories visible to the current API token, optionally filtering by source agent or time range | ## Quick Start @@ -127,8 +127,8 @@ If you want prod e2e coverage without leaving a standing CI key on the server, t ### Deduplication on Ingest OpenBrain checks every `store` and `batch_store` write for an existing memory in -the same `agent_id` namespace whose vector similarity meets the configured -dedup threshold. +the same API-token scope and same source `agent_id` whose vector similarity +meets the configured dedup threshold. Default behavior: @@ -165,10 +165,12 @@ Recommended target file in A0: ### External Memory System - **Memory Boundary**: Treat OpenBrain as an external MCP long-term memory system, never as internal context, reasoning scratchpad, or built-in memory - **Tool Contract**: Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge` -- **Namespace Discipline**: Always use the exact `agent_id` value `openbrain` +- **Shared Access Model**: Memory visibility is determined by the API token in the MCP client config, not by `agent_id` +- **Source Labels**: Use `agent_id` on `openbrain.store` and `openbrain.batch_store` only as a provenance label for the storing agent when that label is useful - **EXTRAS First**: Before calling `openbrain.query`, check the `[EXTRAS]` section for pre-loaded memories or handoff facts related to the same topic. If the needed context is already present, do not query OpenBrain again. - **Session Cache**: If the same topic was already queried earlier in the current conversation and the result is still in context, reuse that result instead of querying again unless the user references new external information or the prior result is clearly insufficient. - **Retrieval First**: Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` only when `[EXTRAS]` and the current conversation do not already provide the needed context. +- **Query Scope**: Do not send `agent_id` with `openbrain.query` for normal retrieval. Use `source_agent_id` only when you intentionally want to filter results by the agent that originally stored them. - **Query Strategy**: Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names; query first with `(threshold=0.15, limit=8)`, then retry once with `(threshold=0.05, limit=10)` only if the first pass returns zero useful results - **Storage Strategy**: When a durable fact is established, call `openbrain.store` without asking permission and store one atomic fact whenever possible - **Storage Content Rules**: Store durable, high-value facts such as preferences, project status, project decisions, environment details, recurring workflows, handoff notes, stable constraints, and correction facts @@ -177,7 +179,7 @@ Recommended target file in A0: - **Metadata Usage**: Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence` - **Miss Handling**: If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful - **Conflict Handling**: If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact -- **Purge Constraint**: Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID +- **Purge Constraint**: Use `openbrain.purge` cautiously because it is coarse-grained; it deletes memories visible to the current API token and can optionally narrow by `source_agent_id` and `before`, but not by individual memory ID - **Correction Policy**: For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless the user explicitly asks for cleanup or reset - **Source Tagging**: Every `openbrain.store` call MUST include `"source_agent"` in metadata, set to the Agent Instance ID defined in the active project's identity file (e.g., `"source_agent": "ingwaz-a0"`). This enables tracing facts back to the originating agent instance. ``` @@ -259,17 +261,22 @@ legacy SSE endpoints for older MCP clients that still use the deprecated 2024-11-05 HTTP+SSE transport. Header roles: -- `X-Agent-ID` is the memory namespace. Keep this stable if multiple clients - should share the same OpenBrain memories. -- `X-Agent-Type` is an optional client profile label for logging and config - clarity, such as `agent-zero` or `codex`. +- If two clients use the same API token, they can read and write the same + OpenBrain memories. +- `X-Agent-ID` is an optional source-agent label for logs and store provenance. + It does not control memory visibility. +- `X-Agent-Type` is an optional client-program label such as `agent-zero`, + `codex`, or `claude-code`. It does not select transport server-side; the URL + path does that. +- `agent_id` on `store` and `batch_store` is provenance. `source_agent_id` on + `query` and `purge` is an optional provenance filter. ### Example: Codex Configuration ```toml [mcp_servers.openbrain] url = "https://memory.example.com/mcp" -http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbrain", "X-Agent-Type" = "codex" } +http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "codex-desktop", "X-Agent-Type" = "codex" } ``` ### Example: Agent Zero Configuration @@ -281,7 +288,7 @@ http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbra "url": "https://memory.example.com/mcp/sse", "headers": { "X-API-Key": "YOUR_OPENBRAIN_API_KEY", - "X-Agent-ID": "openbrain", + "X-Agent-ID": "agent-zero", "X-Agent-Type": "agent-zero" } } @@ -290,7 +297,8 @@ http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbra ``` Agent Zero should keep using the legacy HTTP+SSE transport unless and until its -client runtime supports streamable HTTP. Codex should use `/mcp`. +client runtime supports streamable HTTP. Codex should use `/mcp`. If both +clients use the same API token, they already share memory visibility. ### Example: Store a Memory @@ -303,7 +311,7 @@ client runtime supports streamable HTTP. Codex should use `/mcp`. "name": "store", "arguments": { "content": "The user prefers dark mode and uses vim keybindings", - "agent_id": "assistant-1", + "agent_id": "agent-zero", "ttl": "7d", "metadata": {"source": "preferences"} } @@ -322,7 +330,6 @@ client runtime supports streamable HTTP. Codex should use `/mcp`. "name": "query", "arguments": { "query": "What are the user's editor preferences?", - "agent_id": "assistant-1", "limit": 5, "threshold": 0.6 } @@ -340,7 +347,7 @@ client runtime supports streamable HTTP. Codex should use `/mcp`. "params": { "name": "batch_store", "arguments": { - "agent_id": "assistant-1", + "agent_id": "codex", "entries": [ { "content": "The user prefers dark mode", diff --git a/migrations/V4__auth_scope_shared_memory.sql b/migrations/V4__auth_scope_shared_memory.sql new file mode 100644 index 0000000..5ee6033 --- /dev/null +++ b/migrations/V4__auth_scope_shared_memory.sql @@ -0,0 +1,8 @@ +ALTER TABLE memories + ADD COLUMN IF NOT EXISTS auth_scope VARCHAR(255) NOT NULL DEFAULT 'public'; + +CREATE INDEX IF NOT EXISTS idx_memories_auth_scope + ON memories (auth_scope); + +CREATE INDEX IF NOT EXISTS idx_memories_auth_scope_agent + ON memories (auth_scope, agent_id); diff --git a/src/auth.rs b/src/auth.rs index 8238c01..96c7fef 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -4,7 +4,7 @@ use axum::{ extract::{Request, State}, - http::{HeaderMap, StatusCode, header::AUTHORIZATION}, + http::{header::AUTHORIZATION, HeaderMap, StatusCode}, middleware::Next, response::Response, }; @@ -14,6 +14,8 @@ use tracing::warn; use crate::AppState; +pub const PUBLIC_AUTH_SCOPE: &str = "public"; + /// Hash an API key for secure comparison pub fn hash_api_key(key: &str) -> String { let mut hasher = Sha256::new(); @@ -99,24 +101,25 @@ pub fn get_optional_agent_type(headers: &HeaderMap) -> Option { .map(ToOwned::to_owned) } -/// Extract agent ID from request headers or default -pub fn get_agent_id(request: &Request) -> String { - get_optional_agent_id(request.headers()) - .unwrap_or_else(|| "default".to_string()) +pub fn get_auth_scope(headers: &HeaderMap, auth_enabled: bool) -> String { + if !auth_enabled { + return PUBLIC_AUTH_SCOPE.to_string(); + } + + extract_api_key(headers) + .map(|key| hash_api_key(&key)) + .unwrap_or_else(|| PUBLIC_AUTH_SCOPE.to_string()) } #[cfg(test)] mod tests { use super::*; - use axum::http::{HeaderValue, header::AUTHORIZATION}; + use axum::http::{header::AUTHORIZATION, HeaderValue}; #[test] fn extracts_api_key_from_bearer_header() { let mut headers = HeaderMap::new(); - headers.insert( - AUTHORIZATION, - HeaderValue::from_static("Bearer test-token"), - ); + headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer test-token")); assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token")); } @@ -137,9 +140,21 @@ mod tests { let mut headers = HeaderMap::new(); headers.insert("X-Agent-Type", HeaderValue::from_static("codex")); - assert_eq!( - get_optional_agent_type(&headers).as_deref(), - Some("codex") - ); + assert_eq!(get_optional_agent_type(&headers).as_deref(), Some("codex")); + } + + #[test] + fn derives_auth_scope_from_api_key_when_enabled() { + let mut headers = HeaderMap::new(); + headers.insert("X-API-Key", HeaderValue::from_static("test-token")); + + assert_eq!(get_auth_scope(&headers, true), hash_api_key("test-token")); + } + + #[test] + fn uses_public_scope_when_auth_disabled() { + let headers = HeaderMap::new(); + + assert_eq!(get_auth_scope(&headers, false), PUBLIC_AUTH_SCOPE); } } diff --git a/src/config.rs b/src/config.rs index b444746..e9aa29f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -94,29 +94,50 @@ where } match Option::::deserialize(deserializer)? { - Some(StringOrVec::String(s)) => { - Ok(s.split(',') - .map(|k| k.trim().to_string()) - .filter(|k| !k.is_empty()) - .collect()) - } + Some(StringOrVec::String(s)) => Ok(s + .split(',') + .map(|k| k.trim().to_string()) + .filter(|k| !k.is_empty()) + .collect()), Some(StringOrVec::Vec(v)) => Ok(v), None => Ok(Vec::new()), } } // Default value functions -fn default_host() -> String { "0.0.0.0".to_string() } -fn default_port() -> u16 { 3100 } -fn default_db_port() -> u16 { 5432 } -fn default_pool_size() -> usize { 10 } -fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() } -fn default_embedding_dim() -> usize { 384 } -fn default_vector_weight() -> f32 { 0.6 } -fn default_text_weight() -> f32 { 0.4 } -fn default_dedup_threshold() -> f32 { 0.90 } -fn default_cleanup_interval_seconds() -> u64 { 300 } -fn default_auth_enabled() -> bool { false } +fn default_host() -> String { + "0.0.0.0".to_string() +} +fn default_port() -> u16 { + 3100 +} +fn default_db_port() -> u16 { + 5432 +} +fn default_pool_size() -> usize { + 10 +} +fn default_model_path() -> String { + "models/all-MiniLM-L6-v2".to_string() +} +fn default_embedding_dim() -> usize { + 384 +} +fn default_vector_weight() -> f32 { + 0.6 +} +fn default_text_weight() -> f32 { + 0.4 +} +fn default_dedup_threshold() -> f32 { + 0.90 +} +fn default_cleanup_interval_seconds() -> u64 { + 300 +} +fn default_auth_enabled() -> bool { + false +} impl Config { /// Load configuration from environment variables diff --git a/src/db.rs b/src/db.rs index a20a86a..5cd70c6 100644 --- a/src/db.rs +++ b/src/db.rs @@ -9,9 +9,9 @@ use tokio_postgres::NoTls; use tracing::info; use uuid::Uuid; +use crate::config::DatabaseConfig; use serde::Serialize; use serde_json::{Map, Value}; -use crate::config::DatabaseConfig; /// Database wrapper with connection pool #[derive(Clone)] @@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value { async fn find_dedup_match( client: &C, + auth_scope: &str, agent_id: &str, embedding: &Vector, threshold: f64, @@ -87,13 +88,14 @@ where r#" SELECT id, metadata, expires_at FROM memories - WHERE agent_id = $1 + WHERE auth_scope = $1 + AND agent_id = $2 AND (expires_at IS NULL OR expires_at > NOW()) - AND (1 - (embedding <=> $2)) >= $3 - ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC + AND (1 - (embedding <=> $3)) >= $4 + ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC LIMIT 1 "#, - &[&agent_id, embedding, &threshold], + &[&auth_scope, &agent_id, embedding, &threshold], ) .await .context("Failed to check for duplicate memory")?; @@ -120,13 +122,19 @@ impl Database { .context("Failed to create database pool")?; // Test connection - let client = pool.get().await.context("Failed to get database connection")?; + let client = pool + .get() + .await + .context("Failed to get database connection")?; client .simple_query("SELECT 1") .await .context("Failed to execute test query")?; - info!("Database connection pool created with {} connections", config.pool_size); + info!( + "Database connection pool created with {} connections", + config.pool_size + ); Ok(Self { pool }) } @@ -134,6 +142,7 @@ impl Database { /// Store a memory record pub async fn store_memory( &self, + auth_scope: &str, agent_id: &str, content: &str, embedding: &[f32], @@ -146,7 +155,9 @@ impl Database { let vector = Vector::from(embedding.to_vec()); let dedup_threshold = dedup_threshold as f64; - if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? { + if let Some(existing) = + find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await? + { let merged_metadata = merge_metadata(&existing.metadata, &metadata); let refreshed_expires_at = expires_at.or(existing.expires_at); @@ -176,10 +187,10 @@ impl Database { client .execute( r#" - INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) "#, - &[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], + &[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], ) .await .context("Failed to store memory")?; @@ -194,7 +205,8 @@ impl Database { /// Query memories by vector similarity pub async fn query_memories( &self, - agent_id: &str, + auth_scope: &str, + source_agent_id: Option<&str>, query_text: &str, embedding: &[f32], limit: i64, @@ -230,7 +242,8 @@ impl Database { END AS text_score FROM memories CROSS JOIN search_query - WHERE memories.agent_id = $3 + WHERE memories.auth_scope = $3 + AND ($4::text IS NULL OR memories.agent_id = $4) AND (memories.expires_at IS NULL OR memories.expires_at > NOW()) ), ranked AS ( @@ -251,18 +264,19 @@ impl Database { text_score, CASE WHEN has_text_match = 1 - THEN (($5 * vector_score) + ($6 * text_score))::real + THEN (($6 * vector_score) + ($7 * text_score))::real ELSE vector_score END AS hybrid_score FROM ranked - WHERE vector_score >= $4 OR text_score > 0 + WHERE vector_score >= $5 OR text_score > 0 ORDER BY hybrid_score DESC, vector_score DESC - LIMIT $7 + LIMIT $8 "#, &[ &vector, &query_text, - &agent_id, + &auth_scope, + &source_agent_id, &threshold, &vector_weight, &text_weight, @@ -296,37 +310,47 @@ impl Database { Ok(matches) } - /// Delete memories by agent_id and optional filters + /// Delete memories visible to an auth scope with an optional provenance filter pub async fn purge_memories( &self, - agent_id: &str, + auth_scope: &str, + source_agent_id: Option<&str>, before: Option>, ) -> Result { let client = self.pool.get().await?; - let count = if let Some(before_ts) = before { - client - .execute( - "DELETE FROM memories WHERE agent_id = $1 AND created_at < $2", - &[&agent_id, &before_ts], - ) - .await? - } else { - client - .execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id]) - .await? - }; + let count = client + .execute( + r#" + DELETE FROM memories + WHERE auth_scope = $1 + AND ($2::text IS NULL OR agent_id = $2) + AND ($3::timestamptz IS NULL OR created_at < $3) + "#, + &[&auth_scope, &source_agent_id, &before], + ) + .await?; Ok(count) } - /// Get memory count for an agent - pub async fn count_memories(&self, agent_id: &str) -> Result { + /// Get memory count for a token-visible scope and optional provenance filter + pub async fn count_memories( + &self, + auth_scope: &str, + source_agent_id: Option<&str>, + ) -> Result { let client = self.pool.get().await?; let row = client .query_one( - "SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())", - &[&agent_id], + r#" + SELECT COUNT(*) as count + FROM memories + WHERE auth_scope = $1 + AND ($2::text IS NULL OR agent_id = $2) + AND (expires_at IS NULL OR expires_at > NOW()) + "#, + &[&auth_scope, &source_agent_id], ) .await?; Ok(row.get("count")) @@ -346,7 +370,6 @@ impl Database { } } - /// Result for a single batch entry #[derive(Debug, Clone, Serialize)] pub struct BatchStoreResult { @@ -360,6 +383,7 @@ impl Database { /// Store multiple memories in a single transaction pub async fn batch_store_memories( &self, + auth_scope: &str, agent_id: &str, entries: Vec<( String, @@ -378,7 +402,8 @@ impl Database { for (content, metadata, embedding, keywords, expires_at) in entries { let vector = Vector::from(embedding); if let Some(existing) = - find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await? + find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold) + .await? { let merged_metadata = merge_metadata(&existing.metadata, &metadata); let refreshed_expires_at = expires_at.or(existing.expires_at); @@ -404,8 +429,8 @@ impl Database { } else { let id = Uuid::new_v4(); transaction.execute( - r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#, - &[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], + r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#, + &[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], ).await?; results.push(BatchStoreResult { id: id.to_string(), diff --git a/src/embedding.rs b/src/embedding.rs index ddd0d9f..1d526d4 100644 --- a/src/embedding.rs +++ b/src/embedding.rs @@ -1,7 +1,7 @@ //! Embedding engine using local ONNX models use anyhow::Result; -use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::session::{builder::GraphOptimizationLevel, Session}; use ort::value::Value; use std::path::{Path, PathBuf}; use std::sync::Once; @@ -15,9 +15,9 @@ static ORT_INIT: Once = Once::new(); /// Initialize ONNX Runtime synchronously (called inside spawn_blocking) fn init_ort_sync(dylib_path: &str) -> Result<()> { info!("Initializing ONNX Runtime from: {}", dylib_path); - + let mut init_error: Option = None; - + ORT_INIT.call_once(|| { info!("ORT_INIT.call_once - starting initialization"); match ort::init_from(dylib_path) { @@ -43,7 +43,7 @@ fn init_ort_sync(dylib_path: &str) -> Result<()> { if let Some(err) = init_error { return Err(anyhow::anyhow!("{}", err)); } - + info!("ONNX Runtime initialization complete"); Ok(()) } @@ -91,35 +91,40 @@ impl EmbeddingEngine { let model_path = PathBuf::from(&config.model_path); let dimension = config.dimension; - - info!("Loading ONNX model from {:?}", model_path.join("model.onnx")); - + + info!( + "Loading ONNX model from {:?}", + model_path.join("model.onnx") + ); + // Use spawn_blocking to avoid blocking the async runtime - let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> { - // Initialize ONNX Runtime first - init_ort_sync(&dylib_path)?; - - info!("Creating ONNX session..."); - - // Load ONNX model with ort 2.0 API - let session = Session::builder() - .map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))? - .with_optimization_level(GraphOptimizationLevel::Level3) - .map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))? - .with_intra_threads(4) - .map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))? - .commit_from_file(model_path.join("model.onnx")) - .map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?; + let (session, tokenizer) = + tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> { + // Initialize ONNX Runtime first + init_ort_sync(&dylib_path)?; - info!("ONNX model loaded, loading tokenizer..."); - - // Load tokenizer - let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json")) - .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + info!("Creating ONNX session..."); - info!("Tokenizer loaded successfully"); - Ok((session, tokenizer)) - }).await + // Load ONNX model with ort 2.0 API + let session = Session::builder() + .map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))? + .with_intra_threads(4) + .map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))? + .commit_from_file(model_path.join("model.onnx")) + .map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?; + + info!("ONNX model loaded, loading tokenizer..."); + + // Load tokenizer + let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json")) + .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + + info!("Tokenizer loaded successfully"); + Ok((session, tokenizer)) + }) + .await .map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??; info!( @@ -136,12 +141,17 @@ impl EmbeddingEngine { /// Generate embedding for a single text pub fn embed(&self, text: &str) -> Result> { - let encoding = self.tokenizer + let encoding = self + .tokenizer .encode(text, true) .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; let input_ids: Vec = encoding.get_ids().iter().map(|&x| x as i64).collect(); - let attention_mask: Vec = encoding.get_attention_mask().iter().map(|&x| x as i64).collect(); + let attention_mask: Vec = encoding + .get_attention_mask() + .iter() + .map(|&x| x as i64) + .collect(); let token_type_ids: Vec = encoding.get_type_ids().iter().map(|&x| x as i64).collect(); let seq_len = input_ids.len(); @@ -157,22 +167,25 @@ impl EmbeddingEngine { "attention_mask" => attention_mask_tensor, "token_type_ids" => token_type_ids_tensor, ]; - - let mut session_guard = self.session.lock() + + let mut session_guard = self + .session + .lock() .map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?; let outputs = session_guard.run(inputs)?; // Extract output - let output = outputs.get("last_hidden_state") + let output = outputs + .get("last_hidden_state") .ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?; - + // Get the tensor data let (shape, data) = output.try_extract_tensor::()?; - + // Mean pooling over sequence dimension let hidden_size = *shape.last().unwrap_or(&384) as usize; let seq_len = data.len() / hidden_size; - + let mut embedding = vec![0.0f32; hidden_size]; for i in 0..seq_len { for j in 0..hidden_size { @@ -182,7 +195,7 @@ impl EmbeddingEngine { for val in &mut embedding { *val /= seq_len as f32; } - + // L2 normalize let norm: f32 = embedding.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { @@ -190,7 +203,7 @@ impl EmbeddingEngine { *val /= norm; } } - + Ok(embedding) } @@ -208,37 +221,40 @@ impl EmbeddingEngine { /// Extract keywords from text using simple frequency analysis pub fn extract_keywords(text: &str, limit: usize) -> Vec { use std::collections::HashMap; - + let stop_words: std::collections::HashSet<&str> = [ - "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", - "of", "with", "by", "from", "is", "are", "was", "were", "be", "been", - "being", "have", "has", "had", "do", "does", "did", "will", "would", - "could", "should", "may", "might", "must", "shall", "can", "this", - "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", - "what", "which", "who", "whom", "whose", "where", "when", "why", "how", - "all", "each", "every", "both", "few", "more", "most", "other", "some", - "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", - "very", "just", "also", "now", "here", "there", "then", "once", "if", - ].iter().cloned().collect(); - + "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", + "from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do", + "does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can", + "this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what", + "which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every", + "both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", + "same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then", + "once", "if", + ] + .iter() + .cloned() + .collect(); + let mut word_counts: HashMap = HashMap::new(); - + for word in text.split_whitespace() { let clean: String = word .chars() .filter(|c| c.is_alphanumeric()) .collect::() .to_lowercase(); - + if clean.len() > 2 && !stop_words.contains(clean.as_str()) { *word_counts.entry(clean).or_insert(0) += 1; } } - + let mut sorted: Vec<_> = word_counts.into_iter().collect(); sorted.sort_by(|a, b| b.1.cmp(&a.1)); - - sorted.into_iter() + + sorted + .into_iter() .take(limit) .map(|(word, _)| word) .collect() diff --git a/src/lib.rs b/src/lib.rs index c8a77ea..f4b3909 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,18 +5,18 @@ pub mod config; pub mod db; pub mod embedding; pub mod migrations; -pub mod ttl; pub mod tools; pub mod transport; +pub mod ttl; use anyhow::Result; -use axum::{Router, Json, http::StatusCode, middleware}; +use axum::{http::StatusCode, middleware, Json, Router}; use serde_json::json; use std::sync::Arc; use tokio::net::TcpListener; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; -use tracing::{info, error}; +use tracing::{error, info}; use crate::auth::auth_middleware; use crate::config::Config; @@ -60,15 +60,15 @@ async fn readiness_handler( match readiness { ReadinessState::Ready => ( StatusCode::OK, - Json(json!({"status": "ready", "embedding": true})) + Json(json!({"status": "ready", "embedding": true})), ), ReadinessState::Initializing => ( StatusCode::SERVICE_UNAVAILABLE, - Json(json!({"status": "initializing", "embedding": false})) + Json(json!({"status": "initializing", "embedding": false})), ), ReadinessState::Failed(err) => ( StatusCode::SERVICE_UNAVAILABLE, - Json(json!({"status": "failed", "error": err})) + Json(json!({"status": "failed", "error": err})), ), } } @@ -89,11 +89,14 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> { tokio::spawn(async move { let max_retries = 3; let mut attempt = 0; - + loop { attempt += 1; - info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries); - + info!( + "Initializing embedding engine (attempt {}/{})", + attempt, max_retries + ); + match EmbeddingEngine::new(&embedding_config).await { Ok(engine) => { let engine = Arc::new(engine); @@ -120,9 +123,8 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> { let cleanup_state = state.clone(); let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds; tokio::spawn(async move { - let mut interval = tokio::time::interval(tokio::time::Duration::from_secs( - cleanup_interval_seconds, - )); + let mut interval = + tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval_seconds)); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { @@ -148,10 +150,12 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> { .route("/health", axum::routing::get(health_handler)) .route("/ready", axum::routing::get(readiness_handler)) .with_state(state.clone()); - + // Build MCP router with auth middleware - let mcp_router = transport::mcp_router(mcp_state) - .layer(middleware::from_fn_with_state(state.clone(), auth_middleware)); + let mcp_router = transport::mcp_router(mcp_state).layer(middleware::from_fn_with_state( + state.clone(), + auth_middleware, + )); let app = Router::new() .merge(health_router) diff --git a/src/main.rs b/src/main.rs index 28e18f6..c6d7948 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,10 @@ async fn main() -> Result<()> { .with(tracing_subscriber::fmt::layer()) .init(); - info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION")); + info!( + "Starting OpenBrain MCP Server v{}", + env!("CARGO_PKG_VERSION") + ); // Load configuration let config = Config::load()?; diff --git a/src/tools/batch_store.rs b/src/tools/batch_store.rs index bb14d43..fd3032d 100644 --- a/src/tools/batch_store.rs +++ b/src/tools/batch_store.rs @@ -3,12 +3,14 @@ //! Accepts 1-50 entries per call, generates embeddings for each, //! stores all in a single DB transaction, returns individual IDs/status. -use anyhow::{Context, Result, anyhow}; +use anyhow::{anyhow, Context, Result}; use serde_json::Value; use std::sync::Arc; use tracing::info; +use crate::auth::PUBLIC_AUTH_SCOPE; use crate::embedding::extract_keywords; +use crate::tools::INTERNAL_AUTH_SCOPE_ARG; use crate::ttl::expires_at_from_ttl; use crate::AppState; @@ -18,7 +20,7 @@ const MAX_BATCH_SIZE: usize = 50; /// Execute the batch_store tool /// /// Accepts: -/// - `agent_id`: Optional agent identifier (defaults to "default") +/// - `agent_id`: Optional source agent label (defaults to "default") /// - `ttl`: Optional default TTL string applied to entries without their own ttl /// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional) /// @@ -38,6 +40,10 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result .get("agent_id") .and_then(|v| v.as_str()) .unwrap_or("default"); + let auth_scope = arguments + .get(INTERNAL_AUTH_SCOPE_ARG) + .and_then(|v| v.as_str()) + .unwrap_or(PUBLIC_AUTH_SCOPE); let entries = arguments .get("entries") @@ -47,7 +53,9 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result // 3. Validate batch size if entries.is_empty() { - return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries")); + return Err(anyhow!( + "Empty entries array not allowed - must provide 1-50 entries" + )); } if entries.len() > MAX_BATCH_SIZE { return Err(anyhow!( @@ -69,7 +77,10 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result let content = entry .get("content") .and_then(|v| v.as_str()) - .context(format!("Entry at index {} missing required field: content", idx))?; + .context(format!( + "Entry at index {} missing required field: content", + idx + ))?; if content.is_empty() { return Err(anyhow!( @@ -82,10 +93,7 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result .get("metadata") .cloned() .unwrap_or(serde_json::json!({})); - let ttl = entry - .get("ttl") - .and_then(|v| v.as_str()) - .or(default_ttl); + let ttl = entry.get("ttl").and_then(|v| v.as_str()).or(default_ttl); let expires_at = expires_at_from_ttl(ttl) .with_context(|| format!("Invalid ttl for entry at index {}", idx))?; @@ -97,13 +105,24 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result // Extract keywords let keywords = extract_keywords(content, 10); - processed_entries.push((content.to_string(), metadata, embedding, keywords, expires_at)); + processed_entries.push(( + content.to_string(), + metadata, + embedding, + keywords, + expires_at, + )); } // 5. Batch DB insert (single transaction for atomicity) let results = state .db - .batch_store_memories(agent_id, processed_entries, state.config.dedup.threshold) + .batch_store_memories( + auth_scope, + agent_id, + processed_entries, + state.config.dedup.threshold, + ) .await .context("Failed to batch store memories")?; @@ -114,5 +133,6 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result "success": true, "results": results, "count": count - }).to_string()) + }) + .to_string()) } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index e9358fb..a70fae7 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,9 +1,9 @@ //! MCP Tools for OpenBrain pub mod batch_store; +pub mod purge; pub mod query; pub mod store; -pub mod purge; use anyhow::Result; use serde_json::{json, Value}; @@ -11,11 +11,13 @@ use std::sync::Arc; use crate::AppState; +pub const INTERNAL_AUTH_SCOPE_ARG: &str = "_auth_scope"; + pub fn get_tool_definitions() -> Vec { vec![ json!({ "name": "store", - "description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.", + "description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same API-token scope and same source agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.", "inputSchema": { "type": "object", "properties": { @@ -25,7 +27,7 @@ pub fn get_tool_definitions() -> Vec { }, "agent_id": { "type": "string", - "description": "Unique identifier for the agent storing the memory (default: 'default')" + "description": "Optional source agent label recorded with the memory. If omitted, the server may fall back to X-Agent-ID or 'default'." }, "metadata": { "type": "object", @@ -47,7 +49,7 @@ pub fn get_tool_definitions() -> Vec { "properties": { "agent_id": { "type": "string", - "description": "Unique identifier for the agent storing the memories (default: 'default')" + "description": "Optional source agent label recorded with each stored memory. If omitted, the server may fall back to X-Agent-ID or 'default'." }, "ttl": { "type": "string", @@ -89,9 +91,14 @@ pub fn get_tool_definitions() -> Vec { "type": "string", "description": "The search query text" }, + "source_agent_id": { + "type": "string", + "description": "Optional provenance filter that only returns memories stored by the specified agent label" + }, "agent_id": { "type": "string", - "description": "Agent ID to search within (default: 'default')" + "description": "Deprecated legacy alias. Query visibility is scoped by API token, not by agent_id.", + "deprecated": true }, "limit": { "type": "integer", @@ -107,13 +114,18 @@ pub fn get_tool_definitions() -> Vec { }), json!({ "name": "purge", - "description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.", + "description": "Delete memories visible to the current API token. Can optionally filter by source agent label or by time range.", "inputSchema": { "type": "object", "properties": { + "source_agent_id": { + "type": "string", + "description": "Optional provenance filter that only deletes memories stored by the specified agent label" + }, "agent_id": { "type": "string", - "description": "Agent ID whose memories to delete (required)" + "description": "Deprecated legacy alias for source_agent_id", + "deprecated": true }, "before": { "type": "string", @@ -124,9 +136,9 @@ pub fn get_tool_definitions() -> Vec { "description": "Must be true to confirm deletion" } }, - "required": ["agent_id", "confirm"] + "required": ["confirm"] } - }) + }), ] } diff --git a/src/tools/purge.rs b/src/tools/purge.rs index d71017b..0ee6373 100644 --- a/src/tools/purge.rs +++ b/src/tools/purge.rs @@ -1,4 +1,4 @@ -//! Purge Tool - Delete memories by agent_id or time range +//! Purge Tool - Delete memories visible to the current token with optional filters use anyhow::{bail, Context, Result}; use chrono::DateTime; @@ -6,15 +6,21 @@ use serde_json::Value; use std::sync::Arc; use tracing::{info, warn}; +use crate::auth::PUBLIC_AUTH_SCOPE; +use crate::tools::INTERNAL_AUTH_SCOPE_ARG; use crate::AppState; /// Execute the purge tool pub async fn execute(state: &Arc, arguments: Value) -> Result { // Extract parameters - let agent_id = arguments - .get("agent_id") + let source_agent_id = arguments + .get("source_agent_id") .and_then(|v| v.as_str()) - .context("Missing required parameter: agent_id")?; + .or_else(|| arguments.get("agent_id").and_then(|v| v.as_str())); + let auth_scope = arguments + .get(INTERNAL_AUTH_SCOPE_ARG) + .and_then(|v| v.as_str()) + .unwrap_or(PUBLIC_AUTH_SCOPE); let confirm = arguments .get("confirm") @@ -36,15 +42,18 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result // Get current count before purge let count_before = state .db - .count_memories(agent_id) + .count_memories(auth_scope, source_agent_id) .await .context("Failed to count memories")?; if count_before == 0 { - info!("No memories found for agent '{}'", agent_id); + info!( + "No memories found to purge for auth scope '{}' with source_agent_id={:?}", + auth_scope, source_agent_id + ); return Ok(serde_json::json!({ "success": true, - "agent_id": agent_id, + "source_agent_id_filter": source_agent_id, "deleted": 0, "message": "No memories found to purge" }) @@ -52,25 +61,25 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result } warn!( - "Purging memories for agent '{}' (before={:?})", - agent_id, before + "Purging memories for auth scope '{}' with source_agent_id={:?} (before={:?})", + auth_scope, source_agent_id, before ); // Execute purge let deleted = state .db - .purge_memories(agent_id, before) + .purge_memories(auth_scope, source_agent_id, before) .await .context("Failed to purge memories")?; info!( - "Purged {} memories for agent '{}'", - deleted, agent_id + "Purged {} memories for auth scope '{}' with source_agent_id={:?}", + deleted, auth_scope, source_agent_id ); Ok(serde_json::json!({ "success": true, - "agent_id": agent_id, + "source_agent_id_filter": source_agent_id, "deleted": deleted, "had_before_filter": before.is_some(), "message": format!("Successfully purged {} memories", deleted) diff --git a/src/tools/query.rs b/src/tools/query.rs index 609e888..df947ae 100644 --- a/src/tools/query.rs +++ b/src/tools/query.rs @@ -1,10 +1,12 @@ //! Query Tool - Search memories by semantic similarity -use anyhow::{Context, Result, anyhow}; +use anyhow::{anyhow, Context, Result}; use serde_json::Value; use std::sync::Arc; use tracing::info; +use crate::auth::PUBLIC_AUTH_SCOPE; +use crate::tools::INTERNAL_AUTH_SCOPE_ARG; use crate::AppState; /// Execute the query tool @@ -21,10 +23,14 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result .and_then(|v| v.as_str()) .context("Missing required parameter: query")?; - let agent_id = arguments - .get("agent_id") + let source_agent_id = arguments + .get("source_agent_id") .and_then(|v| v.as_str()) - .unwrap_or("default"); + .filter(|value| !value.is_empty()); + let auth_scope = arguments + .get(INTERNAL_AUTH_SCOPE_ARG) + .and_then(|v| v.as_str()) + .unwrap_or(PUBLIC_AUTH_SCOPE); let limit = arguments .get("limit") @@ -42,8 +48,8 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result ); info!( - "Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})", - agent_id, query_text, limit, threshold, vector_weight, text_weight + "Querying memories for auth scope '{}' with source_agent_id={:?}: '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})", + auth_scope, source_agent_id, query_text, limit, threshold, vector_weight, text_weight ); // Generate embedding for query using Arc @@ -55,7 +61,8 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result let matches = state .db .query_memories( - agent_id, + auth_scope, + source_agent_id, query_text, &query_embedding, limit, @@ -79,6 +86,7 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result "vector_score": m.vector_score, "text_score": m.text_score, "hybrid_score": m.hybrid_score, + "agent_id": m.record.agent_id, "keywords": m.record.keywords, "metadata": m.record.metadata, "created_at": m.record.created_at.to_rfc3339(), @@ -89,8 +97,8 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result Ok(serde_json::json!({ "success": true, - "agent_id": agent_id, "query": query_text, + "source_agent_id_filter": source_agent_id, "vector_weight": vector_weight, "text_weight": text_weight, "count": results.len(), diff --git a/src/tools/store.rs b/src/tools/store.rs index e9b118d..e1afa99 100644 --- a/src/tools/store.rs +++ b/src/tools/store.rs @@ -1,11 +1,13 @@ //! Store Tool - Store memories with automatic embeddings -use anyhow::{Context, Result, anyhow}; +use anyhow::{anyhow, Context, Result}; use serde_json::Value; use std::sync::Arc; use tracing::info; +use crate::auth::PUBLIC_AUTH_SCOPE; use crate::embedding::extract_keywords; +use crate::tools::INTERNAL_AUTH_SCOPE_ARG; use crate::ttl::expires_at_from_ttl; use crate::AppState; @@ -27,6 +29,10 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result .get("agent_id") .and_then(|v| v.as_str()) .unwrap_or("default"); + let auth_scope = arguments + .get(INTERNAL_AUTH_SCOPE_ARG) + .and_then(|v| v.as_str()) + .unwrap_or(PUBLIC_AUTH_SCOPE); let metadata = arguments .get("metadata") @@ -54,6 +60,7 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result let id = state .db .store_memory( + auth_scope, agent_id, content, &embedding, @@ -67,7 +74,11 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result info!( "Memory {} with ID: {}", - if id.deduplicated { "deduplicated" } else { "stored" }, + if id.deduplicated { + "deduplicated" + } else { + "stored" + }, id.id ); diff --git a/src/transport.rs b/src/transport.rs index 72ffec7..3895d9a 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -5,27 +5,25 @@ use axum::{ extract::{Query, State}, - http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}}, + http::{ + header::{HOST, ORIGIN}, + HeaderMap, StatusCode, Uri, + }, response::{ - IntoResponse, Response, sse::{Event, KeepAlive, Sse}, + IntoResponse, Response, }, routing::{get, post}, Json, Router, }; use futures::stream::Stream; use serde::{Deserialize, Serialize}; -use std::{ - collections::HashMap, - convert::Infallible, - sync::Arc, - time::Duration, -}; -use tokio::sync::{RwLock, broadcast, mpsc}; +use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration}; +use tokio::sync::{broadcast, mpsc, RwLock}; use tracing::{error, info, warn}; use uuid::Uuid; -use crate::{AppState, auth, tools}; +use crate::{auth, tools, AppState}; type SessionStore = RwLock>>; @@ -46,11 +44,7 @@ impl McpState { }) } - async fn insert_session( - &self, - session_id: String, - tx: mpsc::Sender, - ) { + async fn insert_session(&self, session_id: String, tx: mpsc::Sender) { self.sessions.write().await.insert(session_id, tx); } @@ -163,7 +157,12 @@ struct PostMessageQuery { /// Create the MCP router pub fn mcp_router(state: Arc) -> Router { Router::new() - .route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler)) + .route( + "/mcp", + get(streamable_get_handler) + .post(streamable_post_handler) + .delete(streamable_delete_handler), + ) .route("/mcp/sse", get(sse_handler)) .route("/mcp/message", post(message_handler)) .route("/mcp/health", get(health_handler)) @@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> { } let origin_uri = origin.parse::().map_err(|_| { - warn!("Rejected MCP request with invalid origin header: {}", origin); + warn!( + "Rejected MCP request with invalid origin header: {}", + origin + ); StatusCode::FORBIDDEN })?; @@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> { .map(str::trim) .filter(|value| !value.is_empty()) .ok_or_else(|| { - warn!("Rejected MCP request without host header for origin {}", origin); + warn!( + "Rejected MCP request without host header for origin {}", + origin + ); StatusCode::FORBIDDEN })?; @@ -247,12 +252,12 @@ async fn streamable_post_handler( info!( method = %request.method, - agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), + client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), "Received streamable MCP request" ); - let request = apply_request_context(request, &headers); + let request = apply_request_context(request, &headers, state.app.config.auth.enabled); let response = dispatch_request(&state, &request).await; match response { @@ -284,8 +289,12 @@ async fn sse_handler( ) -> Result>>, StatusCode> { validate_origin(&headers)?; info!( - agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), - agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), + client_id = auth::get_optional_agent_id(&headers) + .as_deref() + .unwrap_or("unset"), + agent_type = auth::get_optional_agent_type(&headers) + .as_deref() + .unwrap_or("unset"), "Opening legacy SSE MCP stream" ); let mut broadcast_rx = state.event_tx.subscribe(); @@ -354,7 +363,7 @@ async fn message_handler( info!( method = %request.method, - agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), + client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), "Received legacy SSE MCP request" ); @@ -365,7 +374,7 @@ async fn message_handler( } } - let request = apply_request_context(request, &headers); + let request = apply_request_context(request, &headers, state.app.config.auth.enabled); let response = dispatch_request(&state, &request).await; match query.session_id.as_deref() { @@ -408,17 +417,10 @@ async fn route_session_response( fn apply_request_context( mut request: JsonRpcRequest, headers: &HeaderMap, + auth_enabled: bool, ) -> JsonRpcRequest { - if let Some(agent_id) = auth::get_optional_agent_id(headers) { - inject_agent_id(&mut request, &agent_id); - } - - request -} - -fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) { if request.method != "tools/call" { - return; + return request; } if !request.params.is_object() { @@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) { .params .as_object_mut() .expect("params should be an object"); + let tool_name = params + .get("name") + .and_then(|value| value.as_str()) + .unwrap_or("") + .to_string(); let arguments = params .entry("arguments") .or_insert_with(|| serde_json::json!({})); @@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) { *arguments = serde_json::json!({}); } - arguments + let arguments = arguments .as_object_mut() - .expect("arguments should be an object") - .entry("agent_id".to_string()) - .or_insert_with(|| serde_json::json!(agent_id)); + .expect("arguments should be an object"); + arguments.insert( + tools::INTERNAL_AUTH_SCOPE_ARG.to_string(), + serde_json::json!(auth::get_auth_scope(headers, auth_enabled)), + ); + + if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id") + { + if let Some(agent_id) = auth::get_optional_agent_id(headers) { + arguments.insert("agent_id".to_string(), serde_json::json!(agent_id)); + } + } + + request } async fn dispatch_request( @@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value { }) } -fn success_response( - id: serde_json::Value, - result: serde_json::Value, -) -> JsonRpcResponse { +fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse { JsonRpcResponse { jsonrpc: "2.0".to_string(), id, @@ -578,10 +593,11 @@ async fn handle_tools_call( #[cfg(test)] mod tests { use super::*; + use axum::http::HeaderValue; #[test] - fn injects_agent_id_when_missing_from_tool_arguments() { - let mut request = JsonRpcRequest { + fn request_context_injects_auth_scope_for_tool_calls() { + let request = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(serde_json::json!("1")), method: "tools/call".to_string(), @@ -592,8 +608,39 @@ mod tests { } }), }; + let mut headers = HeaderMap::new(); + headers.insert("X-API-Key", HeaderValue::from_static("test-token")); + headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop")); - inject_agent_id(&mut request, "agent-from-header"); + let request = apply_request_context(request, &headers, true); + + assert_eq!( + request + .params + .get("arguments") + .and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG)) + .and_then(|value| value.as_str()), + Some(auth::hash_api_key("test-token").as_str()) + ); + } + + #[test] + fn request_context_does_not_inject_query_filter_from_header() { + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(serde_json::json!("1")), + method: "tools/call".to_string(), + params: serde_json::json!({ + "name": "query", + "arguments": { + "query": "editor preferences" + } + }), + }; + let mut headers = HeaderMap::new(); + headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop")); + + let request = apply_request_context(request, &headers, false); assert_eq!( request @@ -601,25 +648,55 @@ mod tests { .get("arguments") .and_then(|value| value.get("agent_id")) .and_then(|value| value.as_str()), - Some("agent-from-header") + None ); } #[test] - fn preserves_explicit_agent_id() { - let mut request = JsonRpcRequest { + fn request_context_injects_store_agent_id_from_header_when_missing() { + let request = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: Some(serde_json::json!("1")), method: "tools/call".to_string(), params: serde_json::json!({ - "name": "query", + "name": "store", + "arguments": { + "content": "prefers dark mode" + } + }), + }; + let mut headers = HeaderMap::new(); + headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop")); + + let request = apply_request_context(request, &headers, false); + + assert_eq!( + request + .params + .get("arguments") + .and_then(|value| value.get("agent_id")) + .and_then(|value| value.as_str()), + Some("codex-desktop") + ); + } + + #[test] + fn request_context_preserves_explicit_store_agent_id() { + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(serde_json::json!("1")), + method: "tools/call".to_string(), + params: serde_json::json!({ + "name": "store", "arguments": { "agent_id": "explicit-agent" } }), }; + let mut headers = HeaderMap::new(); + headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop")); - inject_agent_id(&mut request, "agent-from-header"); + let request = apply_request_context(request, &headers, false); assert_eq!( request diff --git a/src/ttl.rs b/src/ttl.rs index 24a3146..ddd3a44 100644 --- a/src/ttl.rs +++ b/src/ttl.rs @@ -1,4 +1,4 @@ -use anyhow::{Result, anyhow}; +use anyhow::{anyhow, Result}; use chrono::{DateTime, Duration, Utc}; pub fn parse_ttl_spec(ttl: &str) -> Result { @@ -25,7 +25,9 @@ pub fn parse_ttl_spec(ttl: &str) -> Result { .parse() .map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?; if value <= 0 { - return Err(anyhow!("invalid ttl '{ttl}'. Duration value must be greater than zero")); + return Err(anyhow!( + "invalid ttl '{ttl}'. Duration value must be greater than zero" + )); } let total_seconds = value diff --git a/tests/e2e_mcp.rs b/tests/e2e_mcp.rs index 0176b06..246eac4 100644 --- a/tests/e2e_mcp.rs +++ b/tests/e2e_mcp.rs @@ -15,17 +15,21 @@ fn remote_mode() -> bool { } fn api_key() -> Option { - std::env::var("OPENBRAIN_E2E_API_KEY").ok() + std::env::var("OPENBRAIN_E2E_API_KEY") + .ok() .or_else(|| std::env::var("OPENBRAIN__AUTH__API_KEYS").ok()) .map(|keys| keys.split(',').next().unwrap_or("").trim().to_string()) .filter(|k| !k.is_empty()) } fn db_url() -> String { - let host = std::env::var("OPENBRAIN__DATABASE__HOST").unwrap_or_else(|_| "localhost".to_string()); + let host = + std::env::var("OPENBRAIN__DATABASE__HOST").unwrap_or_else(|_| "localhost".to_string()); let port = std::env::var("OPENBRAIN__DATABASE__PORT").unwrap_or_else(|_| "5432".to_string()); - let name = std::env::var("OPENBRAIN__DATABASE__NAME").unwrap_or_else(|_| "openbrain".to_string()); - let user = std::env::var("OPENBRAIN__DATABASE__USER").unwrap_or_else(|_| "openbrain_svc".to_string()); + let name = + std::env::var("OPENBRAIN__DATABASE__NAME").unwrap_or_else(|_| "openbrain".to_string()); + let user = + std::env::var("OPENBRAIN__DATABASE__USER").unwrap_or_else(|_| "openbrain_svc".to_string()); let password = std::env::var("OPENBRAIN__DATABASE__PASSWORD") .unwrap_or_else(|_| "your_secure_password_here".to_string()); @@ -55,7 +59,10 @@ async fn ensure_schema() { .is_some(); if !vector_exists { - if let Err(e) = client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[]).await { + if let Err(e) = client + .execute("CREATE EXTENSION IF NOT EXISTS vector", &[]) + .await + { panic!( "pgvector extension is not available for this PostgreSQL instance: {e}. \ Install pgvector for your active PostgreSQL major version, then run: CREATE EXTENSION vector;" @@ -68,6 +75,7 @@ Install pgvector for your active PostgreSQL major version, then run: CREATE EXTE r#" CREATE TABLE IF NOT EXISTS memories ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + auth_scope VARCHAR(255) NOT NULL DEFAULT 'public', agent_id VARCHAR(255) NOT NULL, content TEXT NOT NULL, embedding vector(384) NOT NULL, @@ -76,6 +84,7 @@ Install pgvector for your active PostgreSQL major version, then run: CREATE EXTE created_at TIMESTAMPTZ DEFAULT NOW(), expires_at TIMESTAMPTZ ); + ALTER TABLE memories ADD COLUMN IF NOT EXISTS auth_scope VARCHAR(255) NOT NULL DEFAULT 'public'; ALTER TABLE memories ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ; ALTER TABLE memories ADD COLUMN IF NOT EXISTS tsv tsvector; CREATE OR REPLACE FUNCTION memories_tsv_trigger() @@ -105,6 +114,8 @@ Install pgvector for your active PostgreSQL major version, then run: CREATE EXTE BEFORE INSERT OR UPDATE OF content, keywords ON memories FOR EACH ROW EXECUTE FUNCTION memories_tsv_trigger(); + CREATE INDEX IF NOT EXISTS idx_memories_auth_scope ON memories(auth_scope); + CREATE INDEX IF NOT EXISTS idx_memories_auth_scope_agent ON memories(auth_scope, agent_id); CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(agent_id); CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories USING hnsw (embedding vector_cosine_ops); @@ -135,16 +146,40 @@ async fn wait_until_ready(client: &reqwest::Client, base: &str) { panic!("Server did not become ready at {base}/ready within timeout"); } -async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> Value { - let mut req_builder = client - .post(format!("{base}/mcp/message")) - .json(&request); - - // Add API key header if available - if let Some(key) = api_key() { +fn apply_request_headers( + mut req_builder: reqwest::RequestBuilder, + api_key_override: Option<&str>, + extra_headers: &[(&str, &str)], +) -> reqwest::RequestBuilder { + if let Some(key) = api_key_override { req_builder = req_builder.header("X-API-Key", key); } - + + for (name, value) in extra_headers { + req_builder = req_builder.header(*name, *value); + } + + req_builder +} + +async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> Value { + let api_key = api_key(); + call_jsonrpc_with_options(client, base, request, api_key.as_deref(), &[]).await +} + +async fn call_jsonrpc_with_options( + client: &reqwest::Client, + base: &str, + request: Value, + api_key_override: Option<&str>, + extra_headers: &[(&str, &str)], +) -> Value { + let req_builder = apply_request_headers( + client.post(format!("{base}/mcp/message")).json(&request), + api_key_override, + extra_headers, + ); + req_builder .send() .await @@ -159,14 +194,25 @@ async fn call_streamable_jsonrpc( base: &str, request: Value, ) -> reqwest::Response { - let mut req_builder = client - .post(format!("{base}/mcp")) - .header("Accept", "application/json, text/event-stream") - .json(&request); + let api_key = api_key(); + call_streamable_jsonrpc_with_options(client, base, request, api_key.as_deref(), &[]).await +} - if let Some(key) = api_key() { - req_builder = req_builder.header("X-API-Key", key); - } +async fn call_streamable_jsonrpc_with_options( + client: &reqwest::Client, + base: &str, + request: Value, + api_key_override: Option<&str>, + extra_headers: &[(&str, &str)], +) -> reqwest::Response { + let req_builder = apply_request_headers( + client + .post(format!("{base}/mcp")) + .header("Accept", "application/json, text/event-stream") + .json(&request), + api_key_override, + extra_headers, + ); req_builder .send() @@ -176,12 +222,23 @@ async fn call_streamable_jsonrpc( /// Make an authenticated GET request to an MCP endpoint async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response { - let mut req_builder = client.get(format!("{base}{path}")); - - if let Some(key) = api_key() { - req_builder = req_builder.header("X-API-Key", key); - } - + let api_key = api_key(); + get_mcp_endpoint_with_options(client, base, path, api_key.as_deref(), &[]).await +} + +async fn get_mcp_endpoint_with_options( + client: &reqwest::Client, + base: &str, + path: &str, + api_key_override: Option<&str>, + extra_headers: &[(&str, &str)], +) -> reqwest::Response { + let req_builder = apply_request_headers( + client.get(format!("{base}{path}")), + api_key_override, + extra_headers, + ); + req_builder.send().await.expect(&format!("GET {path}")) } @@ -208,32 +265,12 @@ async fn read_sse_event( return Some((event_type, data_lines.join("\n"))); } - let chunk = response - .chunk() - .await - .expect("read SSE chunk")?; + let chunk = response.chunk().await.expect("read SSE chunk")?; buffer.push_str(std::str::from_utf8(&chunk).expect("SSE chunk should be valid UTF-8")); } } -async fn call_tool( - client: &reqwest::Client, - base: &str, - tool_name: &str, - arguments: Value, -) -> Value { - let request = json!({ - "jsonrpc": "2.0", - "id": Uuid::new_v4().to_string(), - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments - } - }); - - let response = call_jsonrpc(client, base, request).await; - +fn parse_tool_response(tool_name: &str, response: Value) -> Value { if let Some(error) = response.get("error") { panic!("tools/call for '{tool_name}' failed: {error}"); } @@ -250,6 +287,72 @@ async fn call_tool( serde_json::from_str(text_payload).expect("tool text payload to be valid JSON") } +async fn call_tool( + client: &reqwest::Client, + base: &str, + tool_name: &str, + arguments: Value, +) -> Value { + let api_key = api_key(); + call_tool_with_options(client, base, tool_name, arguments, api_key.as_deref(), &[]).await +} + +async fn call_tool_with_options( + client: &reqwest::Client, + base: &str, + tool_name: &str, + arguments: Value, + api_key_override: Option<&str>, + extra_headers: &[(&str, &str)], +) -> Value { + let request = json!({ + "jsonrpc": "2.0", + "id": Uuid::new_v4().to_string(), + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + }); + + let response = + call_jsonrpc_with_options(client, base, request, api_key_override, extra_headers).await; + parse_tool_response(tool_name, response) +} + +async fn call_tool_streamable_with_options( + client: &reqwest::Client, + base: &str, + tool_name: &str, + arguments: Value, + api_key_override: Option<&str>, + extra_headers: &[(&str, &str)], +) -> Value { + let request = json!({ + "jsonrpc": "2.0", + "id": Uuid::new_v4().to_string(), + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + }); + + let response = call_streamable_jsonrpc_with_options( + client, + base, + request, + api_key_override, + extra_headers, + ) + .await; + let response: Value = response + .json() + .await + .expect("streamable tool JSON response body"); + parse_tool_response(tool_name, response) +} + #[tokio::test] async fn e2e_store_query_purge_roundtrip() { let base = base_url(); @@ -299,7 +402,7 @@ async fn e2e_store_query_purge_roundtrip() { &base, "query", json!({ - "agent_id": agent_id, + "source_agent_id": agent_id, "query": "What are the user's editor preferences?", "limit": 5, "threshold": 0.0 @@ -347,7 +450,7 @@ async fn e2e_store_query_purge_roundtrip() { &base, "query", json!({ - "agent_id": agent_id, + "source_agent_id": agent_id, "query": "dark theme vim bindings", "limit": 5, "threshold": 0.0 @@ -395,9 +498,18 @@ async fn e2e_transport_tools_list_and_unknown_method() { .filter_map(|t| t.get("name").and_then(Value::as_str)) .collect(); - assert!(tool_names.contains(&"store"), "tools/list should include store"); - assert!(tool_names.contains(&"query"), "tools/list should include query"); - assert!(tool_names.contains(&"purge"), "tools/list should include purge"); + assert!( + tool_names.contains(&"store"), + "tools/list should include store" + ); + assert!( + tool_names.contains(&"query"), + "tools/list should include query" + ); + assert!( + tool_names.contains(&"purge"), + "tools/list should include purge" + ); let unknown_response = call_jsonrpc( &client, @@ -554,7 +666,7 @@ async fn e2e_purge_requires_confirm_flag() { } #[tokio::test] -async fn e2e_query_isolated_by_agent_id() { +async fn e2e_query_shares_memories_across_agent_ids() { let base = base_url(); let client = reqwest::Client::builder() .timeout(Duration::from_secs(20)) @@ -569,8 +681,20 @@ async fn e2e_query_isolated_by_agent_id() { let a_text = format!("A {} prefers dark mode", Uuid::new_v4()); let b_text = format!("B {} prefers light mode", Uuid::new_v4()); - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await; - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await; + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_a, "confirm": true }), + ) + .await; + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_b, "confirm": true }), + ) + .await; let _ = call_tool( &client, @@ -588,12 +712,11 @@ async fn e2e_query_isolated_by_agent_id() { ) .await; - let query_a = call_tool( + let shared_query = call_tool( &client, &base, "query", json!({ - "agent_id": agent_a, "query": "mode preference", "limit": 10, "threshold": 0.0 @@ -601,7 +724,7 @@ async fn e2e_query_isolated_by_agent_id() { ) .await; - let results = query_a + let results = shared_query .get("results") .and_then(Value::as_array) .expect("query results"); @@ -619,11 +742,62 @@ async fn e2e_query_isolated_by_agent_id() { .unwrap_or(false) }); - assert!(has_a, "agent A query should include agent A memory"); - assert!(!has_b, "agent A query must not include agent B memory"); + assert!(has_a, "shared query should include agent A memory"); + assert!(has_b, "shared query should include agent B memory"); - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await; - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await; + let filtered_query = call_tool( + &client, + &base, + "query", + json!({ + "source_agent_id": agent_a, + "query": "mode preference", + "limit": 10, + "threshold": 0.0 + }), + ) + .await; + + let filtered_results = filtered_query + .get("results") + .and_then(Value::as_array) + .expect("filtered query results"); + let filtered_has_a = filtered_results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|s| s == a_text) + .unwrap_or(false) + }); + let filtered_has_b = filtered_results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|s| s == b_text) + .unwrap_or(false) + }); + + assert!( + filtered_has_a, + "filtered query should include agent A memory" + ); + assert!( + !filtered_has_b, + "filtered query should exclude agent B memory" + ); + + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_a, "confirm": true }), + ) + .await; + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_b, "confirm": true }), + ) + .await; } #[tokio::test] @@ -670,19 +844,20 @@ async fn e2e_initialized_notification_is_accepted() { wait_until_ready(&client, &base).await; - let mut request = client - .post(format!("{base}/mcp/message")) - .json(&json!({ - "jsonrpc": "2.0", - "method": "notifications/initialized", - "params": {} - })); + let mut request = client.post(format!("{base}/mcp/message")).json(&json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + })); if let Some(key) = api_key() { request = request.header("X-API-Key", key); } - let response = request.send().await.expect("initialized notification request"); + let response = request + .send() + .await + .expect("initialized notification request"); assert_eq!( response.status(), reqwest::StatusCode::ACCEPTED, @@ -741,14 +916,12 @@ async fn e2e_sse_session_routes_posted_response() { format!("{base}{endpoint}") }; - let mut post_request = client - .post(post_url) - .json(&json!({ - "jsonrpc": "2.0", - "id": "sse-tools-list-1", - "method": "tools/list", - "params": {} - })); + let mut post_request = client.post(post_url).json(&json!({ + "jsonrpc": "2.0", + "id": "sse-tools-list-1", + "method": "tools/list", + "params": {} + })); if let Some(key) = api_key() { post_request = post_request.header("X-API-Key", key); @@ -875,7 +1048,7 @@ async fn e2e_auth_rejection_without_key() { let auth_enabled = std::env::var("OPENBRAIN__AUTH__ENABLED") .map(|v| v == "true") .unwrap_or(false); - + if !auth_enabled { println!("Skipping auth rejection test - OPENBRAIN__AUTH__ENABLED is not true"); return; @@ -933,6 +1106,18 @@ async fn wait_for_status(url: &str, expected_status: reqwest::StatusCode) { panic!("Timed out waiting for {url} to return status {expected_status}"); } +fn spawn_auth_enabled_test_server(port: u16, test_key: &str) -> std::process::Child { + Command::new(env!("CARGO_BIN_EXE_openbrain-mcp")) + .current_dir(env!("CARGO_MANIFEST_DIR")) + .env("OPENBRAIN__SERVER__PORT", port.to_string()) + .env("OPENBRAIN__AUTH__ENABLED", "true") + .env("OPENBRAIN__AUTH__API_KEYS", test_key) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("spawn openbrain-mcp for auth-enabled e2e test") +} + #[tokio::test] async fn e2e_auth_enabled_accepts_test_key() { if remote_mode() { @@ -946,15 +1131,7 @@ async fn e2e_auth_enabled_accepts_test_key() { let base = format!("http://127.0.0.1:{port}"); let test_key = "e2e-test-key-123"; - let mut server = Command::new(env!("CARGO_BIN_EXE_openbrain-mcp")) - .current_dir(env!("CARGO_MANIFEST_DIR")) - .env("OPENBRAIN__SERVER__PORT", port.to_string()) - .env("OPENBRAIN__AUTH__ENABLED", "true") - .env("OPENBRAIN__AUTH__API_KEYS", test_key) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("spawn openbrain-mcp for auth-enabled e2e test"); + let mut server = spawn_auth_enabled_test_server(port, test_key); wait_for_status(&format!("{base}/ready"), reqwest::StatusCode::OK).await; @@ -993,7 +1170,10 @@ async fn e2e_auth_enabled_accepts_test_key() { .await .expect("authorized JSON response"); - assert!(authorized.get("error").is_none(), "valid key should not return JSON-RPC error"); + assert!( + authorized.get("error").is_none(), + "valid key should not return JSON-RPC error" + ); assert!( authorized .get("result") @@ -1033,6 +1213,167 @@ async fn e2e_auth_enabled_accepts_test_key() { let _ = server.wait(); } +#[tokio::test] +async fn e2e_same_token_shares_memories_across_agent_ids_and_agent_types() { + if remote_mode() { + println!("Skipping local auth spawn test in OPENBRAIN_E2E_REMOTE mode"); + return; + } + + ensure_schema().await; + + let port = pick_free_port(); + let base = format!("http://127.0.0.1:{port}"); + let test_key = format!("e2e-shared-token-{}", Uuid::new_v4()); + let mut server = spawn_auth_enabled_test_server(port, &test_key); + + wait_for_status(&format!("{base}/ready"), reqwest::StatusCode::OK).await; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + let alpha_agent_id = "agent-alpha"; + let beta_agent_id = "agent-beta"; + let query_agent_id = "agent-gamma"; + let alpha_text = format!("Shared token alpha memory {}", Uuid::new_v4()); + let beta_text = format!("Shared token beta memory {}", Uuid::new_v4()); + + let _ = call_tool_with_options( + &client, + &base, + "store", + json!({ + "content": alpha_text, + "metadata": { "suite": "shared-token-cross-agent", "transport": "legacy" } + }), + Some(test_key.as_str()), + &[ + ("X-Agent-ID", alpha_agent_id), + ("X-Agent-Type", "agent-zero"), + ], + ) + .await; + + let _ = call_tool_streamable_with_options( + &client, + &base, + "store", + json!({ + "content": beta_text, + "metadata": { "suite": "shared-token-cross-agent", "transport": "streamable" } + }), + Some(test_key.as_str()), + &[("X-Agent-ID", beta_agent_id), ("X-Agent-Type", "codex")], + ) + .await; + + let shared_query = call_tool_streamable_with_options( + &client, + &base, + "query", + json!({ + "query": "shared token memory", + "limit": 10, + "threshold": 0.0 + }), + Some(test_key.as_str()), + &[ + ("X-Agent-ID", query_agent_id), + ("X-Agent-Type", "claude-code"), + ], + ) + .await; + + let results = shared_query["results"] + .as_array() + .expect("shared query results"); + + let has_alpha = results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|content| content == alpha_text) + .unwrap_or(false) + }); + let has_beta = results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|content| content == beta_text) + .unwrap_or(false) + }); + let source_ids: Vec<&str> = results + .iter() + .filter_map(|item| item.get("agent_id").and_then(Value::as_str)) + .collect(); + + assert!(has_alpha, "shared query should return alpha agent memory"); + assert!(has_beta, "shared query should return beta agent memory"); + assert!( + source_ids.contains(&alpha_agent_id), + "results should preserve alpha source agent provenance" + ); + assert!( + source_ids.contains(&beta_agent_id), + "results should preserve beta source agent provenance" + ); + + let filtered_query = call_tool_with_options( + &client, + &base, + "query", + json!({ + "source_agent_id": alpha_agent_id, + "query": "shared token memory", + "limit": 10, + "threshold": 0.0 + }), + Some(test_key.as_str()), + &[ + ("X-Agent-ID", query_agent_id), + ("X-Agent-Type", "claude-code"), + ], + ) + .await; + + let filtered_results = filtered_query["results"] + .as_array() + .expect("filtered query results"); + assert!( + filtered_results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|content| content == alpha_text) + .unwrap_or(false) + }), + "source_agent_id filter should retain alpha memory" + ); + assert!( + !filtered_results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|content| content == beta_text) + .unwrap_or(false) + }), + "source_agent_id filter should exclude beta memory" + ); + + let _ = call_tool_with_options( + &client, + &base, + "purge", + json!({ "confirm": true }), + Some(test_key.as_str()), + &[ + ("X-Agent-ID", query_agent_id), + ("X-Agent-Type", "claude-code"), + ], + ) + .await; + + let _ = server.kill(); + let _ = server.wait(); +} // ============================================================================= // Batch Store Tests (Issue #12) @@ -1058,14 +1399,20 @@ async fn e2e_batch_store_basic() -> anyhow::Result<()> { ) .await; - let result = call_tool(&client, &base, "batch_store", serde_json::json!({ - "agent_id": agent.clone(), - "entries": [ - {"content": "Fact alpha for batch test"}, - {"content": "Fact beta for batch test"}, - {"content": "Fact gamma for batch test"} - ] - })).await; + let result = call_tool( + &client, + &base, + "batch_store", + serde_json::json!({ + "agent_id": agent.clone(), + "entries": [ + {"content": "Fact alpha for batch test"}, + {"content": "Fact beta for batch test"}, + {"content": "Fact gamma for batch test"} + ] + }), + ) + .await; let _ = call_tool( &client, @@ -1107,7 +1454,10 @@ async fn e2e_batch_store_empty_rejected() -> anyhow::Result<()> { ) .await; - assert!(response.get("error").is_some(), "empty batch_store should return an error"); + assert!( + response.get("error").is_some(), + "empty batch_store should return an error" + ); Ok(()) } @@ -1121,7 +1471,9 @@ async fn e2e_batch_store_exceeds_max() -> anyhow::Result<()> { wait_until_ready(&client, &base).await; - let entries: Vec = (0..51).map(|i| serde_json::json!({"content": format!("Entry {}", i)})).collect(); + let entries: Vec = (0..51) + .map(|i| serde_json::json!({"content": format!("Entry {}", i)})) + .collect(); let response = call_jsonrpc( &client, &base, @@ -1139,7 +1491,10 @@ async fn e2e_batch_store_exceeds_max() -> anyhow::Result<()> { ) .await; - assert!(response.get("error").is_some(), "oversized batch_store should return an error"); + assert!( + response.get("error").is_some(), + "oversized batch_store should return an error" + ); Ok(()) } @@ -1170,7 +1525,10 @@ async fn e2e_batch_store_missing_content() -> anyhow::Result<()> { ) .await; - assert!(response.get("error").is_some(), "missing batch entry content should return an error"); + assert!( + response.get("error").is_some(), + "missing batch entry content should return an error" + ); Ok(()) } @@ -1228,10 +1586,16 @@ async fn e2e_existing_store_unchanged() -> anyhow::Result<()> { ) .await; - let result = call_tool(&client, &base, "store", serde_json::json!({ - "agent_id": agent.clone(), - "content": "Original store still works" - })).await; + let result = call_tool( + &client, + &base, + "store", + serde_json::json!({ + "agent_id": agent.clone(), + "content": "Original store still works" + }), + ) + .await; let _ = call_tool( &client, @@ -1246,7 +1610,6 @@ async fn e2e_existing_store_unchanged() -> anyhow::Result<()> { Ok(()) } - // ============================================================================= // Deduplication Tests (Issue #14) // ============================================================================= @@ -1263,7 +1626,10 @@ async fn e2e_store_deduplicates_and_merges_metadata() -> anyhow::Result<()> { wait_until_ready(&client, &base).await; let agent = format!("dedup_{}", uuid::Uuid::new_v4()); - let content = format!("Dedup fact {} prefers concise replies", uuid::Uuid::new_v4()); + let content = format!( + "Dedup fact {} prefers concise replies", + uuid::Uuid::new_v4() + ); let _ = call_tool( &client, &base, @@ -1272,25 +1638,35 @@ async fn e2e_store_deduplicates_and_merges_metadata() -> anyhow::Result<()> { ) .await; - let first = call_tool(&client, &base, "store", json!({ - "agent_id": agent.clone(), - "content": content.clone(), - "metadata": { - "source": "first", - "keep": true, - "override": "old" - } - })) + let first = call_tool( + &client, + &base, + "store", + json!({ + "agent_id": agent.clone(), + "content": content.clone(), + "metadata": { + "source": "first", + "keep": true, + "override": "old" + } + }), + ) .await; assert_eq!(first["deduplicated"].as_bool(), Some(false)); - let first_query = call_tool(&client, &base, "query", json!({ - "agent_id": agent.clone(), - "query": content.clone(), - "limit": 5, - "threshold": 0.0 - })) + let first_query = call_tool( + &client, + &base, + "query", + json!({ + "source_agent_id": agent.clone(), + "query": content.clone(), + "limit": 5, + "threshold": 0.0 + }), + ) .await; let first_created_at = first_query["results"] .as_array() @@ -1302,25 +1678,35 @@ async fn e2e_store_deduplicates_and_merges_metadata() -> anyhow::Result<()> { tokio::time::sleep(Duration::from_millis(1100)).await; - let second = call_tool(&client, &base, "store", json!({ - "agent_id": agent.clone(), - "content": content.clone(), - "metadata": { - "override": "new", - "second": true - } - })) + let second = call_tool( + &client, + &base, + "store", + json!({ + "agent_id": agent.clone(), + "content": content.clone(), + "metadata": { + "override": "new", + "second": true + } + }), + ) .await; assert_eq!(second["deduplicated"].as_bool(), Some(true)); assert_eq!(second["id"], first["id"]); - let query = call_tool(&client, &base, "query", json!({ - "agent_id": agent.clone(), - "query": content.clone(), - "limit": 5, - "threshold": 0.0 - })) + let query = call_tool( + &client, + &base, + "query", + json!({ + "source_agent_id": agent.clone(), + "query": content.clone(), + "limit": 5, + "threshold": 0.0 + }), + ) .await; assert_eq!(query["count"].as_u64(), Some(1)); @@ -1334,9 +1720,7 @@ async fn e2e_store_deduplicates_and_merges_metadata() -> anyhow::Result<()> { assert_eq!(stored["metadata"]["override"], "new"); assert_eq!(stored["metadata"]["second"], true); - let second_created_at = stored["created_at"] - .as_str() - .expect("second created_at"); + let second_created_at = stored["created_at"].as_str().expect("second created_at"); assert!( second_created_at > first_created_at.as_str(), "deduplicated write should refresh created_at" @@ -1368,26 +1752,60 @@ async fn e2e_store_dedup_is_agent_scoped() -> anyhow::Result<()> { let agent_b = format!("dedup_scope_b_{}", uuid::Uuid::new_v4()); let content = format!("Shared cross-agent fact {}", uuid::Uuid::new_v4()); - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a.clone(), "confirm": true })).await; - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b.clone(), "confirm": true })).await; - - let first = call_tool(&client, &base, "store", json!({ - "agent_id": agent_a.clone(), - "content": content.clone() - })) + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_a.clone(), "confirm": true }), + ) .await; - let second = call_tool(&client, &base, "store", json!({ - "agent_id": agent_b.clone(), - "content": content.clone() - })) + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_b.clone(), "confirm": true }), + ) + .await; + + let first = call_tool( + &client, + &base, + "store", + json!({ + "agent_id": agent_a.clone(), + "content": content.clone() + }), + ) + .await; + let second = call_tool( + &client, + &base, + "store", + json!({ + "agent_id": agent_b.clone(), + "content": content.clone() + }), + ) .await; assert_eq!(first["deduplicated"].as_bool(), Some(false)); assert_eq!(second["deduplicated"].as_bool(), Some(false)); assert_ne!(first["id"], second["id"]); - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await; - let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await; + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_a, "confirm": true }), + ) + .await; + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_b, "confirm": true }), + ) + .await; Ok(()) } @@ -1413,19 +1831,24 @@ async fn e2e_batch_store_deduplicates_within_batch() -> anyhow::Result<()> { ) .await; - let result = call_tool(&client, &base, "batch_store", json!({ - "agent_id": agent.clone(), - "entries": [ - { - "content": content.clone(), - "metadata": { "source": "first", "keep": "yes" } - }, - { - "content": content.clone(), - "metadata": { "source": "second", "merged": "yes" } - } - ] - })) + let result = call_tool( + &client, + &base, + "batch_store", + json!({ + "agent_id": agent.clone(), + "entries": [ + { + "content": content.clone(), + "metadata": { "source": "first", "keep": "yes" } + }, + { + "content": content.clone(), + "metadata": { "source": "second", "merged": "yes" } + } + ] + }), + ) .await; let results = result["results"].as_array().expect("batch results"); @@ -1436,12 +1859,17 @@ async fn e2e_batch_store_deduplicates_within_batch() -> anyhow::Result<()> { assert_eq!(results[1]["status"], "deduplicated"); assert_eq!(results[0]["id"], results[1]["id"]); - let query = call_tool(&client, &base, "query", json!({ - "agent_id": agent.clone(), - "query": content.clone(), - "limit": 5, - "threshold": 0.0 - })) + let query = call_tool( + &client, + &base, + "query", + json!({ + "source_agent_id": agent.clone(), + "query": content.clone(), + "limit": 5, + "threshold": 0.0 + }), + ) .await; assert_eq!(query["count"].as_u64(), Some(1));