Scope memories by API token and add shared-token e2e coverage

This commit is contained in:
Agent Zero
2026-04-01 23:30:58 -04:00
parent 98baa27c90
commit 026ae27366
17 changed files with 1096 additions and 428 deletions

View File

@@ -4,7 +4,7 @@
use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode, header::AUTHORIZATION},
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
middleware::Next,
response::Response,
};
@@ -14,6 +14,8 @@ use tracing::warn;
use crate::AppState;
pub const PUBLIC_AUTH_SCOPE: &str = "public";
/// Hash an API key for secure comparison
pub fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
@@ -99,24 +101,25 @@ pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
.map(ToOwned::to_owned)
}
/// Extract agent ID from request headers or default
pub fn get_agent_id(request: &Request) -> String {
get_optional_agent_id(request.headers())
.unwrap_or_else(|| "default".to_string())
pub fn get_auth_scope(headers: &HeaderMap, auth_enabled: bool) -> String {
if !auth_enabled {
return PUBLIC_AUTH_SCOPE.to_string();
}
extract_api_key(headers)
.map(|key| hash_api_key(&key))
.unwrap_or_else(|| PUBLIC_AUTH_SCOPE.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{HeaderValue, header::AUTHORIZATION};
use axum::http::{header::AUTHORIZATION, HeaderValue};
#[test]
fn extracts_api_key_from_bearer_header() {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer test-token"),
);
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer test-token"));
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
}
@@ -137,9 +140,21 @@ mod tests {
let mut headers = HeaderMap::new();
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
assert_eq!(
get_optional_agent_type(&headers).as_deref(),
Some("codex")
);
assert_eq!(get_optional_agent_type(&headers).as_deref(), Some("codex"));
}
#[test]
fn derives_auth_scope_from_api_key_when_enabled() {
let mut headers = HeaderMap::new();
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
assert_eq!(get_auth_scope(&headers, true), hash_api_key("test-token"));
}
#[test]
fn uses_public_scope_when_auth_disabled() {
let headers = HeaderMap::new();
assert_eq!(get_auth_scope(&headers, false), PUBLIC_AUTH_SCOPE);
}
}

View File

@@ -94,29 +94,50 @@ where
}
match Option::<StringOrVec>::deserialize(deserializer)? {
Some(StringOrVec::String(s)) => {
Ok(s.split(',')
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.collect())
}
Some(StringOrVec::String(s)) => Ok(s
.split(',')
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.collect()),
Some(StringOrVec::Vec(v)) => Ok(v),
None => Ok(Vec::new()),
}
}
// Default value functions
fn default_host() -> String { "0.0.0.0".to_string() }
fn default_port() -> u16 { 3100 }
fn default_db_port() -> u16 { 5432 }
fn default_pool_size() -> usize { 10 }
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
fn default_embedding_dim() -> usize { 384 }
fn default_vector_weight() -> f32 { 0.6 }
fn default_text_weight() -> f32 { 0.4 }
fn default_dedup_threshold() -> f32 { 0.90 }
fn default_cleanup_interval_seconds() -> u64 { 300 }
fn default_auth_enabled() -> bool { false }
fn default_host() -> String {
"0.0.0.0".to_string()
}
fn default_port() -> u16 {
3100
}
fn default_db_port() -> u16 {
5432
}
fn default_pool_size() -> usize {
10
}
fn default_model_path() -> String {
"models/all-MiniLM-L6-v2".to_string()
}
fn default_embedding_dim() -> usize {
384
}
fn default_vector_weight() -> f32 {
0.6
}
fn default_text_weight() -> f32 {
0.4
}
fn default_dedup_threshold() -> f32 {
0.90
}
fn default_cleanup_interval_seconds() -> u64 {
300
}
fn default_auth_enabled() -> bool {
false
}
impl Config {
/// Load configuration from environment variables

103
src/db.rs
View File

@@ -9,9 +9,9 @@ use tokio_postgres::NoTls;
use tracing::info;
use uuid::Uuid;
use crate::config::DatabaseConfig;
use serde::Serialize;
use serde_json::{Map, Value};
use crate::config::DatabaseConfig;
/// Database wrapper with connection pool
#[derive(Clone)]
@@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
async fn find_dedup_match<C>(
client: &C,
auth_scope: &str,
agent_id: &str,
embedding: &Vector,
threshold: f64,
@@ -87,13 +88,14 @@ where
r#"
SELECT id, metadata, expires_at
FROM memories
WHERE agent_id = $1
WHERE auth_scope = $1
AND agent_id = $2
AND (expires_at IS NULL OR expires_at > NOW())
AND (1 - (embedding <=> $2)) >= $3
ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC
AND (1 - (embedding <=> $3)) >= $4
ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
LIMIT 1
"#,
&[&agent_id, embedding, &threshold],
&[&auth_scope, &agent_id, embedding, &threshold],
)
.await
.context("Failed to check for duplicate memory")?;
@@ -120,13 +122,19 @@ impl Database {
.context("Failed to create database pool")?;
// Test connection
let client = pool.get().await.context("Failed to get database connection")?;
let client = pool
.get()
.await
.context("Failed to get database connection")?;
client
.simple_query("SELECT 1")
.await
.context("Failed to execute test query")?;
info!("Database connection pool created with {} connections", config.pool_size);
info!(
"Database connection pool created with {} connections",
config.pool_size
);
Ok(Self { pool })
}
@@ -134,6 +142,7 @@ impl Database {
/// Store a memory record
pub async fn store_memory(
&self,
auth_scope: &str,
agent_id: &str,
content: &str,
embedding: &[f32],
@@ -146,7 +155,9 @@ impl Database {
let vector = Vector::from(embedding.to_vec());
let dedup_threshold = dedup_threshold as f64;
if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? {
if let Some(existing) =
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
{
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
let refreshed_expires_at = expires_at.or(existing.expires_at);
@@ -176,10 +187,10 @@ impl Database {
client
.execute(
r#"
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
)
.await
.context("Failed to store memory")?;
@@ -194,7 +205,8 @@ impl Database {
/// Query memories by vector similarity
pub async fn query_memories(
&self,
agent_id: &str,
auth_scope: &str,
source_agent_id: Option<&str>,
query_text: &str,
embedding: &[f32],
limit: i64,
@@ -230,7 +242,8 @@ impl Database {
END AS text_score
FROM memories
CROSS JOIN search_query
WHERE memories.agent_id = $3
WHERE memories.auth_scope = $3
AND ($4::text IS NULL OR memories.agent_id = $4)
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
),
ranked AS (
@@ -251,18 +264,19 @@ impl Database {
text_score,
CASE
WHEN has_text_match = 1
THEN (($5 * vector_score) + ($6 * text_score))::real
THEN (($6 * vector_score) + ($7 * text_score))::real
ELSE vector_score
END AS hybrid_score
FROM ranked
WHERE vector_score >= $4 OR text_score > 0
WHERE vector_score >= $5 OR text_score > 0
ORDER BY hybrid_score DESC, vector_score DESC
LIMIT $7
LIMIT $8
"#,
&[
&vector,
&query_text,
&agent_id,
&auth_scope,
&source_agent_id,
&threshold,
&vector_weight,
&text_weight,
@@ -296,37 +310,47 @@ impl Database {
Ok(matches)
}
/// Delete memories by agent_id and optional filters
/// Delete memories visible to an auth scope with an optional provenance filter
pub async fn purge_memories(
&self,
agent_id: &str,
auth_scope: &str,
source_agent_id: Option<&str>,
before: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<u64> {
let client = self.pool.get().await?;
let count = if let Some(before_ts) = before {
client
.execute(
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
&[&agent_id, &before_ts],
)
.await?
} else {
client
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
.await?
};
let count = client
.execute(
r#"
DELETE FROM memories
WHERE auth_scope = $1
AND ($2::text IS NULL OR agent_id = $2)
AND ($3::timestamptz IS NULL OR created_at < $3)
"#,
&[&auth_scope, &source_agent_id, &before],
)
.await?;
Ok(count)
}
/// Get memory count for an agent
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
/// Get memory count for a token-visible scope and optional provenance filter
pub async fn count_memories(
&self,
auth_scope: &str,
source_agent_id: Option<&str>,
) -> Result<i64> {
let client = self.pool.get().await?;
let row = client
.query_one(
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())",
&[&agent_id],
r#"
SELECT COUNT(*) as count
FROM memories
WHERE auth_scope = $1
AND ($2::text IS NULL OR agent_id = $2)
AND (expires_at IS NULL OR expires_at > NOW())
"#,
&[&auth_scope, &source_agent_id],
)
.await?;
Ok(row.get("count"))
@@ -346,7 +370,6 @@ impl Database {
}
}
/// Result for a single batch entry
#[derive(Debug, Clone, Serialize)]
pub struct BatchStoreResult {
@@ -360,6 +383,7 @@ impl Database {
/// Store multiple memories in a single transaction
pub async fn batch_store_memories(
&self,
auth_scope: &str,
agent_id: &str,
entries: Vec<(
String,
@@ -378,7 +402,8 @@ impl Database {
for (content, metadata, embedding, keywords, expires_at) in entries {
let vector = Vector::from(embedding);
if let Some(existing) =
find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await?
find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
.await?
{
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
let refreshed_expires_at = expires_at.or(existing.expires_at);
@@ -404,8 +429,8 @@ impl Database {
} else {
let id = Uuid::new_v4();
transaction.execute(
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
).await?;
results.push(BatchStoreResult {
id: id.to_string(),

View File

@@ -1,7 +1,7 @@
//! Embedding engine using local ONNX models
use anyhow::Result;
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::session::{builder::GraphOptimizationLevel, Session};
use ort::value::Value;
use std::path::{Path, PathBuf};
use std::sync::Once;
@@ -15,9 +15,9 @@ static ORT_INIT: Once = Once::new();
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
fn init_ort_sync(dylib_path: &str) -> Result<()> {
info!("Initializing ONNX Runtime from: {}", dylib_path);
let mut init_error: Option<String> = None;
ORT_INIT.call_once(|| {
info!("ORT_INIT.call_once - starting initialization");
match ort::init_from(dylib_path) {
@@ -43,7 +43,7 @@ fn init_ort_sync(dylib_path: &str) -> Result<()> {
if let Some(err) = init_error {
return Err(anyhow::anyhow!("{}", err));
}
info!("ONNX Runtime initialization complete");
Ok(())
}
@@ -91,35 +91,40 @@ impl EmbeddingEngine {
let model_path = PathBuf::from(&config.model_path);
let dimension = config.dimension;
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
info!(
"Loading ONNX model from {:?}",
model_path.join("model.onnx")
);
// Use spawn_blocking to avoid blocking the async runtime
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
// Initialize ONNX Runtime first
init_ort_sync(&dylib_path)?;
info!("Creating ONNX session...");
// Load ONNX model with ort 2.0 API
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
.commit_from_file(model_path.join("model.onnx"))
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
let (session, tokenizer) =
tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
// Initialize ONNX Runtime first
init_ort_sync(&dylib_path)?;
info!("ONNX model loaded, loading tokenizer...");
// Load tokenizer
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
info!("Creating ONNX session...");
info!("Tokenizer loaded successfully");
Ok((session, tokenizer))
}).await
// Load ONNX model with ort 2.0 API
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
.commit_from_file(model_path.join("model.onnx"))
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
info!("ONNX model loaded, loading tokenizer...");
// Load tokenizer
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
info!("Tokenizer loaded successfully");
Ok((session, tokenizer))
})
.await
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
info!(
@@ -136,12 +141,17 @@ impl EmbeddingEngine {
/// Generate embedding for a single text
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self.tokenizer
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.map(|&x| x as i64)
.collect();
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
let seq_len = input_ids.len();
@@ -157,22 +167,25 @@ impl EmbeddingEngine {
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor,
];
let mut session_guard = self.session.lock()
let mut session_guard = self
.session
.lock()
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
let outputs = session_guard.run(inputs)?;
// Extract output
let output = outputs.get("last_hidden_state")
let output = outputs
.get("last_hidden_state")
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
// Get the tensor data
let (shape, data) = output.try_extract_tensor::<f32>()?;
// Mean pooling over sequence dimension
let hidden_size = *shape.last().unwrap_or(&384) as usize;
let seq_len = data.len() / hidden_size;
let mut embedding = vec![0.0f32; hidden_size];
for i in 0..seq_len {
for j in 0..hidden_size {
@@ -182,7 +195,7 @@ impl EmbeddingEngine {
for val in &mut embedding {
*val /= seq_len as f32;
}
// L2 normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
@@ -190,7 +203,7 @@ impl EmbeddingEngine {
*val /= norm;
}
}
Ok(embedding)
}
@@ -208,37 +221,40 @@ impl EmbeddingEngine {
/// Extract keywords from text using simple frequency analysis
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
use std::collections::HashMap;
let stop_words: std::collections::HashSet<&str> = [
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "must", "shall", "can", "this",
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
"all", "each", "every", "both", "few", "more", "most", "other", "some",
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
"very", "just", "also", "now", "here", "there", "then", "once", "if",
].iter().cloned().collect();
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
"from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do",
"does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can",
"this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what",
"which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every",
"both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own",
"same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then",
"once", "if",
]
.iter()
.cloned()
.collect();
let mut word_counts: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() {
let clean: String = word
.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_lowercase();
if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
*word_counts.entry(clean).or_insert(0) += 1;
}
}
let mut sorted: Vec<_> = word_counts.into_iter().collect();
sorted.sort_by(|a, b| b.1.cmp(&a.1));
sorted.into_iter()
sorted
.into_iter()
.take(limit)
.map(|(word, _)| word)
.collect()

View File

@@ -5,18 +5,18 @@ pub mod config;
pub mod db;
pub mod embedding;
pub mod migrations;
pub mod ttl;
pub mod tools;
pub mod transport;
pub mod ttl;
use anyhow::Result;
use axum::{Router, Json, http::StatusCode, middleware};
use axum::{http::StatusCode, middleware, Json, Router};
use serde_json::json;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::{info, error};
use tracing::{error, info};
use crate::auth::auth_middleware;
use crate::config::Config;
@@ -60,15 +60,15 @@ async fn readiness_handler(
match readiness {
ReadinessState::Ready => (
StatusCode::OK,
Json(json!({"status": "ready", "embedding": true}))
Json(json!({"status": "ready", "embedding": true})),
),
ReadinessState::Initializing => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"status": "initializing", "embedding": false}))
Json(json!({"status": "initializing", "embedding": false})),
),
ReadinessState::Failed(err) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"status": "failed", "error": err}))
Json(json!({"status": "failed", "error": err})),
),
}
}
@@ -89,11 +89,14 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
tokio::spawn(async move {
let max_retries = 3;
let mut attempt = 0;
loop {
attempt += 1;
info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries);
info!(
"Initializing embedding engine (attempt {}/{})",
attempt, max_retries
);
match EmbeddingEngine::new(&embedding_config).await {
Ok(engine) => {
let engine = Arc::new(engine);
@@ -120,9 +123,8 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
let cleanup_state = state.clone();
let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds;
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
cleanup_interval_seconds,
));
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval_seconds));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
@@ -148,10 +150,12 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
.route("/health", axum::routing::get(health_handler))
.route("/ready", axum::routing::get(readiness_handler))
.with_state(state.clone());
// Build MCP router with auth middleware
let mcp_router = transport::mcp_router(mcp_state)
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
let mcp_router = transport::mcp_router(mcp_state).layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
));
let app = Router::new()
.merge(health_router)

View File

@@ -17,7 +17,10 @@ async fn main() -> Result<()> {
.with(tracing_subscriber::fmt::layer())
.init();
info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION"));
info!(
"Starting OpenBrain MCP Server v{}",
env!("CARGO_PKG_VERSION")
);
// Load configuration
let config = Config::load()?;

View File

@@ -3,12 +3,14 @@
//! Accepts 1-50 entries per call, generates embeddings for each,
//! stores all in a single DB transaction, returns individual IDs/status.
use anyhow::{Context, Result, anyhow};
use anyhow::{anyhow, Context, Result};
use serde_json::Value;
use std::sync::Arc;
use tracing::info;
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::embedding::extract_keywords;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::ttl::expires_at_from_ttl;
use crate::AppState;
@@ -18,7 +20,7 @@ const MAX_BATCH_SIZE: usize = 50;
/// Execute the batch_store tool
///
/// Accepts:
/// - `agent_id`: Optional agent identifier (defaults to "default")
/// - `agent_id`: Optional source agent label (defaults to "default")
/// - `ttl`: Optional default TTL string applied to entries without their own ttl
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
///
@@ -38,6 +40,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.get("agent_id")
.and_then(|v| v.as_str())
.unwrap_or("default");
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let entries = arguments
.get("entries")
@@ -47,7 +53,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// 3. Validate batch size
if entries.is_empty() {
return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries"));
return Err(anyhow!(
"Empty entries array not allowed - must provide 1-50 entries"
));
}
if entries.len() > MAX_BATCH_SIZE {
return Err(anyhow!(
@@ -69,7 +77,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
let content = entry
.get("content")
.and_then(|v| v.as_str())
.context(format!("Entry at index {} missing required field: content", idx))?;
.context(format!(
"Entry at index {} missing required field: content",
idx
))?;
if content.is_empty() {
return Err(anyhow!(
@@ -82,10 +93,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.get("metadata")
.cloned()
.unwrap_or(serde_json::json!({}));
let ttl = entry
.get("ttl")
.and_then(|v| v.as_str())
.or(default_ttl);
let ttl = entry.get("ttl").and_then(|v| v.as_str()).or(default_ttl);
let expires_at = expires_at_from_ttl(ttl)
.with_context(|| format!("Invalid ttl for entry at index {}", idx))?;
@@ -97,13 +105,24 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// Extract keywords
let keywords = extract_keywords(content, 10);
processed_entries.push((content.to_string(), metadata, embedding, keywords, expires_at));
processed_entries.push((
content.to_string(),
metadata,
embedding,
keywords,
expires_at,
));
}
// 5. Batch DB insert (single transaction for atomicity)
let results = state
.db
.batch_store_memories(agent_id, processed_entries, state.config.dedup.threshold)
.batch_store_memories(
auth_scope,
agent_id,
processed_entries,
state.config.dedup.threshold,
)
.await
.context("Failed to batch store memories")?;
@@ -114,5 +133,6 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
"success": true,
"results": results,
"count": count
}).to_string())
})
.to_string())
}

View File

@@ -1,9 +1,9 @@
//! MCP Tools for OpenBrain
pub mod batch_store;
pub mod purge;
pub mod query;
pub mod store;
pub mod purge;
use anyhow::Result;
use serde_json::{json, Value};
@@ -11,11 +11,13 @@ use std::sync::Arc;
use crate::AppState;
pub const INTERNAL_AUTH_SCOPE_ARG: &str = "_auth_scope";
pub fn get_tool_definitions() -> Vec<Value> {
vec![
json!({
"name": "store",
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same API-token scope and same source agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
"inputSchema": {
"type": "object",
"properties": {
@@ -25,7 +27,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
},
"agent_id": {
"type": "string",
"description": "Unique identifier for the agent storing the memory (default: 'default')"
"description": "Optional source agent label recorded with the memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
},
"metadata": {
"type": "object",
@@ -47,7 +49,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
"properties": {
"agent_id": {
"type": "string",
"description": "Unique identifier for the agent storing the memories (default: 'default')"
"description": "Optional source agent label recorded with each stored memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
},
"ttl": {
"type": "string",
@@ -89,9 +91,14 @@ pub fn get_tool_definitions() -> Vec<Value> {
"type": "string",
"description": "The search query text"
},
"source_agent_id": {
"type": "string",
"description": "Optional provenance filter that only returns memories stored by the specified agent label"
},
"agent_id": {
"type": "string",
"description": "Agent ID to search within (default: 'default')"
"description": "Deprecated legacy alias. Query visibility is scoped by API token, not by agent_id.",
"deprecated": true
},
"limit": {
"type": "integer",
@@ -107,13 +114,18 @@ pub fn get_tool_definitions() -> Vec<Value> {
}),
json!({
"name": "purge",
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
"description": "Delete memories visible to the current API token. Can optionally filter by source agent label or by time range.",
"inputSchema": {
"type": "object",
"properties": {
"source_agent_id": {
"type": "string",
"description": "Optional provenance filter that only deletes memories stored by the specified agent label"
},
"agent_id": {
"type": "string",
"description": "Agent ID whose memories to delete (required)"
"description": "Deprecated legacy alias for source_agent_id",
"deprecated": true
},
"before": {
"type": "string",
@@ -124,9 +136,9 @@ pub fn get_tool_definitions() -> Vec<Value> {
"description": "Must be true to confirm deletion"
}
},
"required": ["agent_id", "confirm"]
"required": ["confirm"]
}
})
}),
]
}

View File

@@ -1,4 +1,4 @@
//! Purge Tool - Delete memories by agent_id or time range
//! Purge Tool - Delete memories visible to the current token with optional filters
use anyhow::{bail, Context, Result};
use chrono::DateTime;
@@ -6,15 +6,21 @@ use serde_json::Value;
use std::sync::Arc;
use tracing::{info, warn};
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::AppState;
/// Execute the purge tool
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
// Extract parameters
let agent_id = arguments
.get("agent_id")
let source_agent_id = arguments
.get("source_agent_id")
.and_then(|v| v.as_str())
.context("Missing required parameter: agent_id")?;
.or_else(|| arguments.get("agent_id").and_then(|v| v.as_str()));
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let confirm = arguments
.get("confirm")
@@ -36,15 +42,18 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// Get current count before purge
let count_before = state
.db
.count_memories(agent_id)
.count_memories(auth_scope, source_agent_id)
.await
.context("Failed to count memories")?;
if count_before == 0 {
info!("No memories found for agent '{}'", agent_id);
info!(
"No memories found to purge for auth scope '{}' with source_agent_id={:?}",
auth_scope, source_agent_id
);
return Ok(serde_json::json!({
"success": true,
"agent_id": agent_id,
"source_agent_id_filter": source_agent_id,
"deleted": 0,
"message": "No memories found to purge"
})
@@ -52,25 +61,25 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
}
warn!(
"Purging memories for agent '{}' (before={:?})",
agent_id, before
"Purging memories for auth scope '{}' with source_agent_id={:?} (before={:?})",
auth_scope, source_agent_id, before
);
// Execute purge
let deleted = state
.db
.purge_memories(agent_id, before)
.purge_memories(auth_scope, source_agent_id, before)
.await
.context("Failed to purge memories")?;
info!(
"Purged {} memories for agent '{}'",
deleted, agent_id
"Purged {} memories for auth scope '{}' with source_agent_id={:?}",
deleted, auth_scope, source_agent_id
);
Ok(serde_json::json!({
"success": true,
"agent_id": agent_id,
"source_agent_id_filter": source_agent_id,
"deleted": deleted,
"had_before_filter": before.is_some(),
"message": format!("Successfully purged {} memories", deleted)

View File

@@ -1,10 +1,12 @@
//! Query Tool - Search memories by semantic similarity
use anyhow::{Context, Result, anyhow};
use anyhow::{anyhow, Context, Result};
use serde_json::Value;
use std::sync::Arc;
use tracing::info;
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::AppState;
/// Execute the query tool
@@ -21,10 +23,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.and_then(|v| v.as_str())
.context("Missing required parameter: query")?;
let agent_id = arguments
.get("agent_id")
let source_agent_id = arguments
.get("source_agent_id")
.and_then(|v| v.as_str())
.unwrap_or("default");
.filter(|value| !value.is_empty());
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let limit = arguments
.get("limit")
@@ -42,8 +48,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
);
info!(
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
agent_id, query_text, limit, threshold, vector_weight, text_weight
"Querying memories for auth scope '{}' with source_agent_id={:?}: '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
auth_scope, source_agent_id, query_text, limit, threshold, vector_weight, text_weight
);
// Generate embedding for query using Arc<EmbeddingEngine>
@@ -55,7 +61,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
let matches = state
.db
.query_memories(
agent_id,
auth_scope,
source_agent_id,
query_text,
&query_embedding,
limit,
@@ -79,6 +86,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
"vector_score": m.vector_score,
"text_score": m.text_score,
"hybrid_score": m.hybrid_score,
"agent_id": m.record.agent_id,
"keywords": m.record.keywords,
"metadata": m.record.metadata,
"created_at": m.record.created_at.to_rfc3339(),
@@ -89,8 +97,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
Ok(serde_json::json!({
"success": true,
"agent_id": agent_id,
"query": query_text,
"source_agent_id_filter": source_agent_id,
"vector_weight": vector_weight,
"text_weight": text_weight,
"count": results.len(),

View File

@@ -1,11 +1,13 @@
//! Store Tool - Store memories with automatic embeddings
use anyhow::{Context, Result, anyhow};
use anyhow::{anyhow, Context, Result};
use serde_json::Value;
use std::sync::Arc;
use tracing::info;
use crate::auth::PUBLIC_AUTH_SCOPE;
use crate::embedding::extract_keywords;
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
use crate::ttl::expires_at_from_ttl;
use crate::AppState;
@@ -27,6 +29,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.get("agent_id")
.and_then(|v| v.as_str())
.unwrap_or("default");
let auth_scope = arguments
.get(INTERNAL_AUTH_SCOPE_ARG)
.and_then(|v| v.as_str())
.unwrap_or(PUBLIC_AUTH_SCOPE);
let metadata = arguments
.get("metadata")
@@ -54,6 +60,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
let id = state
.db
.store_memory(
auth_scope,
agent_id,
content,
&embedding,
@@ -67,7 +74,11 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
info!(
"Memory {} with ID: {}",
if id.deduplicated { "deduplicated" } else { "stored" },
if id.deduplicated {
"deduplicated"
} else {
"stored"
},
id.id
);

View File

@@ -5,27 +5,25 @@
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
http::{
header::{HOST, ORIGIN},
HeaderMap, StatusCode, Uri,
},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
IntoResponse, Response,
},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
time::Duration,
};
use tokio::sync::{RwLock, broadcast, mpsc};
use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
use tokio::sync::{broadcast, mpsc, RwLock};
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{AppState, auth, tools};
use crate::{auth, tools, AppState};
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
@@ -46,11 +44,7 @@ impl McpState {
})
}
async fn insert_session(
&self,
session_id: String,
tx: mpsc::Sender<serde_json::Value>,
) {
async fn insert_session(&self, session_id: String, tx: mpsc::Sender<serde_json::Value>) {
self.sessions.write().await.insert(session_id, tx);
}
@@ -163,7 +157,12 @@ struct PostMessageQuery {
/// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new()
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
.route(
"/mcp",
get(streamable_get_handler)
.post(streamable_post_handler)
.delete(streamable_delete_handler),
)
.route("/mcp/sse", get(sse_handler))
.route("/mcp/message", post(message_handler))
.route("/mcp/health", get(health_handler))
@@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
}
let origin_uri = origin.parse::<Uri>().map_err(|_| {
warn!("Rejected MCP request with invalid origin header: {}", origin);
warn!(
"Rejected MCP request with invalid origin header: {}",
origin
);
StatusCode::FORBIDDEN
})?;
@@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
warn!("Rejected MCP request without host header for origin {}", origin);
warn!(
"Rejected MCP request without host header for origin {}",
origin
);
StatusCode::FORBIDDEN
})?;
@@ -247,12 +252,12 @@ async fn streamable_post_handler(
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received streamable MCP request"
);
let request = apply_request_context(request, &headers);
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
let response = dispatch_request(&state, &request).await;
match response {
@@ -284,8 +289,12 @@ async fn sse_handler(
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
validate_origin(&headers)?;
info!(
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
client_id = auth::get_optional_agent_id(&headers)
.as_deref()
.unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers)
.as_deref()
.unwrap_or("unset"),
"Opening legacy SSE MCP stream"
);
let mut broadcast_rx = state.event_tx.subscribe();
@@ -354,7 +363,7 @@ async fn message_handler(
info!(
method = %request.method,
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
"Received legacy SSE MCP request"
);
@@ -365,7 +374,7 @@ async fn message_handler(
}
}
let request = apply_request_context(request, &headers);
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
let response = dispatch_request(&state, &request).await;
match query.session_id.as_deref() {
@@ -408,17 +417,10 @@ async fn route_session_response(
fn apply_request_context(
mut request: JsonRpcRequest,
headers: &HeaderMap,
auth_enabled: bool,
) -> JsonRpcRequest {
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
inject_agent_id(&mut request, &agent_id);
}
request
}
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
if request.method != "tools/call" {
return;
return request;
}
if !request.params.is_object() {
@@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
.params
.as_object_mut()
.expect("params should be an object");
let tool_name = params
.get("name")
.and_then(|value| value.as_str())
.unwrap_or("")
.to_string();
let arguments = params
.entry("arguments")
.or_insert_with(|| serde_json::json!({}));
@@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
*arguments = serde_json::json!({});
}
arguments
let arguments = arguments
.as_object_mut()
.expect("arguments should be an object")
.entry("agent_id".to_string())
.or_insert_with(|| serde_json::json!(agent_id));
.expect("arguments should be an object");
arguments.insert(
tools::INTERNAL_AUTH_SCOPE_ARG.to_string(),
serde_json::json!(auth::get_auth_scope(headers, auth_enabled)),
);
if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id")
{
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
arguments.insert("agent_id".to_string(), serde_json::json!(agent_id));
}
}
request
}
async fn dispatch_request(
@@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value {
})
}
fn success_response(
id: serde_json::Value,
result: serde_json::Value,
) -> JsonRpcResponse {
fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
@@ -578,10 +593,11 @@ async fn handle_tools_call(
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn injects_agent_id_when_missing_from_tool_arguments() {
let mut request = JsonRpcRequest {
fn request_context_injects_auth_scope_for_tool_calls() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
@@ -592,8 +608,39 @@ mod tests {
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
inject_agent_id(&mut request, "agent-from-header");
let request = apply_request_context(request, &headers, true);
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG))
.and_then(|value| value.as_str()),
Some(auth::hash_api_key("test-token").as_str())
);
}
#[test]
fn request_context_does_not_inject_query_filter_from_header() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"query": "editor preferences"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
let request = apply_request_context(request, &headers, false);
assert_eq!(
request
@@ -601,25 +648,55 @@ mod tests {
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("agent-from-header")
None
);
}
#[test]
fn preserves_explicit_agent_id() {
let mut request = JsonRpcRequest {
fn request_context_injects_store_agent_id_from_header_when_missing() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"name": "store",
"arguments": {
"content": "prefers dark mode"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
let request = apply_request_context(request, &headers, false);
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("codex-desktop")
);
}
#[test]
fn request_context_preserves_explicit_store_agent_id() {
let request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "store",
"arguments": {
"agent_id": "explicit-agent"
}
}),
};
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
inject_agent_id(&mut request, "agent-from-header");
let request = apply_request_context(request, &headers, false);
assert_eq!(
request

View File

@@ -1,4 +1,4 @@
use anyhow::{Result, anyhow};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Duration, Utc};
pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
@@ -25,7 +25,9 @@ pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
.parse()
.map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?;
if value <= 0 {
return Err(anyhow!("invalid ttl '{ttl}'. Duration value must be greater than zero"));
return Err(anyhow!(
"invalid ttl '{ttl}'. Duration value must be greater than zero"
));
}
let total_seconds = value