Scope memories by API token and add shared-token e2e coverage

This commit is contained in:
Agent Zero
2026-04-01 23:30:58 -04:00
parent 98baa27c90
commit 026ae27366
17 changed files with 1096 additions and 428 deletions

View File

@@ -7,7 +7,9 @@ memory.
## External Memory System ## External Memory System
- Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge` - Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
- Always use the exact `agent_id` value `openbrain` - Memory visibility is determined by the API token in the MCP client config, not by `agent_id`
- On `openbrain.store`, use `agent_id` only as a provenance label for the storing agent when that label is useful
- On `openbrain.query`, do not send `agent_id` for normal retrieval; use `source_agent_id` only when you intentionally want to filter by source agent
- Do not hardcode live credentials into the repository - Do not hardcode live credentials into the repository
- Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` first - Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` first
- Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names - Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names
@@ -19,7 +21,7 @@ memory.
- Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence` - Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
- If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful - If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
- If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact - If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
- Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID - Use `openbrain.purge` cautiously because it is coarse-grained; it deletes memories visible to the current API token and can optionally narrow by `source_agent_id` and `before`, not by individual memory ID
- For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless cleanup or reset is explicitly requested - For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless cleanup or reset is explicitly requested
## Agent Identity & Source Tagging ## Agent Identity & Source Tagging

View File

