mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-06-15 22:07:08 +00:00
Scope memories by API token and add shared-token e2e coverage
This commit is contained in:
43
src/auth.rs
43
src/auth.rs
@@ -4,7 +4,7 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, StatusCode, header::AUTHORIZATION},
|
||||
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
@@ -14,6 +14,8 @@ use tracing::warn;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub const PUBLIC_AUTH_SCOPE: &str = "public";
|
||||
|
||||
/// Hash an API key for secure comparison
|
||||
pub fn hash_api_key(key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
@@ -99,24 +101,25 @@ pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
|
||||
.map(ToOwned::to_owned)
|
||||
}
|
||||
|
||||
/// Extract agent ID from request headers or default
|
||||
pub fn get_agent_id(request: &Request) -> String {
|
||||
get_optional_agent_id(request.headers())
|
||||
.unwrap_or_else(|| "default".to_string())
|
||||
pub fn get_auth_scope(headers: &HeaderMap, auth_enabled: bool) -> String {
|
||||
if !auth_enabled {
|
||||
return PUBLIC_AUTH_SCOPE.to_string();
|
||||
}
|
||||
|
||||
extract_api_key(headers)
|
||||
.map(|key| hash_api_key(&key))
|
||||
.unwrap_or_else(|| PUBLIC_AUTH_SCOPE.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::{HeaderValue, header::AUTHORIZATION};
|
||||
use axum::http::{header::AUTHORIZATION, HeaderValue};
|
||||
|
||||
#[test]
|
||||
fn extracts_api_key_from_bearer_header() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_static("Bearer test-token"),
|
||||
);
|
||||
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer test-token"));
|
||||
|
||||
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
|
||||
}
|
||||
@@ -137,9 +140,21 @@ mod tests {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
|
||||
|
||||
assert_eq!(
|
||||
get_optional_agent_type(&headers).as_deref(),
|
||||
Some("codex")
|
||||
);
|
||||
assert_eq!(get_optional_agent_type(&headers).as_deref(), Some("codex"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derives_auth_scope_from_api_key_when_enabled() {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
|
||||
|
||||
assert_eq!(get_auth_scope(&headers, true), hash_api_key("test-token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uses_public_scope_when_auth_disabled() {
|
||||
let headers = HeaderMap::new();
|
||||
|
||||
assert_eq!(get_auth_scope(&headers, false), PUBLIC_AUTH_SCOPE);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,29 +94,50 @@ where
|
||||
}
|
||||
|
||||
match Option::<StringOrVec>::deserialize(deserializer)? {
|
||||
Some(StringOrVec::String(s)) => {
|
||||
Ok(s.split(',')
|
||||
.map(|k| k.trim().to_string())
|
||||
.filter(|k| !k.is_empty())
|
||||
.collect())
|
||||
}
|
||||
Some(StringOrVec::String(s)) => Ok(s
|
||||
.split(',')
|
||||
.map(|k| k.trim().to_string())
|
||||
.filter(|k| !k.is_empty())
|
||||
.collect()),
|
||||
Some(StringOrVec::Vec(v)) => Ok(v),
|
||||
None => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
// Default value functions
|
||||
fn default_host() -> String { "0.0.0.0".to_string() }
|
||||
fn default_port() -> u16 { 3100 }
|
||||
fn default_db_port() -> u16 { 5432 }
|
||||
fn default_pool_size() -> usize { 10 }
|
||||
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
||||
fn default_embedding_dim() -> usize { 384 }
|
||||
fn default_vector_weight() -> f32 { 0.6 }
|
||||
fn default_text_weight() -> f32 { 0.4 }
|
||||
fn default_dedup_threshold() -> f32 { 0.90 }
|
||||
fn default_cleanup_interval_seconds() -> u64 { 300 }
|
||||
fn default_auth_enabled() -> bool { false }
|
||||
fn default_host() -> String {
|
||||
"0.0.0.0".to_string()
|
||||
}
|
||||
fn default_port() -> u16 {
|
||||
3100
|
||||
}
|
||||
fn default_db_port() -> u16 {
|
||||
5432
|
||||
}
|
||||
fn default_pool_size() -> usize {
|
||||
10
|
||||
}
|
||||
fn default_model_path() -> String {
|
||||
"models/all-MiniLM-L6-v2".to_string()
|
||||
}
|
||||
fn default_embedding_dim() -> usize {
|
||||
384
|
||||
}
|
||||
fn default_vector_weight() -> f32 {
|
||||
0.6
|
||||
}
|
||||
fn default_text_weight() -> f32 {
|
||||
0.4
|
||||
}
|
||||
fn default_dedup_threshold() -> f32 {
|
||||
0.90
|
||||
}
|
||||
fn default_cleanup_interval_seconds() -> u64 {
|
||||
300
|
||||
}
|
||||
fn default_auth_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load configuration from environment variables
|
||||
|
||||
103
src/db.rs
103
src/db.rs
@@ -9,9 +9,9 @@ use tokio_postgres::NoTls;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::DatabaseConfig;
|
||||
use serde::Serialize;
|
||||
use serde_json::{Map, Value};
|
||||
use crate::config::DatabaseConfig;
|
||||
|
||||
/// Database wrapper with connection pool
|
||||
#[derive(Clone)]
|
||||
@@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
|
||||
|
||||
async fn find_dedup_match<C>(
|
||||
client: &C,
|
||||
auth_scope: &str,
|
||||
agent_id: &str,
|
||||
embedding: &Vector,
|
||||
threshold: f64,
|
||||
@@ -87,13 +88,14 @@ where
|
||||
r#"
|
||||
SELECT id, metadata, expires_at
|
||||
FROM memories
|
||||
WHERE agent_id = $1
|
||||
WHERE auth_scope = $1
|
||||
AND agent_id = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
AND (1 - (embedding <=> $2)) >= $3
|
||||
ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC
|
||||
AND (1 - (embedding <=> $3)) >= $4
|
||||
ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
&[&agent_id, embedding, &threshold],
|
||||
&[&auth_scope, &agent_id, embedding, &threshold],
|
||||
)
|
||||
.await
|
||||
.context("Failed to check for duplicate memory")?;
|
||||
@@ -120,13 +122,19 @@ impl Database {
|
||||
.context("Failed to create database pool")?;
|
||||
|
||||
// Test connection
|
||||
let client = pool.get().await.context("Failed to get database connection")?;
|
||||
let client = pool
|
||||
.get()
|
||||
.await
|
||||
.context("Failed to get database connection")?;
|
||||
client
|
||||
.simple_query("SELECT 1")
|
||||
.await
|
||||
.context("Failed to execute test query")?;
|
||||
|
||||
info!("Database connection pool created with {} connections", config.pool_size);
|
||||
info!(
|
||||
"Database connection pool created with {} connections",
|
||||
config.pool_size
|
||||
);
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
@@ -134,6 +142,7 @@ impl Database {
|
||||
/// Store a memory record
|
||||
pub async fn store_memory(
|
||||
&self,
|
||||
auth_scope: &str,
|
||||
agent_id: &str,
|
||||
content: &str,
|
||||
embedding: &[f32],
|
||||
@@ -146,7 +155,9 @@ impl Database {
|
||||
let vector = Vector::from(embedding.to_vec());
|
||||
let dedup_threshold = dedup_threshold as f64;
|
||||
|
||||
if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? {
|
||||
if let Some(existing) =
|
||||
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
|
||||
{
|
||||
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
||||
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
||||
|
||||
@@ -176,10 +187,10 @@ impl Database {
|
||||
client
|
||||
.execute(
|
||||
r#"
|
||||
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
)
|
||||
.await
|
||||
.context("Failed to store memory")?;
|
||||
@@ -194,7 +205,8 @@ impl Database {
|
||||
/// Query memories by vector similarity
|
||||
pub async fn query_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
auth_scope: &str,
|
||||
source_agent_id: Option<&str>,
|
||||
query_text: &str,
|
||||
embedding: &[f32],
|
||||
limit: i64,
|
||||
@@ -230,7 +242,8 @@ impl Database {
|
||||
END AS text_score
|
||||
FROM memories
|
||||
CROSS JOIN search_query
|
||||
WHERE memories.agent_id = $3
|
||||
WHERE memories.auth_scope = $3
|
||||
AND ($4::text IS NULL OR memories.agent_id = $4)
|
||||
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
|
||||
),
|
||||
ranked AS (
|
||||
@@ -251,18 +264,19 @@ impl Database {
|
||||
text_score,
|
||||
CASE
|
||||
WHEN has_text_match = 1
|
||||
THEN (($5 * vector_score) + ($6 * text_score))::real
|
||||
THEN (($6 * vector_score) + ($7 * text_score))::real
|
||||
ELSE vector_score
|
||||
END AS hybrid_score
|
||||
FROM ranked
|
||||
WHERE vector_score >= $4 OR text_score > 0
|
||||
WHERE vector_score >= $5 OR text_score > 0
|
||||
ORDER BY hybrid_score DESC, vector_score DESC
|
||||
LIMIT $7
|
||||
LIMIT $8
|
||||
"#,
|
||||
&[
|
||||
&vector,
|
||||
&query_text,
|
||||
&agent_id,
|
||||
&auth_scope,
|
||||
&source_agent_id,
|
||||
&threshold,
|
||||
&vector_weight,
|
||||
&text_weight,
|
||||
@@ -296,37 +310,47 @@ impl Database {
|
||||
Ok(matches)
|
||||
}
|
||||
|
||||
/// Delete memories by agent_id and optional filters
|
||||
/// Delete memories visible to an auth scope with an optional provenance filter
|
||||
pub async fn purge_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
auth_scope: &str,
|
||||
source_agent_id: Option<&str>,
|
||||
before: Option<chrono::DateTime<chrono::Utc>>,
|
||||
) -> Result<u64> {
|
||||
let client = self.pool.get().await?;
|
||||
|
||||
let count = if let Some(before_ts) = before {
|
||||
client
|
||||
.execute(
|
||||
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
|
||||
&[&agent_id, &before_ts],
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
client
|
||||
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
|
||||
.await?
|
||||
};
|
||||
let count = client
|
||||
.execute(
|
||||
r#"
|
||||
DELETE FROM memories
|
||||
WHERE auth_scope = $1
|
||||
AND ($2::text IS NULL OR agent_id = $2)
|
||||
AND ($3::timestamptz IS NULL OR created_at < $3)
|
||||
"#,
|
||||
&[&auth_scope, &source_agent_id, &before],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Get memory count for an agent
|
||||
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
|
||||
/// Get memory count for a token-visible scope and optional provenance filter
|
||||
pub async fn count_memories(
|
||||
&self,
|
||||
auth_scope: &str,
|
||||
source_agent_id: Option<&str>,
|
||||
) -> Result<i64> {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())",
|
||||
&[&agent_id],
|
||||
r#"
|
||||
SELECT COUNT(*) as count
|
||||
FROM memories
|
||||
WHERE auth_scope = $1
|
||||
AND ($2::text IS NULL OR agent_id = $2)
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"#,
|
||||
&[&auth_scope, &source_agent_id],
|
||||
)
|
||||
.await?;
|
||||
Ok(row.get("count"))
|
||||
@@ -346,7 +370,6 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Result for a single batch entry
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct BatchStoreResult {
|
||||
@@ -360,6 +383,7 @@ impl Database {
|
||||
/// Store multiple memories in a single transaction
|
||||
pub async fn batch_store_memories(
|
||||
&self,
|
||||
auth_scope: &str,
|
||||
agent_id: &str,
|
||||
entries: Vec<(
|
||||
String,
|
||||
@@ -378,7 +402,8 @@ impl Database {
|
||||
for (content, metadata, embedding, keywords, expires_at) in entries {
|
||||
let vector = Vector::from(embedding);
|
||||
if let Some(existing) =
|
||||
find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await?
|
||||
find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
|
||||
.await?
|
||||
{
|
||||
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
||||
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
||||
@@ -404,8 +429,8 @@ impl Database {
|
||||
} else {
|
||||
let id = Uuid::new_v4();
|
||||
transaction.execute(
|
||||
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
|
||||
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
).await?;
|
||||
results.push(BatchStoreResult {
|
||||
id: id.to_string(),
|
||||
|
||||
130
src/embedding.rs
130
src/embedding.rs
@@ -1,7 +1,7 @@
|
||||
//! Embedding engine using local ONNX models
|
||||
|
||||
use anyhow::Result;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
use ort::value::Value;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Once;
|
||||
@@ -15,9 +15,9 @@ static ORT_INIT: Once = Once::new();
|
||||
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
|
||||
fn init_ort_sync(dylib_path: &str) -> Result<()> {
|
||||
info!("Initializing ONNX Runtime from: {}", dylib_path);
|
||||
|
||||
|
||||
let mut init_error: Option<String> = None;
|
||||
|
||||
|
||||
ORT_INIT.call_once(|| {
|
||||
info!("ORT_INIT.call_once - starting initialization");
|
||||
match ort::init_from(dylib_path) {
|
||||
@@ -43,7 +43,7 @@ fn init_ort_sync(dylib_path: &str) -> Result<()> {
|
||||
if let Some(err) = init_error {
|
||||
return Err(anyhow::anyhow!("{}", err));
|
||||
}
|
||||
|
||||
|
||||
info!("ONNX Runtime initialization complete");
|
||||
Ok(())
|
||||
}
|
||||
@@ -91,35 +91,40 @@ impl EmbeddingEngine {
|
||||
|
||||
let model_path = PathBuf::from(&config.model_path);
|
||||
let dimension = config.dimension;
|
||||
|
||||
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
|
||||
|
||||
|
||||
info!(
|
||||
"Loading ONNX model from {:?}",
|
||||
model_path.join("model.onnx")
|
||||
);
|
||||
|
||||
// Use spawn_blocking to avoid blocking the async runtime
|
||||
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
||||
// Initialize ONNX Runtime first
|
||||
init_ort_sync(&dylib_path)?;
|
||||
|
||||
info!("Creating ONNX session...");
|
||||
|
||||
// Load ONNX model with ort 2.0 API
|
||||
let session = Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
||||
.commit_from_file(model_path.join("model.onnx"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
||||
let (session, tokenizer) =
|
||||
tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
||||
// Initialize ONNX Runtime first
|
||||
init_ort_sync(&dylib_path)?;
|
||||
|
||||
info!("ONNX model loaded, loading tokenizer...");
|
||||
|
||||
// Load tokenizer
|
||||
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
||||
info!("Creating ONNX session...");
|
||||
|
||||
info!("Tokenizer loaded successfully");
|
||||
Ok((session, tokenizer))
|
||||
}).await
|
||||
// Load ONNX model with ort 2.0 API
|
||||
let session = Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
||||
.commit_from_file(model_path.join("model.onnx"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
||||
|
||||
info!("ONNX model loaded, loading tokenizer...");
|
||||
|
||||
// Load tokenizer
|
||||
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
||||
|
||||
info!("Tokenizer loaded successfully");
|
||||
Ok((session, tokenizer))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
|
||||
|
||||
info!(
|
||||
@@ -136,12 +141,17 @@ impl EmbeddingEngine {
|
||||
|
||||
/// Generate embedding for a single text
|
||||
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
let encoding = self.tokenizer
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, true)
|
||||
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
|
||||
|
||||
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
|
||||
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
|
||||
let attention_mask: Vec<i64> = encoding
|
||||
.get_attention_mask()
|
||||
.iter()
|
||||
.map(|&x| x as i64)
|
||||
.collect();
|
||||
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
|
||||
|
||||
let seq_len = input_ids.len();
|
||||
@@ -157,22 +167,25 @@ impl EmbeddingEngine {
|
||||
"attention_mask" => attention_mask_tensor,
|
||||
"token_type_ids" => token_type_ids_tensor,
|
||||
];
|
||||
|
||||
let mut session_guard = self.session.lock()
|
||||
|
||||
let mut session_guard = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
|
||||
let outputs = session_guard.run(inputs)?;
|
||||
|
||||
// Extract output
|
||||
let output = outputs.get("last_hidden_state")
|
||||
let output = outputs
|
||||
.get("last_hidden_state")
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
|
||||
|
||||
|
||||
// Get the tensor data
|
||||
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
||||
|
||||
|
||||
// Mean pooling over sequence dimension
|
||||
let hidden_size = *shape.last().unwrap_or(&384) as usize;
|
||||
let seq_len = data.len() / hidden_size;
|
||||
|
||||
|
||||
let mut embedding = vec![0.0f32; hidden_size];
|
||||
for i in 0..seq_len {
|
||||
for j in 0..hidden_size {
|
||||
@@ -182,7 +195,7 @@ impl EmbeddingEngine {
|
||||
for val in &mut embedding {
|
||||
*val /= seq_len as f32;
|
||||
}
|
||||
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
@@ -190,7 +203,7 @@ impl EmbeddingEngine {
|
||||
*val /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
@@ -208,37 +221,40 @@ impl EmbeddingEngine {
|
||||
/// Extract keywords from text using simple frequency analysis
|
||||
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
||||
let stop_words: std::collections::HashSet<&str> = [
|
||||
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
|
||||
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
||||
"could", "should", "may", "might", "must", "shall", "can", "this",
|
||||
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
|
||||
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
|
||||
"all", "each", "every", "both", "few", "more", "most", "other", "some",
|
||||
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
|
||||
"very", "just", "also", "now", "here", "there", "then", "once", "if",
|
||||
].iter().cloned().collect();
|
||||
|
||||
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
|
||||
"from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do",
|
||||
"does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can",
|
||||
"this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what",
|
||||
"which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every",
|
||||
"both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own",
|
||||
"same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then",
|
||||
"once", "if",
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut word_counts: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
|
||||
for word in text.split_whitespace() {
|
||||
let clean: String = word
|
||||
.chars()
|
||||
.filter(|c| c.is_alphanumeric())
|
||||
.collect::<String>()
|
||||
.to_lowercase();
|
||||
|
||||
|
||||
if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
|
||||
*word_counts.entry(clean).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut sorted: Vec<_> = word_counts.into_iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
sorted.into_iter()
|
||||
|
||||
sorted
|
||||
.into_iter()
|
||||
.take(limit)
|
||||
.map(|(word, _)| word)
|
||||
.collect()
|
||||
|
||||
34
src/lib.rs
34
src/lib.rs
@@ -5,18 +5,18 @@ pub mod config;
|
||||
pub mod db;
|
||||
pub mod embedding;
|
||||
pub mod migrations;
|
||||
pub mod ttl;
|
||||
pub mod tools;
|
||||
pub mod transport;
|
||||
pub mod ttl;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{Router, Json, http::StatusCode, middleware};
|
||||
use axum::{http::StatusCode, middleware, Json, Router};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::{info, error};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::auth::auth_middleware;
|
||||
use crate::config::Config;
|
||||
@@ -60,15 +60,15 @@ async fn readiness_handler(
|
||||
match readiness {
|
||||
ReadinessState::Ready => (
|
||||
StatusCode::OK,
|
||||
Json(json!({"status": "ready", "embedding": true}))
|
||||
Json(json!({"status": "ready", "embedding": true})),
|
||||
),
|
||||
ReadinessState::Initializing => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({"status": "initializing", "embedding": false}))
|
||||
Json(json!({"status": "initializing", "embedding": false})),
|
||||
),
|
||||
ReadinessState::Failed(err) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({"status": "failed", "error": err}))
|
||||
Json(json!({"status": "failed", "error": err})),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -89,11 +89,14 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
||||
tokio::spawn(async move {
|
||||
let max_retries = 3;
|
||||
let mut attempt = 0;
|
||||
|
||||
|
||||
loop {
|
||||
attempt += 1;
|
||||
info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries);
|
||||
|
||||
info!(
|
||||
"Initializing embedding engine (attempt {}/{})",
|
||||
attempt, max_retries
|
||||
);
|
||||
|
||||
match EmbeddingEngine::new(&embedding_config).await {
|
||||
Ok(engine) => {
|
||||
let engine = Arc::new(engine);
|
||||
@@ -120,9 +123,8 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
||||
let cleanup_state = state.clone();
|
||||
let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds;
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
|
||||
cleanup_interval_seconds,
|
||||
));
|
||||
let mut interval =
|
||||
tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval_seconds));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
@@ -148,10 +150,12 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
||||
.route("/health", axum::routing::get(health_handler))
|
||||
.route("/ready", axum::routing::get(readiness_handler))
|
||||
.with_state(state.clone());
|
||||
|
||||
|
||||
// Build MCP router with auth middleware
|
||||
let mcp_router = transport::mcp_router(mcp_state)
|
||||
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
|
||||
let mcp_router = transport::mcp_router(mcp_state).layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
auth_middleware,
|
||||
));
|
||||
|
||||
let app = Router::new()
|
||||
.merge(health_router)
|
||||
|
||||
@@ -17,7 +17,10 @@ async fn main() -> Result<()> {
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION"));
|
||||
info!(
|
||||
"Starting OpenBrain MCP Server v{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
);
|
||||
|
||||
// Load configuration
|
||||
let config = Config::load()?;
|
||||
|
||||
@@ -3,12 +3,14 @@
|
||||
//! Accepts 1-50 entries per call, generates embeddings for each,
|
||||
//! stores all in a single DB transaction, returns individual IDs/status.
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||
use crate::embedding::extract_keywords;
|
||||
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||
use crate::ttl::expires_at_from_ttl;
|
||||
use crate::AppState;
|
||||
|
||||
@@ -18,7 +20,7 @@ const MAX_BATCH_SIZE: usize = 50;
|
||||
/// Execute the batch_store tool
|
||||
///
|
||||
/// Accepts:
|
||||
/// - `agent_id`: Optional agent identifier (defaults to "default")
|
||||
/// - `agent_id`: Optional source agent label (defaults to "default")
|
||||
/// - `ttl`: Optional default TTL string applied to entries without their own ttl
|
||||
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
|
||||
///
|
||||
@@ -38,6 +40,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.get("agent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
let auth_scope = arguments
|
||||
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||
|
||||
let entries = arguments
|
||||
.get("entries")
|
||||
@@ -47,7 +53,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
|
||||
// 3. Validate batch size
|
||||
if entries.is_empty() {
|
||||
return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries"));
|
||||
return Err(anyhow!(
|
||||
"Empty entries array not allowed - must provide 1-50 entries"
|
||||
));
|
||||
}
|
||||
if entries.len() > MAX_BATCH_SIZE {
|
||||
return Err(anyhow!(
|
||||
@@ -69,7 +77,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
let content = entry
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.context(format!("Entry at index {} missing required field: content", idx))?;
|
||||
.context(format!(
|
||||
"Entry at index {} missing required field: content",
|
||||
idx
|
||||
))?;
|
||||
|
||||
if content.is_empty() {
|
||||
return Err(anyhow!(
|
||||
@@ -82,10 +93,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.get("metadata")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
let ttl = entry
|
||||
.get("ttl")
|
||||
.and_then(|v| v.as_str())
|
||||
.or(default_ttl);
|
||||
let ttl = entry.get("ttl").and_then(|v| v.as_str()).or(default_ttl);
|
||||
let expires_at = expires_at_from_ttl(ttl)
|
||||
.with_context(|| format!("Invalid ttl for entry at index {}", idx))?;
|
||||
|
||||
@@ -97,13 +105,24 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
// Extract keywords
|
||||
let keywords = extract_keywords(content, 10);
|
||||
|
||||
processed_entries.push((content.to_string(), metadata, embedding, keywords, expires_at));
|
||||
processed_entries.push((
|
||||
content.to_string(),
|
||||
metadata,
|
||||
embedding,
|
||||
keywords,
|
||||
expires_at,
|
||||
));
|
||||
}
|
||||
|
||||
// 5. Batch DB insert (single transaction for atomicity)
|
||||
let results = state
|
||||
.db
|
||||
.batch_store_memories(agent_id, processed_entries, state.config.dedup.threshold)
|
||||
.batch_store_memories(
|
||||
auth_scope,
|
||||
agent_id,
|
||||
processed_entries,
|
||||
state.config.dedup.threshold,
|
||||
)
|
||||
.await
|
||||
.context("Failed to batch store memories")?;
|
||||
|
||||
@@ -114,5 +133,6 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
"success": true,
|
||||
"results": results,
|
||||
"count": count
|
||||
}).to_string())
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
//! MCP Tools for OpenBrain
|
||||
|
||||
pub mod batch_store;
|
||||
pub mod purge;
|
||||
pub mod query;
|
||||
pub mod store;
|
||||
pub mod purge;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde_json::{json, Value};
|
||||
@@ -11,11 +11,13 @@ use std::sync::Arc;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub const INTERNAL_AUTH_SCOPE_ARG: &str = "_auth_scope";
|
||||
|
||||
pub fn get_tool_definitions() -> Vec<Value> {
|
||||
vec![
|
||||
json!({
|
||||
"name": "store",
|
||||
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
|
||||
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same API-token scope and same source agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -25,7 +27,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
},
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the agent storing the memory (default: 'default')"
|
||||
"description": "Optional source agent label recorded with the memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
@@ -47,7 +49,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the agent storing the memories (default: 'default')"
|
||||
"description": "Optional source agent label recorded with each stored memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
|
||||
},
|
||||
"ttl": {
|
||||
"type": "string",
|
||||
@@ -89,9 +91,14 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
"type": "string",
|
||||
"description": "The search query text"
|
||||
},
|
||||
"source_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Optional provenance filter that only returns memories stored by the specified agent label"
|
||||
},
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "Agent ID to search within (default: 'default')"
|
||||
"description": "Deprecated legacy alias. Query visibility is scoped by API token, not by agent_id.",
|
||||
"deprecated": true
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
@@ -107,13 +114,18 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
}),
|
||||
json!({
|
||||
"name": "purge",
|
||||
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
|
||||
"description": "Delete memories visible to the current API token. Can optionally filter by source agent label or by time range.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source_agent_id": {
|
||||
"type": "string",
|
||||
"description": "Optional provenance filter that only deletes memories stored by the specified agent label"
|
||||
},
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "Agent ID whose memories to delete (required)"
|
||||
"description": "Deprecated legacy alias for source_agent_id",
|
||||
"deprecated": true
|
||||
},
|
||||
"before": {
|
||||
"type": "string",
|
||||
@@ -124,9 +136,9 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
"description": "Must be true to confirm deletion"
|
||||
}
|
||||
},
|
||||
"required": ["agent_id", "confirm"]
|
||||
"required": ["confirm"]
|
||||
}
|
||||
})
|
||||
}),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//! Purge Tool - Delete memories by agent_id or time range
|
||||
//! Purge Tool - Delete memories visible to the current token with optional filters
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use chrono::DateTime;
|
||||
@@ -6,15 +6,21 @@ use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||
use crate::AppState;
|
||||
|
||||
/// Execute the purge tool
|
||||
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
||||
// Extract parameters
|
||||
let agent_id = arguments
|
||||
.get("agent_id")
|
||||
let source_agent_id = arguments
|
||||
.get("source_agent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.context("Missing required parameter: agent_id")?;
|
||||
.or_else(|| arguments.get("agent_id").and_then(|v| v.as_str()));
|
||||
let auth_scope = arguments
|
||||
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||
|
||||
let confirm = arguments
|
||||
.get("confirm")
|
||||
@@ -36,15 +42,18 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
// Get current count before purge
|
||||
let count_before = state
|
||||
.db
|
||||
.count_memories(agent_id)
|
||||
.count_memories(auth_scope, source_agent_id)
|
||||
.await
|
||||
.context("Failed to count memories")?;
|
||||
|
||||
if count_before == 0 {
|
||||
info!("No memories found for agent '{}'", agent_id);
|
||||
info!(
|
||||
"No memories found to purge for auth scope '{}' with source_agent_id={:?}",
|
||||
auth_scope, source_agent_id
|
||||
);
|
||||
return Ok(serde_json::json!({
|
||||
"success": true,
|
||||
"agent_id": agent_id,
|
||||
"source_agent_id_filter": source_agent_id,
|
||||
"deleted": 0,
|
||||
"message": "No memories found to purge"
|
||||
})
|
||||
@@ -52,25 +61,25 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Purging memories for agent '{}' (before={:?})",
|
||||
agent_id, before
|
||||
"Purging memories for auth scope '{}' with source_agent_id={:?} (before={:?})",
|
||||
auth_scope, source_agent_id, before
|
||||
);
|
||||
|
||||
// Execute purge
|
||||
let deleted = state
|
||||
.db
|
||||
.purge_memories(agent_id, before)
|
||||
.purge_memories(auth_scope, source_agent_id, before)
|
||||
.await
|
||||
.context("Failed to purge memories")?;
|
||||
|
||||
info!(
|
||||
"Purged {} memories for agent '{}'",
|
||||
deleted, agent_id
|
||||
"Purged {} memories for auth scope '{}' with source_agent_id={:?}",
|
||||
deleted, auth_scope, source_agent_id
|
||||
);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"success": true,
|
||||
"agent_id": agent_id,
|
||||
"source_agent_id_filter": source_agent_id,
|
||||
"deleted": deleted,
|
||||
"had_before_filter": before.is_some(),
|
||||
"message": format!("Successfully purged {} memories", deleted)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
//! Query Tool - Search memories by semantic similarity
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||
use crate::AppState;
|
||||
|
||||
/// Execute the query tool
|
||||
@@ -21,10 +23,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.and_then(|v| v.as_str())
|
||||
.context("Missing required parameter: query")?;
|
||||
|
||||
let agent_id = arguments
|
||||
.get("agent_id")
|
||||
let source_agent_id = arguments
|
||||
.get("source_agent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
.filter(|value| !value.is_empty());
|
||||
let auth_scope = arguments
|
||||
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||
|
||||
let limit = arguments
|
||||
.get("limit")
|
||||
@@ -42,8 +48,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
);
|
||||
|
||||
info!(
|
||||
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
|
||||
agent_id, query_text, limit, threshold, vector_weight, text_weight
|
||||
"Querying memories for auth scope '{}' with source_agent_id={:?}: '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
|
||||
auth_scope, source_agent_id, query_text, limit, threshold, vector_weight, text_weight
|
||||
);
|
||||
|
||||
// Generate embedding for query using Arc<EmbeddingEngine>
|
||||
@@ -55,7 +61,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
let matches = state
|
||||
.db
|
||||
.query_memories(
|
||||
agent_id,
|
||||
auth_scope,
|
||||
source_agent_id,
|
||||
query_text,
|
||||
&query_embedding,
|
||||
limit,
|
||||
@@ -79,6 +86,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
"vector_score": m.vector_score,
|
||||
"text_score": m.text_score,
|
||||
"hybrid_score": m.hybrid_score,
|
||||
"agent_id": m.record.agent_id,
|
||||
"keywords": m.record.keywords,
|
||||
"metadata": m.record.metadata,
|
||||
"created_at": m.record.created_at.to_rfc3339(),
|
||||
@@ -89,8 +97,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"success": true,
|
||||
"agent_id": agent_id,
|
||||
"query": query_text,
|
||||
"source_agent_id_filter": source_agent_id,
|
||||
"vector_weight": vector_weight,
|
||||
"text_weight": text_weight,
|
||||
"count": results.len(),
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
//! Store Tool - Store memories with automatic embeddings
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||
use crate::embedding::extract_keywords;
|
||||
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||
use crate::ttl::expires_at_from_ttl;
|
||||
use crate::AppState;
|
||||
|
||||
@@ -27,6 +29,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.get("agent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
let auth_scope = arguments
|
||||
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||
|
||||
let metadata = arguments
|
||||
.get("metadata")
|
||||
@@ -54,6 +60,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
let id = state
|
||||
.db
|
||||
.store_memory(
|
||||
auth_scope,
|
||||
agent_id,
|
||||
content,
|
||||
&embedding,
|
||||
@@ -67,7 +74,11 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
|
||||
info!(
|
||||
"Memory {} with ID: {}",
|
||||
if id.deduplicated { "deduplicated" } else { "stored" },
|
||||
if id.deduplicated {
|
||||
"deduplicated"
|
||||
} else {
|
||||
"stored"
|
||||
},
|
||||
id.id
|
||||
);
|
||||
|
||||
|
||||
175
src/transport.rs
175
src/transport.rs
@@ -5,27 +5,25 @@
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
|
||||
http::{
|
||||
header::{HOST, ORIGIN},
|
||||
HeaderMap, StatusCode, Uri,
|
||||
},
|
||||
response::{
|
||||
IntoResponse, Response,
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse, Response,
|
||||
},
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use futures::stream::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||
use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{AppState, auth, tools};
|
||||
use crate::{auth, tools, AppState};
|
||||
|
||||
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
||||
|
||||
@@ -46,11 +44,7 @@ impl McpState {
|
||||
})
|
||||
}
|
||||
|
||||
async fn insert_session(
|
||||
&self,
|
||||
session_id: String,
|
||||
tx: mpsc::Sender<serde_json::Value>,
|
||||
) {
|
||||
async fn insert_session(&self, session_id: String, tx: mpsc::Sender<serde_json::Value>) {
|
||||
self.sessions.write().await.insert(session_id, tx);
|
||||
}
|
||||
|
||||
@@ -163,7 +157,12 @@ struct PostMessageQuery {
|
||||
/// Create the MCP router
|
||||
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||
Router::new()
|
||||
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
|
||||
.route(
|
||||
"/mcp",
|
||||
get(streamable_get_handler)
|
||||
.post(streamable_post_handler)
|
||||
.delete(streamable_delete_handler),
|
||||
)
|
||||
.route("/mcp/sse", get(sse_handler))
|
||||
.route("/mcp/message", post(message_handler))
|
||||
.route("/mcp/health", get(health_handler))
|
||||
@@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
||||
}
|
||||
|
||||
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
||||
warn!("Rejected MCP request with invalid origin header: {}", origin);
|
||||
warn!(
|
||||
"Rejected MCP request with invalid origin header: {}",
|
||||
origin
|
||||
);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
@@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
warn!("Rejected MCP request without host header for origin {}", origin);
|
||||
warn!(
|
||||
"Rejected MCP request without host header for origin {}",
|
||||
origin
|
||||
);
|
||||
StatusCode::FORBIDDEN
|
||||
})?;
|
||||
|
||||
@@ -247,12 +252,12 @@ async fn streamable_post_handler(
|
||||
|
||||
info!(
|
||||
method = %request.method,
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Received streamable MCP request"
|
||||
);
|
||||
|
||||
let request = apply_request_context(request, &headers);
|
||||
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
|
||||
let response = dispatch_request(&state, &request).await;
|
||||
|
||||
match response {
|
||||
@@ -284,8 +289,12 @@ async fn sse_handler(
|
||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
||||
validate_origin(&headers)?;
|
||||
info!(
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
client_id = auth::get_optional_agent_id(&headers)
|
||||
.as_deref()
|
||||
.unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers)
|
||||
.as_deref()
|
||||
.unwrap_or("unset"),
|
||||
"Opening legacy SSE MCP stream"
|
||||
);
|
||||
let mut broadcast_rx = state.event_tx.subscribe();
|
||||
@@ -354,7 +363,7 @@ async fn message_handler(
|
||||
|
||||
info!(
|
||||
method = %request.method,
|
||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||
"Received legacy SSE MCP request"
|
||||
);
|
||||
@@ -365,7 +374,7 @@ async fn message_handler(
|
||||
}
|
||||
}
|
||||
|
||||
let request = apply_request_context(request, &headers);
|
||||
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
|
||||
let response = dispatch_request(&state, &request).await;
|
||||
|
||||
match query.session_id.as_deref() {
|
||||
@@ -408,17 +417,10 @@ async fn route_session_response(
|
||||
fn apply_request_context(
|
||||
mut request: JsonRpcRequest,
|
||||
headers: &HeaderMap,
|
||||
auth_enabled: bool,
|
||||
) -> JsonRpcRequest {
|
||||
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||
inject_agent_id(&mut request, &agent_id);
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
if request.method != "tools/call" {
|
||||
return;
|
||||
return request;
|
||||
}
|
||||
|
||||
if !request.params.is_object() {
|
||||
@@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
.params
|
||||
.as_object_mut()
|
||||
.expect("params should be an object");
|
||||
let tool_name = params
|
||||
.get("name")
|
||||
.and_then(|value| value.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let arguments = params
|
||||
.entry("arguments")
|
||||
.or_insert_with(|| serde_json::json!({}));
|
||||
@@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||
*arguments = serde_json::json!({});
|
||||
}
|
||||
|
||||
arguments
|
||||
let arguments = arguments
|
||||
.as_object_mut()
|
||||
.expect("arguments should be an object")
|
||||
.entry("agent_id".to_string())
|
||||
.or_insert_with(|| serde_json::json!(agent_id));
|
||||
.expect("arguments should be an object");
|
||||
arguments.insert(
|
||||
tools::INTERNAL_AUTH_SCOPE_ARG.to_string(),
|
||||
serde_json::json!(auth::get_auth_scope(headers, auth_enabled)),
|
||||
);
|
||||
|
||||
if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id")
|
||||
{
|
||||
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||
arguments.insert("agent_id".to_string(), serde_json::json!(agent_id));
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
async fn dispatch_request(
|
||||
@@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value {
|
||||
})
|
||||
}
|
||||
|
||||
fn success_response(
|
||||
id: serde_json::Value,
|
||||
result: serde_json::Value,
|
||||
) -> JsonRpcResponse {
|
||||
fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse {
|
||||
JsonRpcResponse {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
@@ -578,10 +593,11 @@ async fn handle_tools_call(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::http::HeaderValue;
|
||||
|
||||
#[test]
|
||||
fn injects_agent_id_when_missing_from_tool_arguments() {
|
||||
let mut request = JsonRpcRequest {
|
||||
fn request_context_injects_auth_scope_for_tool_calls() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
@@ -592,8 +608,39 @@ mod tests {
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
inject_agent_id(&mut request, "agent-from-header");
|
||||
let request = apply_request_context(request, &headers, true);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.params
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some(auth::hash_api_key("test-token").as_str())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_context_does_not_inject_query_filter_from_header() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "query",
|
||||
"arguments": {
|
||||
"query": "editor preferences"
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
let request = apply_request_context(request, &headers, false);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
@@ -601,25 +648,55 @@ mod tests {
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get("agent_id"))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("agent-from-header")
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserves_explicit_agent_id() {
|
||||
let mut request = JsonRpcRequest {
|
||||
fn request_context_injects_store_agent_id_from_header_when_missing() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "query",
|
||||
"name": "store",
|
||||
"arguments": {
|
||||
"content": "prefers dark mode"
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
let request = apply_request_context(request, &headers, false);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.params
|
||||
.get("arguments")
|
||||
.and_then(|value| value.get("agent_id"))
|
||||
.and_then(|value| value.as_str()),
|
||||
Some("codex-desktop")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_context_preserves_explicit_store_agent_id() {
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: Some(serde_json::json!("1")),
|
||||
method: "tools/call".to_string(),
|
||||
params: serde_json::json!({
|
||||
"name": "store",
|
||||
"arguments": {
|
||||
"agent_id": "explicit-agent"
|
||||
}
|
||||
}),
|
||||
};
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||
|
||||
inject_agent_id(&mut request, "agent-from-header");
|
||||
let request = apply_request_context(request, &headers, false);
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{anyhow, Result};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
|
||||
pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
|
||||
@@ -25,7 +25,9 @@ pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?;
|
||||
if value <= 0 {
|
||||
return Err(anyhow!("invalid ttl '{ttl}'. Duration value must be greater than zero"));
|
||||
return Err(anyhow!(
|
||||
"invalid ttl '{ttl}'. Duration value must be greater than zero"
|
||||
));
|
||||
}
|
||||
|
||||
let total_seconds = value
|
||||
|
||||
Reference in New Issue
Block a user