@@ -10,7 +10,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2 - 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing - 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
- 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility - 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility
- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent - 🔐 **Shared Memory by Token**: Agents using the same API token share memory visibility while retaining source-agent provenance
- ♻️ **Deduplicated Ingest**: Near-duplicate facts are merged instead of stored repeatedly - ♻️ **Deduplicated Ingest**: Near-duplicate facts are merged instead of stored repeatedly
-**High Performance**: Rust implementation with async I/O -**High Performance**: Rust implementation with async I/O
@@ -20,8 +20,8 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
|------|-------------| |------|-------------|
| `store` | Store a memory with automatic embedding generation, optional TTL, and automatic deduplication | | `store` | Store a memory with automatic embedding generation, optional TTL, and automatic deduplication |
| `batch_store` | Store 1-50 memories atomically in a single call with the same deduplication rules | | `batch_store` | Store 1-50 memories atomically in a single call with the same deduplication rules |
| `query` | Search memories by semantic similarity | | `query` | Search memories by semantic similarity, optionally filtering by source agent |
| `purge` | Delete memories by agent ID or time range | | `purge` | Delete memories visible to the current API token, optionally filtering by source agent or time range |
## Quick Start ## Quick Start
@@ -127,8 +127,8 @@ If you want prod e2e coverage without leaving a standing CI key on the server, t
### Deduplication on Ingest ### Deduplication on Ingest
OpenBrain checks every `store` and `batch_store` write for an existing memory in OpenBrain checks every `store` and `batch_store` write for an existing memory in
the same `agent_id` namespace whose vector similarity meets the configured the same API-token scope and same source `agent_id` whose vector similarity
dedup threshold. meets the configured dedup threshold.
Default behavior: Default behavior:
@@ -165,10 +165,12 @@ Recommended target file in A0:
### External Memory System ### External Memory System
- **Memory Boundary**: Treat OpenBrain as an external MCP long-term memory system, never as internal context, reasoning scratchpad, or built-in memory - **Memory Boundary**: Treat OpenBrain as an external MCP long-term memory system, never as internal context, reasoning scratchpad, or built-in memory
- **Tool Contract**: Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge` - **Tool Contract**: Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
- **Namespace Discipline**: Always use the exact `agent_id` value `openbrain` - **Shared Access Model**: Memory visibility is determined by the API token in the MCP client config, not by `agent_id`
- **Source Labels**: Use `agent_id` on `openbrain.store` and `openbrain.batch_store` only as a provenance label for the storing agent when that label is useful
- **EXTRAS First**: Before calling `openbrain.query`, check the `[EXTRAS]` section for pre-loaded memories or handoff facts related to the same topic. If the needed context is already present, do not query OpenBrain again. - **EXTRAS First**: Before calling `openbrain.query`, check the `[EXTRAS]` section for pre-loaded memories or handoff facts related to the same topic. If the needed context is already present, do not query OpenBrain again.
- **Session Cache**: If the same topic was already queried earlier in the current conversation and the result is still in context, reuse that result instead of querying again unless the user references new external information or the prior result is clearly insufficient. - **Session Cache**: If the same topic was already queried earlier in the current conversation and the result is still in context, reuse that result instead of querying again unless the user references new external information or the prior result is clearly insufficient.
- **Retrieval First**: Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` only when `[EXTRAS]` and the current conversation do not already provide the needed context. - **Retrieval First**: Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` only when `[EXTRAS]` and the current conversation do not already provide the needed context.
- **Query Scope**: Do not send `agent_id` with `openbrain.query` for normal retrieval. Use `source_agent_id` only when you intentionally want to filter results by the agent that originally stored them.
- **Query Strategy**: Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names; query first with `(threshold=0.15, limit=8)`, then retry once with `(threshold=0.05, limit=10)` only if the first pass returns zero useful results - **Query Strategy**: Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names; query first with `(threshold=0.15, limit=8)`, then retry once with `(threshold=0.05, limit=10)` only if the first pass returns zero useful results
- **Storage Strategy**: When a durable fact is established, call `openbrain.store` without asking permission and store one atomic fact whenever possible - **Storage Strategy**: When a durable fact is established, call `openbrain.store` without asking permission and store one atomic fact whenever possible
- **Storage Content Rules**: Store durable, high-value facts such as preferences, project status, project decisions, environment details, recurring workflows, handoff notes, stable constraints, and correction facts - **Storage Content Rules**: Store durable, high-value facts such as preferences, project status, project decisions, environment details, recurring workflows, handoff notes, stable constraints, and correction facts
@@ -177,7 +179,7 @@ Recommended target file in A0:
- **Metadata Usage**: Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence` - **Metadata Usage**: Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
- **Miss Handling**: If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful - **Miss Handling**: If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
- **Conflict Handling**: If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact - **Conflict Handling**: If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
- **Purge Constraint**: Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID - **Purge Constraint**: Use `openbrain.purge` cautiously because it is coarse-grained; it deletes memories visible to the current API token and can optionally narrow by `source_agent_id` and `before`, but not by individual memory ID
- **Correction Policy**: For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless the user explicitly asks for cleanup or reset - **Correction Policy**: For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless the user explicitly asks for cleanup or reset
- **Source Tagging**: Every `openbrain.store` call MUST include `"source_agent"` in metadata, set to the Agent Instance ID defined in the active project's identity file (e.g., `"source_agent": "ingwaz-a0"`). This enables tracing facts back to the originating agent instance. - **Source Tagging**: Every `openbrain.store` call MUST include `"source_agent"` in metadata, set to the Agent Instance ID defined in the active project's identity file (e.g., `"source_agent": "ingwaz-a0"`). This enables tracing facts back to the originating agent instance.
``` ```
@@ -259,17 +261,22 @@ legacy SSE endpoints for older MCP clients that still use the deprecated
2024-11-05 HTTP+SSE transport. 2024-11-05 HTTP+SSE transport.
Header roles: Header roles:
- `X-Agent-ID` is the memory namespace. Keep this stable if multiple clients - If two clients use the same API token, they can read and write the same
should share the same OpenBrain memories. OpenBrain memories.
- `X-Agent-Type` is an optional client profile label for logging and config - `X-Agent-ID` is an optional source-agent label for logs and store provenance.
clarity, such as `agent-zero` or `codex`. It does not control memory visibility.
- `X-Agent-Type` is an optional client-program label such as `agent-zero`,
`codex`, or `claude-code`. It does not select transport server-side; the URL
path does that.
- `agent_id` on `store` and `batch_store` is provenance. `source_agent_id` on
`query` and `purge` is an optional provenance filter.
### Example: Codex Configuration ### Example: Codex Configuration
```toml ```toml
[mcp_servers.openbrain] [mcp_servers.openbrain]
url = "https://memory.example.com/mcp" url = "https://memory.example.com/mcp"
http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbrain", "X-Agent-Type" = "codex" } http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "codex-desktop", "X-Agent-Type" = "codex" }
``` ```
### Example: Agent Zero Configuration ### Example: Agent Zero Configuration
@@ -281,7 +288,7 @@ http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbra
"url": "https://memory.example.com/mcp/sse", "url": "https://memory.example.com/mcp/sse",
"headers": { "headers": {
"X-API-Key": "YOUR_OPENBRAIN_API_KEY", "X-API-Key": "YOUR_OPENBRAIN_API_KEY",
"X-Agent-ID": "openbrain", "X-Agent-ID": "agent-zero",
"X-Agent-Type": "agent-zero" "X-Agent-Type": "agent-zero"
} }
} }
@@ -290,7 +297,8 @@ http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbra
``` ```
Agent Zero should keep using the legacy HTTP+SSE transport unless and until its Agent Zero should keep using the legacy HTTP+SSE transport unless and until its
client runtime supports streamable HTTP. Codex should use `/mcp`. client runtime supports streamable HTTP. Codex should use `/mcp`. If both
clients use the same API token, they already share memory visibility.
### Example: Store a Memory ### Example: Store a Memory
@@ -303,7 +311,7 @@ client runtime supports streamable HTTP. Codex should use `/mcp`.
"name": "store", "name": "store",
"arguments": { "arguments": {
"content": "The user prefers dark mode and uses vim keybindings", "content": "The user prefers dark mode and uses vim keybindings",
"agent_id": "assistant-1", "agent_id": "agent-zero",
"ttl": "7d", "ttl": "7d",
"metadata": {"source": "preferences"} "metadata": {"source": "preferences"}
} }
@@ -322,7 +330,6 @@ client runtime supports streamable HTTP. Codex should use `/mcp`.
"name": "query", "name": "query",
"arguments": { "arguments": {
"query": "What are the user's editor preferences?", "query": "What are the user's editor preferences?",
"agent_id": "assistant-1",
"limit": 5, "limit": 5,
"threshold": 0.6 "threshold": 0.6
} }
@@ -340,7 +347,7 @@ client runtime supports streamable HTTP. Codex should use `/mcp`.
"params": { "params": {
"name": "batch_store", "name": "batch_store",
"arguments": { "arguments": {
"agent_id": "assistant-1", "agent_id": "codex",
"entries": [ "entries": [
{ {
"content": "The user prefers dark mode", "content": "The user prefers dark mode",

View File

@@ -0,0 +1,8 @@
ALTER TABLE memories
ADD COLUMN IF NOT EXISTS auth_scope VARCHAR(255) NOT NULL DEFAULT 'public';
CREATE INDEX IF NOT EXISTS idx_memories_auth_scope
ON memories (auth_scope);
CREATE INDEX IF NOT EXISTS idx_memories_auth_scope_agent
ON memories (auth_scope, agent_id);

View File

@@ -4,7 +4,7 @@
use axum::{ use axum::{
extract::{Request, State}, extract::{Request, State},
http::{HeaderMap, StatusCode, header::AUTHORIZATION}, http::{header::AUTHORIZATION, HeaderMap, StatusCode},
middleware::Next, middleware::Next,
response::Response, response::Response,
}; };
@@ -14,6 +14,8 @@ use tracing::warn;
use crate::AppState; use crate::AppState;
pub const PUBLIC_AUTH_SCOPE: &str = "public";
/// Hash an API key for secure comparison /// Hash an API key for secure comparison
pub fn hash_api_key(key: &str) -> String { pub fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
@@ -99,24 +101,25 @@ pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
} }
/// Extract agent ID from request headers or default pub fn get_auth_scope(headers: &HeaderMap, auth_enabled: bool) -> String {
pub fn get_agent_id(request: &Request) -> String { if !auth_enabled {
get_optional_agent_id(request.headers()) return PUBLIC_AUTH_SCOPE.to_string();
.unwrap_or_else(|| "default".to_string()) }
extract_api_key(headers)
.map(|key| hash_api_key(&key))
.unwrap_or_else(|| PUBLIC_AUTH_SCOPE.to_string())
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use axum::http::{HeaderValue, header::AUTHORIZATION}; use axum::http::{header::AUTHORIZATION, HeaderValue};
#[test] #[test]
fn extracts_api_key_from_bearer_header() { fn extracts_api_key_from_bearer_header() {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert( headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer test-token"));
AUTHORIZATION,
HeaderValue::from_static("Bearer test-token"),
);
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token")); assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
} }
@@ -137,9 +140,21 @@ mod tests {
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert("X-Agent-Type", HeaderValue::from_static("codex")); headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
assert_eq!( assert_eq!(get_optional_agent_type(&headers).as_deref(), Some("codex"));
get_optional_agent_type(&headers).as_deref(), }
Some("codex")
); #[test]
fn derives_auth_scope_from_api_key_when_enabled() {
let mut headers = HeaderMap::new();
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
assert_eq!(get_auth_scope(&headers, true), hash_api_key("test-token"));
}
#[test]
fn uses_public_scope_when_auth_disabled() {
let headers = HeaderMap::new();
assert_eq!(get_auth_scope(&headers, false), PUBLIC_AUTH_SCOPE);
} }
} }

View File

@@ -94,29 +94,50 @@ where
} }
match Option::<StringOrVec>::deserialize(deserializer)? { match Option::<StringOrVec>::deserialize(deserializer)? {
Some(StringOrVec::String(s)) => { Some(StringOrVec::String(s)) => Ok(s
Ok(s.split(',') .split(',')
.map(|k| k.trim().to_string()) .map(|k| k.trim().to_string())
.filter(|k| !k.is_empty()) .filter(|k| !k.is_empty())
.collect()) .collect()),
}
Some(StringOrVec::Vec(v)) => Ok(v), Some(StringOrVec::Vec(v)) => Ok(v),
None => Ok(Vec::new()), None => Ok(Vec::new()),
} }
} }
// Default value functions // Default value functions
fn default_host() -> String { "0.0.0.0".to_string() } fn default_host() -> String {
fn default_port() -> u16 { 3100 } "0.0.0.0".to_string()
fn default_db_port() -> u16 { 5432 } }
fn default_pool_size() -> usize { 10 } fn default_port() -> u16 {
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() } 3100
fn default_embedding_dim() -> usize { 384 } }
fn default_vector_weight() -> f32 { 0.6 } fn default_db_port() -> u16 {
fn default_text_weight() -> f32 { 0.4 } 5432
fn default_dedup_threshold() -> f32 { 0.90 } }
fn default_cleanup_interval_seconds() -> u64 { 300 } fn default_pool_size() -> usize {
fn default_auth_enabled() -> bool { false } 10
}
fn default_model_path() -> String {
"models/all-MiniLM-L6-v2".to_string()
}
fn default_embedding_dim() -> usize {
384
}
fn default_vector_weight() -> f32 {
0.6
}
fn default_text_weight() -> f32 {
0.4
}
fn default_dedup_threshold() -> f32 {
0.90
}
fn default_cleanup_interval_seconds() -> u64 {
300
}
fn default_auth_enabled() -> bool {
false
}
impl Config { impl Config {
/// Load configuration from environment variables /// Load configuration from environment variables

103
src/db.rs
View File

@@ -9,9 +9,9 @@ use tokio_postgres::NoTls;
use tracing::info; use tracing::info;
use uuid::Uuid; use uuid::Uuid;
use crate::config::DatabaseConfig;
use serde::Serialize; use serde::Serialize;
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use crate::config::DatabaseConfig;
/// Database wrapper with connection pool /// Database wrapper with connection pool
#[derive(Clone)] #[derive(Clone)]
@@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
async fn find_dedup_match<C>( async fn find_dedup_match<C>(
client: &C, client: &C,
auth_scope: &str,
agent_id: &str, agent_id: &str,
embedding: &Vector, embedding: &Vector,
threshold: f64, threshold: f64,
@@ -87,13 +88,14 @@ where
r#" r#"
SELECT id, metadata, expires_at SELECT id, metadata, expires_at
FROM memories FROM memories
WHERE agent_id = $1 WHERE auth_scope = $1
AND agent_id = $2
AND (expires_at IS NULL OR expires_at > NOW()) AND (expires_at IS NULL OR expires_at > NOW())
AND (1 - (embedding <=> $2)) >= $3 AND (1 - (embedding <=> $3)) >= $4
ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
LIMIT 1 LIMIT 1
"#, "#,
&[&agent_id, embedding, &threshold], &[&auth_scope, &agent_id, embedding, &threshold],
) )
.await .await
.context("Failed to check for duplicate memory")?; .context("Failed to check for duplicate memory")?;
@@ -120,13 +122,19 @@ impl Database {
.context("Failed to create database pool")?; .context("Failed to create database pool")?;
// Test connection // Test connection
let client = pool.get().await.context("Failed to get database connection")?; let client = pool
.get()
.await
.context("Failed to get database connection")?;
client client
.simple_query("SELECT 1") .simple_query("SELECT 1")
.await .await
.context("Failed to execute test query")?; .context("Failed to execute test query")?;
info!("Database connection pool created with {} connections", config.pool_size); info!(
"Database connection pool created with {} connections",
config.pool_size
);
Ok(Self { pool }) Ok(Self { pool })
} }
@@ -134,6 +142,7 @@ impl Database {
/// Store a memory record /// Store a memory record
pub async fn store_memory( pub async fn store_memory(
&self, &self,
auth_scope: &str,
agent_id: &str, agent_id: &str,
content: &str, content: &str,
embedding: &[f32], embedding: &[f32],
@@ -146,7 +155,9 @@ impl Database {
let vector = Vector::from(embedding.to_vec()); let vector = Vector::from(embedding.to_vec());
let dedup_threshold = dedup_threshold as f64; let dedup_threshold = dedup_threshold as f64;
if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? { if let Some(existing) =
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
{
let merged_metadata = merge_metadata(&existing.metadata, &metadata); let merged_metadata = merge_metadata(&existing.metadata, &metadata);
let refreshed_expires_at = expires_at.or(existing.expires_at); let refreshed_expires_at = expires_at.or(existing.expires_at);
@@ -176,10 +187,10 @@ impl Database {
client client
.execute( .execute(
r#" r#"
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#, "#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], &[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
) )
.await .await
.context("Failed to store memory")?; .context("Failed to store memory")?;
@@ -194,7 +205,8 @@ impl Database {
/// Query memories by vector similarity /// Query memories by vector similarity
pub async fn query_memories( pub async fn query_memories(
&self, &self,
agent_id: &str, auth_scope: &str,
source_agent_id: Option<&str>,
query_text: &str, query_text: &str,
embedding: &[f32], embedding: &[f32],
limit: i64, limit: i64,
@@ -230,7 +242,8 @@ impl Database {
END AS text_score END AS text_score
FROM memories FROM memories
CROSS JOIN search_query CROSS JOIN search_query
WHERE memories.agent_id = $3 WHERE memories.auth_scope = $3
AND ($4::text IS NULL OR memories.agent_id = $4)
AND (memories.expires_at IS NULL OR memories.expires_at > NOW()) AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
), ),
ranked AS ( ranked AS (
@@ -251,18 +264,19 @@ impl Database {
text_score, text_score,
CASE CASE
WHEN has_text_match = 1 WHEN has_text_match = 1
THEN (($5 * vector_score) + ($6 * text_score))::real THEN (($6 * vector_score) + ($7 * text_score))::real
ELSE vector_score ELSE vector_score
END AS hybrid_score END AS hybrid_score
FROM ranked FROM ranked
WHERE vector_score >= $4 OR text_score > 0 WHERE vector_score >= $5 OR text_score > 0
ORDER BY hybrid_score DESC, vector_score DESC ORDER BY hybrid_score DESC, vector_score DESC
LIMIT $7 LIMIT $8
"#, "#,
&[ &[
&vector, &vector,
&query_text, &query_text,
&agent_id, &auth_scope,
&source_agent_id,
&threshold, &threshold,
&vector_weight, &vector_weight,
&text_weight, &text_weight,
@@ -296,37 +310,47 @@ impl Database {
Ok(matches) Ok(matches)
} }
/// Delete memories by agent_id and optional filters /// Delete memories visible to an auth scope with an optional provenance filter
pub async fn purge_memories( pub async fn purge_memories(
&self, &self,
agent_id: &str, auth_scope: &str,
source_agent_id: Option<&str>,
before: Option<chrono::DateTime<chrono::Utc>>, before: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<u64> { ) -> Result<u64> {
let client = self.pool.get().await?; let client = self.pool.get().await?;
let count = if let Some(before_ts) = before { let count = client
client .execute(
.execute( r#"
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2", DELETE FROM memories
&[&agent_id, &before_ts], WHERE auth_scope = $1
) AND ($2::text IS NULL OR agent_id = $2)
.await? AND ($3::timestamptz IS NULL OR created_at < $3)
} else { "#,
client &[&auth_scope, &source_agent_id, &before],
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id]) )
.await? .await?;
};
Ok(count) Ok(count)
} }
/// Get memory count for an agent /// Get memory count for a token-visible scope and optional provenance filter
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> { pub async fn count_memories(
&self,
auth_scope: &str,
source_agent_id: Option<&str>,
) -> Result<i64> {
let client = self.pool.get().await?; let client = self.pool.get().await?;
let row = client let row = client
.query_one( .query_one(
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())", r#"
&[&agent_id], SELECT COUNT(*) as count
FROM memories
WHERE auth_scope = $1
AND ($2::text IS NULL OR agent_id = $2)
AND (expires_at IS NULL OR expires_at > NOW())
"#,
&[&auth_scope, &source_agent_id],
) )
.await?; .await?;
Ok(row.get("count")) Ok(row.get("count"))
@@ -346,7 +370,6 @@ impl Database {
} }
} }
/// Result for a single batch entry /// Result for a single batch entry
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct BatchStoreResult { pub struct BatchStoreResult {
@@ -360,6 +383,7 @@ impl Database {
/// Store multiple memories in a single transaction /// Store multiple memories in a single transaction
pub async fn batch_store_memories( pub async fn batch_store_memories(
&self, &self,
auth_scope: &str,
agent_id: &str, agent_id: &str,
entries: Vec<( entries: Vec<(
String, String,
@@ -378,7 +402,8 @@ impl Database {
for (content, metadata, embedding, keywords, expires_at) in entries { for (content, metadata, embedding, keywords, expires_at) in entries {
let vector = Vector::from(embedding); let vector = Vector::from(embedding);
if let Some(existing) = if let Some(existing) =
find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await? find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
.await?
{ {
let merged_metadata = merge_metadata(&existing.metadata, &metadata); let merged_metadata = merge_metadata(&existing.metadata, &metadata);
let refreshed_expires_at = expires_at.or(existing.expires_at); let refreshed_expires_at = expires_at.or(existing.expires_at);
@@ -404,8 +429,8 @@ impl Database {
} else { } else {
let id = Uuid::new_v4(); let id = Uuid::new_v4();
transaction.execute( transaction.execute(
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#, r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], &[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
).await?; ).await?;
results.push(BatchStoreResult { results.push(BatchStoreResult {
id: id.to_string(), id: id.to_string(),

View File

@@ -1,7 +1,7 @@
//! Embedding engine using local ONNX models //! Embedding engine using local ONNX models
use anyhow::Result; use anyhow::Result;
use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::session::{builder::GraphOptimizationLevel, Session};
use ort::value::Value; use ort::value::Value;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Once; use std::sync::Once;
@@ -15,9 +15,9 @@ static ORT_INIT: Once = Once::new();
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking) /// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
fn init_ort_sync(dylib_path: &str) -> Result<()> { fn init_ort_sync(dylib_path: &str) -> Result<()> {
info!("Initializing ONNX Runtime from: {}", dylib_path); info!("Initializing ONNX Runtime from: {}", dylib_path);
let mut init_error: Option<String> = None; let mut init_error: Option<String> = None;
ORT_INIT.call_once(|| { ORT_INIT.call_once(|| {
info!("ORT_INIT.call_once - starting initialization"); info!("ORT_INIT.call_once - starting initialization");
match ort::init_from(dylib_path) { match ort::init_from(dylib_path) {
@@ -43,7 +43,7 @@ fn init_ort_sync(dylib_path: &str) -> Result<()> {
if let Some(err) = init_error { if let Some(err) = init_error {
return Err(anyhow::anyhow!("{}", err)); return Err(anyhow::anyhow!("{}", err));
} }
info!("ONNX Runtime initialization complete"); info!("ONNX Runtime initialization complete");
Ok(()) Ok(())
} }
@@ -91,35 +91,40 @@ impl EmbeddingEngine {
let model_path = PathBuf::from(&config.model_path); let model_path = PathBuf::from(&config.model_path);
let dimension = config.dimension; let dimension = config.dimension;
info!("Loading ONNX model from {:?}", model_path.join("model.onnx")); info!(
"Loading ONNX model from {:?}",
model_path.join("model.onnx")
);
// Use spawn_blocking to avoid blocking the async runtime // Use spawn_blocking to avoid blocking the async runtime
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> { let (session, tokenizer) =
// Initialize ONNX Runtime first tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
init_ort_sync(&dylib_path)?; // Initialize ONNX Runtime first
init_ort_sync(&dylib_path)?;
info!("Creating ONNX session...");
// Load ONNX model with ort 2.0 API
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
.commit_from_file(model_path.join("model.onnx"))
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
info!("ONNX model loaded, loading tokenizer..."); info!("Creating ONNX session...");
// Load tokenizer
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
info!("Tokenizer loaded successfully"); // Load ONNX model with ort 2.0 API
Ok((session, tokenizer)) let session = Session::builder()
}).await .map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
.commit_from_file(model_path.join("model.onnx"))
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
info!("ONNX model loaded, loading tokenizer...");
// Load tokenizer
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
info!("Tokenizer loaded successfully");
Ok((session, tokenizer))
})
.await
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??; .map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
info!( info!(
@@ -136,12 +141,17 @@ impl EmbeddingEngine {
/// Generate embedding for a single text /// Generate embedding for a single text
pub fn embed(&self, text: &str) -> Result<Vec<f32>> { pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self.tokenizer let encoding = self
.tokenizer
.encode(text, true) .encode(text, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect(); let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect(); let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&x| x as i64)
.collect();
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect(); let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
let seq_len = input_ids.len(); let seq_len = input_ids.len();
@@ -157,22 +167,25 @@ impl EmbeddingEngine {
"attention_mask" => attention_mask_tensor, "attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor, "token_type_ids" => token_type_ids_tensor,
]; ];
let mut session_guard = self.session.lock() let mut session_guard = self
.session
.lock()
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?; .map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
let outputs = session_guard.run(inputs)?; let outputs = session_guard.run(inputs)?;
// Extract output // Extract output
let output = outputs.get("last_hidden_state") let output = outputs
.get("last_hidden_state")
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?; .ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
// Get the tensor data // Get the tensor data
let (shape, data) = output.try_extract_tensor::<f32>()?; let (shape, data) = output.try_extract_tensor::<f32>()?;
// Mean pooling over sequence dimension // Mean pooling over sequence dimension
let hidden_size = *shape.last().unwrap_or(&384) as usize; let hidden_size = *shape.last().unwrap_or(&384) as usize;
let seq_len = data.len() / hidden_size; let seq_len = data.len() / hidden_size;
let mut embedding = vec![0.0f32; hidden_size]; let mut embedding = vec![0.0f32; hidden_size];
for i in 0..seq_len { for i in 0..seq_len {
for j in 0..hidden_size { for j in 0..hidden_size {
@@ -182,7 +195,7 @@ impl EmbeddingEngine {
for val in &mut embedding { for val in &mut embedding {
*val /= seq_len as f32; *val /= seq_len as f32;
} }
// L2 normalize // L2 normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt(); let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 { if norm > 0.0 {
@@ -190,7 +203,7 @@ impl EmbeddingEngine {
*val /= norm; *val /= norm;
} }
} }
Ok(embedding) Ok(embedding)
} }
@@ -208,37 +221,40 @@ impl EmbeddingEngine {
/// Extract keywords from text using simple frequency analysis /// Extract keywords from text using simple frequency analysis
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> { pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
use std::collections::HashMap; use std::collections::HashMap;
let stop_words: std::collections::HashSet<&str> = [ let stop_words: std::collections::HashSet<&str> = [
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been", "from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do",
"being", "have", "has", "had", "do", "does", "did", "will", "would", "does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can",
"could", "should", "may", "might", "must", "shall", "can", "this", "this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what",
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every",
"what", "which", "who", "whom", "whose", "where", "when", "why", "how", "both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own",
"all", "each", "every", "both", "few", "more", "most", "other", "some", "same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then",
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "once", "if",
"very", "just", "also", "now", "here", "there", "then", "once", "if", ]
].iter().cloned().collect(); .iter()
.cloned()
.collect();
let mut word_counts: HashMap<String, usize> = HashMap::new(); let mut word_counts: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() { for word in text.split_whitespace() {
let clean: String = word let clean: String = word
.chars() .chars()
.filter(|c| c.is_alphanumeric()) .filter(|c| c.is_alphanumeric())
.collect::<String>() .collect::<String>()
.to_lowercase(); .to_lowercase();
if clean.len() > 2 && !stop_words.contains(clean.as_str()) { if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
*word_counts.entry(clean).or_insert(0) += 1; *word_counts.entry(clean).or_insert(0) += 1;
} }
} }
let mut sorted: Vec<_> = word_counts.into_iter().collect(); let mut sorted: Vec<_> = word_counts.into_iter().collect();
sorted.sort_by(|a, b| b.1.cmp(&a.1)); sorted.sort_by(|a, b| b.1.cmp(&a.1));
sorted.into_iter() sorted
.into_iter()
.take(limit) .take(limit)
.map(|(word, _)| word) .map(|(word, _)| word)
.collect() .collect()

View File

@@ -5,18 +5,18 @@ pub mod config;
pub mod db; pub mod db;
pub mod embedding; pub mod embedding;
pub mod migrations; pub mod migrations;
pub mod ttl;
pub mod tools; pub mod tools;
pub mod transport; pub mod transport;
pub mod ttl;
use anyhow::Result; use anyhow::Result;
use axum::{Router, Json, http::StatusCode, middleware}; use axum::{http::StatusCode, middleware, Json, Router};
use serde_json::json; use serde_json::json;
use std::sync::Arc; use std::sync::Arc;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing::{info, error}; use tracing::{error, info};
use crate::auth::auth_middleware; use crate::auth::auth_middleware;
use crate::config::Config; use crate::config::Config;
@@ -60,15 +60,15 @@ async fn readiness_handler(
match readiness { match readiness {
ReadinessState::Ready => ( ReadinessState::Ready => (
StatusCode::OK, StatusCode::OK,
Json(json!({"status": "ready", "embedding": true})) Json(json!({"status": "ready", "embedding": true})),
), ),
ReadinessState::Initializing => ( ReadinessState::Initializing => (
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"status": "initializing", "embedding": false})) Json(json!({"status": "initializing", "embedding": false})),
), ),
ReadinessState::Failed(err) => ( ReadinessState::Failed(err) => (
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"status": "failed", "error": err})) Json(json!({"status": "failed", "error": err})),
), ),
} }
} }
@@ -89,11 +89,14 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
tokio::spawn(async move { tokio::spawn(async move {
let max_retries = 3; let max_retries = 3;
let mut attempt = 0; let mut attempt = 0;
loop { loop {
attempt += 1; attempt += 1;
info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries); info!(
"Initializing embedding engine (attempt {}/{})",
attempt, max_retries
);
match EmbeddingEngine::new(&embedding_config).await { match EmbeddingEngine::new(&embedding_config).await {
Ok(engine) => { Ok(engine) => {
let engine = Arc::new(engine); let engine = Arc::new(engine);
@@ -120,9 +123,8 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
let cleanup_state = state.clone(); let cleanup_state = state.clone();
let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds; let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds;
tokio::spawn(async move { tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs( let mut interval =
cleanup_interval_seconds, tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval_seconds));
));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop { loop {
@@ -148,10 +150,12 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
.route("/health", axum::routing::get(health_handler)) .route("/health", axum::routing::get(health_handler))
.route("/ready", axum::routing::get(readiness_handler)) .route("/ready", axum::routing::get(readiness_handler))
.with_state(state.clone()); .with_state(state.clone());
// Build MCP router with auth middleware // Build MCP router with auth middleware
let mcp_router = transport::mcp_router(mcp_state) let mcp_router = transport::mcp_router(mcp_state).layer(middleware::from_fn_with_state(
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware)); state.clone(),
auth_middleware,
));
let app = Router::new() let app = Router::new()
.merge(health_router) .merge(health_router)

View File

@@ -17,7 +17,10 @@ async fn main() -> Result<()> {
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION")); info!(
"Starting OpenBrain MCP Server v{}",
env!("CARGO_PKG_VERSION")
);
// Load configuration // Load configuration
let config = Config::load()?; let config = Config::load()?;

View File

@@ -3,12 +3,14 @@
//! Accepts 1-50 entries per call, generates embeddings for each, //! Accepts 1-50 entries per call, generates embeddings for each,
//! stores all in a single DB transaction, returns individual IDs/status. //! stores all in a single DB transaction, returns individual IDs/status.
use anyhow::{Context, Result, anyhow}; use anyhow::{anyhow, Context, Result};
use serde_json::Value; use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use tracing::info; use tracing::info;
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::embedding::extract_keywords; use crate::embedding::extract_keywords;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::ttl::expires_at_from_ttl; use crate::ttl::expires_at_from_ttl;
use crate::AppState; use crate::AppState;
@@ -18,7 +20,7 @@ const MAX_BATCH_SIZE: usize = 50;
/// Execute the batch_store tool /// Execute the batch_store tool
/// ///
/// Accepts: /// Accepts:
/// - `agent_id`: Optional agent identifier (defaults to "default") /// - `agent_id`: Optional source agent label (defaults to "default")
/// - `ttl`: Optional default TTL string applied to entries without their own ttl /// - `ttl`: Optional default TTL string applied to entries without their own ttl
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional) /// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
/// ///
@@ -38,6 +40,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.get("agent_id") .get("agent_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("default"); .unwrap_or("default");
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let entries = arguments let entries = arguments
.get("entries") .get("entries")
@@ -47,7 +53,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// 3. Validate batch size // 3. Validate batch size
if entries.is_empty() { if entries.is_empty() {
return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries")); return Err(anyhow!(
"Empty entries array not allowed - must provide 1-50 entries"
));
} }
if entries.len() > MAX_BATCH_SIZE { if entries.len() > MAX_BATCH_SIZE {
return Err(anyhow!( return Err(anyhow!(
@@ -69,7 +77,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
let content = entry let content = entry
.get("content") .get("content")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.context(format!("Entry at index {} missing required field: content", idx))?; .context(format!(
"Entry at index {} missing required field: content",
idx
))?;
if content.is_empty() { if content.is_empty() {
return Err(anyhow!( return Err(anyhow!(
@@ -82,10 +93,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.get("metadata") .get("metadata")
.cloned() .cloned()
.unwrap_or(serde_json::json!({})); .unwrap_or(serde_json::json!({}));
let ttl = entry let ttl = entry.get("ttl").and_then(|v| v.as_str()).or(default_ttl);
.get("ttl")
.and_then(|v| v.as_str())
.or(default_ttl);
let expires_at = expires_at_from_ttl(ttl) let expires_at = expires_at_from_ttl(ttl)
.with_context(|| format!("Invalid ttl for entry at index {}", idx))?; .with_context(|| format!("Invalid ttl for entry at index {}", idx))?;
@@ -97,13 +105,24 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// Extract keywords // Extract keywords
let keywords = extract_keywords(content, 10); let keywords = extract_keywords(content, 10);
processed_entries.push((content.to_string(), metadata, embedding, keywords, expires_at)); processed_entries.push((
content.to_string(),
metadata,
embedding,
keywords,
expires_at,
));
} }
// 5. Batch DB insert (single transaction for atomicity) // 5. Batch DB insert (single transaction for atomicity)
let results = state let results = state
.db .db
.batch_store_memories(agent_id, processed_entries, state.config.dedup.threshold) .batch_store_memories(
auth_scope,
agent_id,
processed_entries,
state.config.dedup.threshold,
)
.await .await
.context("Failed to batch store memories")?; .context("Failed to batch store memories")?;
@@ -114,5 +133,6 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
"success": true, "success": true,
"results": results, "results": results,
"count": count "count": count
}).to_string()) })
.to_string())
} }

View File

@@ -1,9 +1,9 @@
//! MCP Tools for OpenBrain //! MCP Tools for OpenBrain
pub mod batch_store; pub mod batch_store;
pub mod purge;
pub mod query; pub mod query;
pub mod store; pub mod store;
pub mod purge;
use anyhow::Result; use anyhow::Result;
use serde_json::{json, Value}; use serde_json::{json, Value};
@@ -11,11 +11,13 @@ use std::sync::Arc;
use crate::AppState; use crate::AppState;
pub const INTERNAL_AUTH_SCOPE_ARG: &str = "_auth_scope";
pub fn get_tool_definitions() -> Vec<Value> { pub fn get_tool_definitions() -> Vec<Value> {
vec![ vec![
json!({ json!({
"name": "store", "name": "store",
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.", "description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same API-token scope and same source agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
"inputSchema": { "inputSchema": {
"type": "object", "type": "object",
"properties": { "properties": {
@@ -25,7 +27,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
}, },
"agent_id": { "agent_id": {
"type": "string", "type": "string",
"description": "Unique identifier for the agent storing the memory (default: 'default')" "description": "Optional source agent label recorded with the memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
}, },
"metadata": { "metadata": {
"type": "object", "type": "object",
@@ -47,7 +49,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
"properties": { "properties": {
"agent_id": { "agent_id": {
"type": "string", "type": "string",
"description": "Unique identifier for the agent storing the memories (default: 'default')" "description": "Optional source agent label recorded with each stored memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
}, },
"ttl": { "ttl": {
"type": "string", "type": "string",
@@ -89,9 +91,14 @@ pub fn get_tool_definitions() -> Vec<Value> {
"type": "string", "type": "string",
"description": "The search query text" "description": "The search query text"
}, },
"source_agent_id": {
"type": "string",
"description": "Optional provenance filter that only returns memories stored by the specified agent label"
},
"agent_id": { "agent_id": {
"type": "string", "type": "string",
"description": "Agent ID to search within (default: 'default')" "description": "Deprecated legacy alias. Query visibility is scoped by API token, not by agent_id.",
"deprecated": true
}, },
"limit": { "limit": {
"type": "integer", "type": "integer",
@@ -107,13 +114,18 @@ pub fn get_tool_definitions() -> Vec<Value> {
}), }),
json!({ json!({
"name": "purge", "name": "purge",
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.", "description": "Delete memories visible to the current API token. Can optionally filter by source agent label or by time range.",
"inputSchema": { "inputSchema": {
"type": "object", "type": "object",
"properties": { "properties": {
"source_agent_id": {
"type": "string",
"description": "Optional provenance filter that only deletes memories stored by the specified agent label"
},
"agent_id": { "agent_id": {
"type": "string", "type": "string",
"description": "Agent ID whose memories to delete (required)" "description": "Deprecated legacy alias for source_agent_id",
"deprecated": true
}, },
"before": { "before": {
"type": "string", "type": "string",
@@ -124,9 +136,9 @@ pub fn get_tool_definitions() -> Vec<Value> {
"description": "Must be true to confirm deletion" "description": "Must be true to confirm deletion"
} }
}, },
"required": ["agent_id", "confirm"] "required": ["confirm"]
} }
}) }),
] ]
} }

View File

@@ -1,4 +1,4 @@
//! Purge Tool - Delete memories by agent_id or time range //! Purge Tool - Delete memories visible to the current token with optional filters
use anyhow::{bail, Context, Result}; use anyhow::{bail, Context, Result};
use chrono::DateTime; use chrono::DateTime;
@@ -6,15 +6,21 @@ use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::AppState; use crate::AppState;
/// Execute the purge tool /// Execute the purge tool
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> { pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
// Extract parameters // Extract parameters
let agent_id = arguments let source_agent_id = arguments
.get("agent_id") .get("source_agent_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.context("Missing required parameter: agent_id")?; .or_else(|| arguments.get("agent_id").and_then(|v| v.as_str()));
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let confirm = arguments let confirm = arguments
.get("confirm") .get("confirm")
@@ -36,15 +42,18 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// Get current count before purge // Get current count before purge
let count_before = state let count_before = state
.db .db
.count_memories(agent_id) .count_memories(auth_scope, source_agent_id)
.await .await
.context("Failed to count memories")?; .context("Failed to count memories")?;
if count_before == 0 { if count_before == 0 {
info!("No memories found for agent '{}'", agent_id); info!(
"No memories found to purge for auth scope '{}' with source_agent_id={:?}",
auth_scope, source_agent_id
);
return Ok(serde_json::json!({ return Ok(serde_json::json!({
"success": true, "success": true,
"agent_id": agent_id, "source_agent_id_filter": source_agent_id,
"deleted": 0, "deleted": 0,
"message": "No memories found to purge" "message": "No memories found to purge"
}) })
@@ -52,25 +61,25 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
} }
warn!( warn!(
"Purging memories for agent '{}' (before={:?})", "Purging memories for auth scope '{}' with source_agent_id={:?} (before={:?})",
agent_id, before auth_scope, source_agent_id, before
); );
// Execute purge // Execute purge
let deleted = state let deleted = state
.db .db
.purge_memories(agent_id, before) .purge_memories(auth_scope, source_agent_id, before)
.await .await
.context("Failed to purge memories")?; .context("Failed to purge memories")?;
info!( info!(
"Purged {} memories for agent '{}'", "Purged {} memories for auth scope '{}' with source_agent_id={:?}",
deleted, agent_id deleted, auth_scope, source_agent_id
); );
Ok(serde_json::json!({ Ok(serde_json::json!({
"success": true, "success": true,
"agent_id": agent_id, "source_agent_id_filter": source_agent_id,
"deleted": deleted, "deleted": deleted,
"had_before_filter": before.is_some(), "had_before_filter": before.is_some(),
"message": format!("Successfully purged {} memories", deleted) "message": format!("Successfully purged {} memories", deleted)

View File

@@ -1,10 +1,12 @@
//! Query Tool - Search memories by semantic similarity //! Query Tool - Search memories by semantic similarity
use anyhow::{Context, Result, anyhow}; use anyhow::{anyhow, Context, Result};
use serde_json::Value; use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use tracing::info; use tracing::info;
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::AppState; use crate::AppState;
/// Execute the query tool /// Execute the query tool
@@ -21,10 +23,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.context("Missing required parameter: query")?; .context("Missing required parameter: query")?;
let agent_id = arguments let source_agent_id = arguments
.get("agent_id") .get("source_agent_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("default"); .filter(|value| !value.is_empty());
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let limit = arguments let limit = arguments
.get("limit") .get("limit")
@@ -42,8 +48,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
); );
info!( info!(
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})", "Querying memories for auth scope '{}' with source_agent_id={:?}: '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
agent_id, query_text, limit, threshold, vector_weight, text_weight auth_scope, source_agent_id, query_text, limit, threshold, vector_weight, text_weight
); );
// Generate embedding for query using Arc<EmbeddingEngine> // Generate embedding for query using Arc<EmbeddingEngine>
@@ -55,7 +61,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
let matches = state let matches = state
.db .db
.query_memories( .query_memories(
agent_id, auth_scope,
source_agent_id,
query_text, query_text,
&query_embedding, &query_embedding,
limit, limit,
@@ -79,6 +86,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
"vector_score": m.vector_score, "vector_score": m.vector_score,
"text_score": m.text_score, "text_score": m.text_score,
"hybrid_score": m.hybrid_score, "hybrid_score": m.hybrid_score,
"agent_id": m.record.agent_id,
"keywords": m.record.keywords, "keywords": m.record.keywords,
"metadata": m.record.metadata, "metadata": m.record.metadata,
"created_at": m.record.created_at.to_rfc3339(), "created_at": m.record.created_at.to_rfc3339(),
@@ -89,8 +97,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
Ok(serde_json::json!({ Ok(serde_json::json!({
"success": true, "success": true,
"agent_id": agent_id,
"query": query_text, "query": query_text,
"source_agent_id_filter": source_agent_id,
"vector_weight": vector_weight, "vector_weight": vector_weight,
"text_weight": text_weight, "text_weight": text_weight,
"count": results.len(), "count": results.len(),

View File

@@ -1,11 +1,13 @@
//! Store Tool - Store memories with automatic embeddings //! Store Tool - Store memories with automatic embeddings
use anyhow::{Context, Result, anyhow}; use anyhow::{anyhow, Context, Result};
use serde_json::Value; use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
use tracing::info; use tracing::info;
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::embedding::extract_keywords; use crate::embedding::extract_keywords;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::ttl::expires_at_from_ttl; use crate::ttl::expires_at_from_ttl;
use crate::AppState; use crate::AppState;
@@ -27,6 +29,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.get("agent_id") .get("agent_id")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("default"); .unwrap_or("default");
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let metadata = arguments let metadata = arguments
.get("metadata") .get("metadata")
@@ -54,6 +60,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
let id = state let id = state
.db .db
.store_memory( .store_memory(
auth_scope,
agent_id, agent_id,
content, content,
&embedding, &embedding,
@@ -67,7 +74,11 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
info!( info!(
"Memory {} with ID: {}", "Memory {} with ID: {}",
if id.deduplicated { "deduplicated" } else { "stored" }, if id.deduplicated {
"deduplicated"
} else {
"stored"
},
id.id id.id
); );

View File

@@ -5,27 +5,25 @@
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}}, http::{
header::{HOST, ORIGIN},
HeaderMap, StatusCode, Uri,
},
response::{ response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse}, sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
}, },
routing::{get, post}, routing::{get, post},
Json, Router, Json, Router,
}; };
use futures::stream::Stream; use futures::stream::Stream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
collections::HashMap, use tokio::sync::{broadcast, mpsc, RwLock};
convert::Infallible,
sync::Arc,
time::Duration,
};
use tokio::sync::{RwLock, broadcast, mpsc};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use uuid::Uuid; use uuid::Uuid;
use crate::{AppState, auth, tools}; use crate::{auth, tools, AppState};
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>; type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
@@ -46,11 +44,7 @@ impl McpState {
}) })
} }
async fn insert_session( async fn insert_session(&self, session_id: String, tx: mpsc::Sender<serde_json::Value>) {
&self,
session_id: String,
tx: mpsc::Sender<serde_json::Value>,
) {
self.sessions.write().await.insert(session_id, tx); self.sessions.write().await.insert(session_id, tx);
} }
@@ -163,7 +157,12 @@ struct PostMessageQuery {
/// Create the MCP router /// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router { pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new() Router::new()
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler)) .route(
"/mcp",
get(streamable_get_handler)
.post(streamable_post_handler)
.delete(streamable_delete_handler),
)
.route("/mcp/sse", get(sse_handler)) .route("/mcp/sse", get(sse_handler))
.route("/mcp/message", post(message_handler)) .route("/mcp/message", post(message_handler))
.route("/mcp/health", get(health_handler)) .route("/mcp/health", get(health_handler))
@@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
} }
let origin_uri = origin.parse::<Uri>().map_err(|_| { let origin_uri = origin.parse::<Uri>().map_err(|_| {
warn!("Rejected MCP request with invalid origin header: {}", origin); warn!(
"Rejected MCP request with invalid origin header: {}",
origin
);
StatusCode::FORBIDDEN StatusCode::FORBIDDEN
})?; })?;
@@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
.map(str::trim) .map(str::trim)
.filter(|value| !value.is_empty()) .filter(|value| !value.is_empty())
.ok_or_else(|| { .ok_or_else(|| {
warn!("Rejected MCP request without host header for origin {}", origin); warn!(
"Rejected MCP request without host header for origin {}",
origin
);
StatusCode::FORBIDDEN StatusCode::FORBIDDEN
})?; })?;
@@ -247,12 +252,12 @@ async fn streamable_post_handler(
info!( info!(
method = %request.method, method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received streamable MCP request" "Received streamable MCP request"
); );
let request = apply_request_context(request, &headers); let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
let response = dispatch_request(&state, &request).await; let response = dispatch_request(&state, &request).await;
match response { match response {
@@ -284,8 +289,12 @@ async fn sse_handler(
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> { ) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
validate_origin(&headers)?; validate_origin(&headers)?;
info!( info!(
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), client_id = auth::get_optional_agent_id(&headers)
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), .as_deref()
.unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers)
.as_deref()
.unwrap_or("unset"),
"Opening legacy SSE MCP stream" "Opening legacy SSE MCP stream"
); );
let mut broadcast_rx = state.event_tx.subscribe(); let mut broadcast_rx = state.event_tx.subscribe();
@@ -354,7 +363,7 @@ async fn message_handler(
info!( info!(
method = %request.method, method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"), client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"), agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received legacy SSE MCP request" "Received legacy SSE MCP request"
); );
@@ -365,7 +374,7 @@ async fn message_handler(
} }
} }
let request = apply_request_context(request, &headers); let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
let response = dispatch_request(&state, &request).await; let response = dispatch_request(&state, &request).await;
match query.session_id.as_deref() { match query.session_id.as_deref() {
@@ -408,17 +417,10 @@ async fn route_session_response(
fn apply_request_context( fn apply_request_context(
mut request: JsonRpcRequest, mut request: JsonRpcRequest,
headers: &HeaderMap, headers: &HeaderMap,
auth_enabled: bool,
) -> JsonRpcRequest { ) -> JsonRpcRequest {
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
inject_agent_id(&mut request, &agent_id);
}
request
}
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
if request.method != "tools/call" { if request.method != "tools/call" {
return; return request;
} }
if !request.params.is_object() { if !request.params.is_object() {
@@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
.params .params
.as_object_mut() .as_object_mut()
.expect("params should be an object"); .expect("params should be an object");
let tool_name = params
.get("name")
.and_then(|value| value.as_str())
.unwrap_or("")
.to_string();
let arguments = params let arguments = params
.entry("arguments") .entry("arguments")
.or_insert_with(|| serde_json::json!({})); .or_insert_with(|| serde_json::json!({}));
@@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
*arguments = serde_json::json!({}); *arguments = serde_json::json!({});
} }
arguments let arguments = arguments
.as_object_mut() .as_object_mut()
.expect("arguments should be an object") .expect("arguments should be an object");
.entry("agent_id".to_string()) arguments.insert(
.or_insert_with(|| serde_json::json!(agent_id)); tools::INTERNAL_AUTH_SCOPE_ARG.to_string(),
serde_json::json!(auth::get_auth_scope(headers, auth_enabled)),
);
if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id")
{
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
arguments.insert("agent_id".to_string(), serde_json::json!(agent_id));
}
}
request
} }
async fn dispatch_request( async fn dispatch_request(
@@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value {
}) })
} }
fn success_response( fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse {
id: serde_json::Value,
result: serde_json::Value,
) -> JsonRpcResponse {
JsonRpcResponse { JsonRpcResponse {
jsonrpc: "2.0".to_string(), jsonrpc: "2.0".to_string(),
id, id,
@@ -578,10 +593,11 @@ async fn handle_tools_call(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use axum::http::HeaderValue;
#[test] #[test]
fn injects_agent_id_when_missing_from_tool_arguments() { fn request_context_injects_auth_scope_for_tool_calls() {
let mut request = JsonRpcRequest { let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(), jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")), id: Some(serde_json::json!("1")),
method: "tools/call".to_string(), method: "tools/call".to_string(),
@@ -592,8 +608,39 @@ mod tests {
} }
}), }),
}; };
let mut headers = HeaderMap::new();
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
inject_agent_id(&mut request, "agent-from-header"); let request = apply_request_context(request, &headers, true);
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG))
.and_then(|value| value.as_str()),
Some(auth::hash_api_key("test-token").as_str())
);
}
#[test]
fn request_context_does_not_inject_query_filter_from_header() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"query": "editor preferences"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
let request = apply_request_context(request, &headers, false);
assert_eq!( assert_eq!(
request request
@@ -601,25 +648,55 @@ mod tests {
.get("arguments") .get("arguments")
.and_then(|value| value.get("agent_id")) .and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()), .and_then(|value| value.as_str()),
Some("agent-from-header") None
); );
} }
#[test] #[test]
fn preserves_explicit_agent_id() { fn request_context_injects_store_agent_id_from_header_when_missing() {
let mut request = JsonRpcRequest { let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(), jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")), id: Some(serde_json::json!("1")),
method: "tools/call".to_string(), method: "tools/call".to_string(),
params: serde_json::json!({ params: serde_json::json!({
"name": "query", "name": "store",
"arguments": {
"content": "prefers dark mode"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
let request = apply_request_context(request, &headers, false);
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("codex-desktop")
);
}
#[test]
fn request_context_preserves_explicit_store_agent_id() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "store",
"arguments": { "arguments": {
"agent_id": "explicit-agent" "agent_id": "explicit-agent"
} }
}), }),
}; };
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
inject_agent_id(&mut request, "agent-from-header"); let request = apply_request_context(request, &headers, false);
assert_eq!( assert_eq!(
request request

View File

@@ -1,4 +1,4 @@
use anyhow::{Result, anyhow}; use anyhow::{anyhow, Result};
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> { pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
@@ -25,7 +25,9 @@ pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
.parse() .parse()
.map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?; .map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?;
if value <= 0 { if value <= 0 {
return Err(anyhow!("invalid ttl '{ttl}'. Duration value must be greater than zero")); return Err(anyhow!(
"invalid ttl '{ttl}'. Duration value must be greater than zero"
));
} }
let total_seconds = value let total_seconds = value

File diff suppressed because it is too large Load Diff