mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-06-15 22:07:08 +00:00
Merge pull request 'Scope memories by API token and add shared-token e2e coverage' (#28) from codex/shared-token-memory into main
Reviewed-on: Ingwaz/openbrain-mcp#28
This commit is contained in:
@@ -7,7 +7,9 @@ memory.
|
|||||||
## External Memory System
|
## External Memory System
|
||||||
|
|
||||||
- Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
|
- Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
|
||||||
- Always use the exact `agent_id` value `openbrain`
|
- Memory visibility is determined by the API token in the MCP client config, not by `agent_id`
|
||||||
|
- On `openbrain.store`, use `agent_id` only as a provenance label for the storing agent when that label is useful
|
||||||
|
- On `openbrain.query`, do not send `agent_id` for normal retrieval; use `source_agent_id` only when you intentionally want to filter by source agent
|
||||||
- Do not hardcode live credentials into the repository
|
- Do not hardcode live credentials into the repository
|
||||||
- Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` first
|
- Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` first
|
||||||
- Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names
|
- Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names
|
||||||
@@ -19,7 +21,7 @@ memory.
|
|||||||
- Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
|
- Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
|
||||||
- If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
|
- If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
|
||||||
- If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
|
- If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
|
||||||
- Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID
|
- Use `openbrain.purge` cautiously because it is coarse-grained; it deletes memories visible to the current API token and can optionally narrow by `source_agent_id` and `before`, not by individual memory ID
|
||||||
- For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless cleanup or reset is explicitly requested
|
- For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless cleanup or reset is explicitly requested
|
||||||
|
|
||||||
## Agent Identity & Source Tagging
|
## Agent Identity & Source Tagging
|
||||||
|
|||||||
41
README.md
41
README.md
@@ -10,7 +10,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
|
|||||||
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
|
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
|
||||||
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
|
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
|
||||||
- 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility
|
- 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility
|
||||||
- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent
|
- 🔐 **Shared Memory by Token**: Agents using the same API token share memory visibility while retaining source-agent provenance
|
||||||
- ♻️ **Deduplicated Ingest**: Near-duplicate facts are merged instead of stored repeatedly
|
- ♻️ **Deduplicated Ingest**: Near-duplicate facts are merged instead of stored repeatedly
|
||||||
- ⚡ **High Performance**: Rust implementation with async I/O
|
- ⚡ **High Performance**: Rust implementation with async I/O
|
||||||
|
|
||||||
@@ -20,8 +20,8 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
|
|||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `store` | Store a memory with automatic embedding generation, optional TTL, and automatic deduplication |
|
| `store` | Store a memory with automatic embedding generation, optional TTL, and automatic deduplication |
|
||||||
| `batch_store` | Store 1-50 memories atomically in a single call with the same deduplication rules |
|
| `batch_store` | Store 1-50 memories atomically in a single call with the same deduplication rules |
|
||||||
| `query` | Search memories by semantic similarity |
|
| `query` | Search memories by semantic similarity, optionally filtering by source agent |
|
||||||
| `purge` | Delete memories by agent ID or time range |
|
| `purge` | Delete memories visible to the current API token, optionally filtering by source agent or time range |
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
@@ -127,8 +127,8 @@ If you want prod e2e coverage without leaving a standing CI key on the server, t
|
|||||||
### Deduplication on Ingest
|
### Deduplication on Ingest
|
||||||
|
|
||||||
OpenBrain checks every `store` and `batch_store` write for an existing memory in
|
OpenBrain checks every `store` and `batch_store` write for an existing memory in
|
||||||
the same `agent_id` namespace whose vector similarity meets the configured
|
the same API-token scope and same source `agent_id` whose vector similarity
|
||||||
dedup threshold.
|
meets the configured dedup threshold.
|
||||||
|
|
||||||
Default behavior:
|
Default behavior:
|
||||||
|
|
||||||
@@ -165,10 +165,12 @@ Recommended target file in A0:
|
|||||||
### External Memory System
|
### External Memory System
|
||||||
- **Memory Boundary**: Treat OpenBrain as an external MCP long-term memory system, never as internal context, reasoning scratchpad, or built-in memory
|
- **Memory Boundary**: Treat OpenBrain as an external MCP long-term memory system, never as internal context, reasoning scratchpad, or built-in memory
|
||||||
- **Tool Contract**: Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
|
- **Tool Contract**: Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
|
||||||
- **Namespace Discipline**: Always use the exact `agent_id` value `openbrain`
|
- **Shared Access Model**: Memory visibility is determined by the API token in the MCP client config, not by `agent_id`
|
||||||
|
- **Source Labels**: Use `agent_id` on `openbrain.store` and `openbrain.batch_store` only as a provenance label for the storing agent when that label is useful
|
||||||
- **EXTRAS First**: Before calling `openbrain.query`, check the `[EXTRAS]` section for pre-loaded memories or handoff facts related to the same topic. If the needed context is already present, do not query OpenBrain again.
|
- **EXTRAS First**: Before calling `openbrain.query`, check the `[EXTRAS]` section for pre-loaded memories or handoff facts related to the same topic. If the needed context is already present, do not query OpenBrain again.
|
||||||
- **Session Cache**: If the same topic was already queried earlier in the current conversation and the result is still in context, reuse that result instead of querying again unless the user references new external information or the prior result is clearly insufficient.
|
- **Session Cache**: If the same topic was already queried earlier in the current conversation and the result is still in context, reuse that result instead of querying again unless the user references new external information or the prior result is clearly insufficient.
|
||||||
- **Retrieval First**: Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` only when `[EXTRAS]` and the current conversation do not already provide the needed context.
|
- **Retrieval First**: Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` only when `[EXTRAS]` and the current conversation do not already provide the needed context.
|
||||||
|
- **Query Scope**: Do not send `agent_id` with `openbrain.query` for normal retrieval. Use `source_agent_id` only when you intentionally want to filter results by the agent that originally stored them.
|
||||||
- **Query Strategy**: Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names; query first with `(threshold=0.15, limit=8)`, then retry once with `(threshold=0.05, limit=10)` only if the first pass returns zero useful results
|
- **Query Strategy**: Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names; query first with `(threshold=0.15, limit=8)`, then retry once with `(threshold=0.05, limit=10)` only if the first pass returns zero useful results
|
||||||
- **Storage Strategy**: When a durable fact is established, call `openbrain.store` without asking permission and store one atomic fact whenever possible
|
- **Storage Strategy**: When a durable fact is established, call `openbrain.store` without asking permission and store one atomic fact whenever possible
|
||||||
- **Storage Content Rules**: Store durable, high-value facts such as preferences, project status, project decisions, environment details, recurring workflows, handoff notes, stable constraints, and correction facts
|
- **Storage Content Rules**: Store durable, high-value facts such as preferences, project status, project decisions, environment details, recurring workflows, handoff notes, stable constraints, and correction facts
|
||||||
@@ -177,7 +179,7 @@ Recommended target file in A0:
|
|||||||
- **Metadata Usage**: Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
|
- **Metadata Usage**: Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
|
||||||
- **Miss Handling**: If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
|
- **Miss Handling**: If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
|
||||||
- **Conflict Handling**: If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
|
- **Conflict Handling**: If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
|
||||||
- **Purge Constraint**: Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID
|
- **Purge Constraint**: Use `openbrain.purge` cautiously because it is coarse-grained; it deletes memories visible to the current API token and can optionally narrow by `source_agent_id` and `before`, but not by individual memory ID
|
||||||
- **Correction Policy**: For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless the user explicitly asks for cleanup or reset
|
- **Correction Policy**: For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless the user explicitly asks for cleanup or reset
|
||||||
- **Source Tagging**: Every `openbrain.store` call MUST include `"source_agent"` in metadata, set to the Agent Instance ID defined in the active project's identity file (e.g., `"source_agent": "ingwaz-a0"`). This enables tracing facts back to the originating agent instance.
|
- **Source Tagging**: Every `openbrain.store` call MUST include `"source_agent"` in metadata, set to the Agent Instance ID defined in the active project's identity file (e.g., `"source_agent": "ingwaz-a0"`). This enables tracing facts back to the originating agent instance.
|
||||||
```
|
```
|
||||||
@@ -259,17 +261,22 @@ legacy SSE endpoints for older MCP clients that still use the deprecated
|
|||||||
2024-11-05 HTTP+SSE transport.
|
2024-11-05 HTTP+SSE transport.
|
||||||
|
|
||||||
Header roles:
|
Header roles:
|
||||||
- `X-Agent-ID` is the memory namespace. Keep this stable if multiple clients
|
- If two clients use the same API token, they can read and write the same
|
||||||
should share the same OpenBrain memories.
|
OpenBrain memories.
|
||||||
- `X-Agent-Type` is an optional client profile label for logging and config
|
- `X-Agent-ID` is an optional source-agent label for logs and store provenance.
|
||||||
clarity, such as `agent-zero` or `codex`.
|
It does not control memory visibility.
|
||||||
|
- `X-Agent-Type` is an optional client-program label such as `agent-zero`,
|
||||||
|
`codex`, or `claude-code`. It does not select transport server-side; the URL
|
||||||
|
path does that.
|
||||||
|
- `agent_id` on `store` and `batch_store` is provenance. `source_agent_id` on
|
||||||
|
`query` and `purge` is an optional provenance filter.
|
||||||
|
|
||||||
### Example: Codex Configuration
|
### Example: Codex Configuration
|
||||||
|
|
||||||
```toml
|
```toml
|
||||||
[mcp_servers.openbrain]
|
[mcp_servers.openbrain]
|
||||||
url = "https://memory.example.com/mcp"
|
url = "https://memory.example.com/mcp"
|
||||||
http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbrain", "X-Agent-Type" = "codex" }
|
http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "codex-desktop", "X-Agent-Type" = "codex" }
|
||||||
```
|
```
|
||||||
|
|
||||||
### Example: Agent Zero Configuration
|
### Example: Agent Zero Configuration
|
||||||
@@ -281,7 +288,7 @@ http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbra
|
|||||||
"url": "https://memory.example.com/mcp/sse",
|
"url": "https://memory.example.com/mcp/sse",
|
||||||
"headers": {
|
"headers": {
|
||||||
"X-API-Key": "YOUR_OPENBRAIN_API_KEY",
|
"X-API-Key": "YOUR_OPENBRAIN_API_KEY",
|
||||||
"X-Agent-ID": "openbrain",
|
"X-Agent-ID": "agent-zero",
|
||||||
"X-Agent-Type": "agent-zero"
|
"X-Agent-Type": "agent-zero"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -290,7 +297,8 @@ http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbra
|
|||||||
```
|
```
|
||||||
|
|
||||||
Agent Zero should keep using the legacy HTTP+SSE transport unless and until its
|
Agent Zero should keep using the legacy HTTP+SSE transport unless and until its
|
||||||
client runtime supports streamable HTTP. Codex should use `/mcp`.
|
client runtime supports streamable HTTP. Codex should use `/mcp`. If both
|
||||||
|
clients use the same API token, they already share memory visibility.
|
||||||
|
|
||||||
### Example: Store a Memory
|
### Example: Store a Memory
|
||||||
|
|
||||||
@@ -303,7 +311,7 @@ client runtime supports streamable HTTP. Codex should use `/mcp`.
|
|||||||
"name": "store",
|
"name": "store",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"content": "The user prefers dark mode and uses vim keybindings",
|
"content": "The user prefers dark mode and uses vim keybindings",
|
||||||
"agent_id": "assistant-1",
|
"agent_id": "agent-zero",
|
||||||
"ttl": "7d",
|
"ttl": "7d",
|
||||||
"metadata": {"source": "preferences"}
|
"metadata": {"source": "preferences"}
|
||||||
}
|
}
|
||||||
@@ -322,7 +330,6 @@ client runtime supports streamable HTTP. Codex should use `/mcp`.
|
|||||||
"name": "query",
|
"name": "query",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"query": "What are the user's editor preferences?",
|
"query": "What are the user's editor preferences?",
|
||||||
"agent_id": "assistant-1",
|
|
||||||
"limit": 5,
|
"limit": 5,
|
||||||
"threshold": 0.6
|
"threshold": 0.6
|
||||||
}
|
}
|
||||||
@@ -340,7 +347,7 @@ client runtime supports streamable HTTP. Codex should use `/mcp`.
|
|||||||
"params": {
|
"params": {
|
||||||
"name": "batch_store",
|
"name": "batch_store",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"agent_id": "assistant-1",
|
"agent_id": "codex",
|
||||||
"entries": [
|
"entries": [
|
||||||
{
|
{
|
||||||
"content": "The user prefers dark mode",
|
"content": "The user prefers dark mode",
|
||||||
|
|||||||
8
migrations/V4__auth_scope_shared_memory.sql
Normal file
8
migrations/V4__auth_scope_shared_memory.sql
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
ALTER TABLE memories
|
||||||
|
ADD COLUMN IF NOT EXISTS auth_scope VARCHAR(255) NOT NULL DEFAULT 'public';
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_auth_scope
|
||||||
|
ON memories (auth_scope);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_auth_scope_agent
|
||||||
|
ON memories (auth_scope, agent_id);
|
||||||
43
src/auth.rs
43
src/auth.rs
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
http::{HeaderMap, StatusCode, header::AUTHORIZATION},
|
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::Response,
|
response::Response,
|
||||||
};
|
};
|
||||||
@@ -14,6 +14,8 @@ use tracing::warn;
|
|||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
|
pub const PUBLIC_AUTH_SCOPE: &str = "public";
|
||||||
|
|
||||||
/// Hash an API key for secure comparison
|
/// Hash an API key for secure comparison
|
||||||
pub fn hash_api_key(key: &str) -> String {
|
pub fn hash_api_key(key: &str) -> String {
|
||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
@@ -99,24 +101,25 @@ pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
|
|||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract agent ID from request headers or default
|
pub fn get_auth_scope(headers: &HeaderMap, auth_enabled: bool) -> String {
|
||||||
pub fn get_agent_id(request: &Request) -> String {
|
if !auth_enabled {
|
||||||
get_optional_agent_id(request.headers())
|
return PUBLIC_AUTH_SCOPE.to_string();
|
||||||
.unwrap_or_else(|| "default".to_string())
|
}
|
||||||
|
|
||||||
|
extract_api_key(headers)
|
||||||
|
.map(|key| hash_api_key(&key))
|
||||||
|
.unwrap_or_else(|| PUBLIC_AUTH_SCOPE.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use axum::http::{HeaderValue, header::AUTHORIZATION};
|
use axum::http::{header::AUTHORIZATION, HeaderValue};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extracts_api_key_from_bearer_header() {
|
fn extracts_api_key_from_bearer_header() {
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer test-token"));
|
||||||
AUTHORIZATION,
|
|
||||||
HeaderValue::from_static("Bearer test-token"),
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
|
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
|
||||||
}
|
}
|
||||||
@@ -137,9 +140,21 @@ mod tests {
|
|||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
|
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(get_optional_agent_type(&headers).as_deref(), Some("codex"));
|
||||||
get_optional_agent_type(&headers).as_deref(),
|
}
|
||||||
Some("codex")
|
|
||||||
);
|
#[test]
|
||||||
|
fn derives_auth_scope_from_api_key_when_enabled() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
|
||||||
|
|
||||||
|
assert_eq!(get_auth_scope(&headers, true), hash_api_key("test-token"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn uses_public_scope_when_auth_disabled() {
|
||||||
|
let headers = HeaderMap::new();
|
||||||
|
|
||||||
|
assert_eq!(get_auth_scope(&headers, false), PUBLIC_AUTH_SCOPE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -94,29 +94,50 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
match Option::<StringOrVec>::deserialize(deserializer)? {
|
match Option::<StringOrVec>::deserialize(deserializer)? {
|
||||||
Some(StringOrVec::String(s)) => {
|
Some(StringOrVec::String(s)) => Ok(s
|
||||||
Ok(s.split(',')
|
.split(',')
|
||||||
.map(|k| k.trim().to_string())
|
.map(|k| k.trim().to_string())
|
||||||
.filter(|k| !k.is_empty())
|
.filter(|k| !k.is_empty())
|
||||||
.collect())
|
.collect()),
|
||||||
}
|
|
||||||
Some(StringOrVec::Vec(v)) => Ok(v),
|
Some(StringOrVec::Vec(v)) => Ok(v),
|
||||||
None => Ok(Vec::new()),
|
None => Ok(Vec::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default value functions
|
// Default value functions
|
||||||
fn default_host() -> String { "0.0.0.0".to_string() }
|
fn default_host() -> String {
|
||||||
fn default_port() -> u16 { 3100 }
|
"0.0.0.0".to_string()
|
||||||
fn default_db_port() -> u16 { 5432 }
|
}
|
||||||
fn default_pool_size() -> usize { 10 }
|
fn default_port() -> u16 {
|
||||||
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
3100
|
||||||
fn default_embedding_dim() -> usize { 384 }
|
}
|
||||||
fn default_vector_weight() -> f32 { 0.6 }
|
fn default_db_port() -> u16 {
|
||||||
fn default_text_weight() -> f32 { 0.4 }
|
5432
|
||||||
fn default_dedup_threshold() -> f32 { 0.90 }
|
}
|
||||||
fn default_cleanup_interval_seconds() -> u64 { 300 }
|
fn default_pool_size() -> usize {
|
||||||
fn default_auth_enabled() -> bool { false }
|
10
|
||||||
|
}
|
||||||
|
fn default_model_path() -> String {
|
||||||
|
"models/all-MiniLM-L6-v2".to_string()
|
||||||
|
}
|
||||||
|
fn default_embedding_dim() -> usize {
|
||||||
|
384
|
||||||
|
}
|
||||||
|
fn default_vector_weight() -> f32 {
|
||||||
|
0.6
|
||||||
|
}
|
||||||
|
fn default_text_weight() -> f32 {
|
||||||
|
0.4
|
||||||
|
}
|
||||||
|
fn default_dedup_threshold() -> f32 {
|
||||||
|
0.90
|
||||||
|
}
|
||||||
|
fn default_cleanup_interval_seconds() -> u64 {
|
||||||
|
300
|
||||||
|
}
|
||||||
|
fn default_auth_enabled() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
/// Load configuration from environment variables
|
/// Load configuration from environment variables
|
||||||
|
|||||||
103
src/db.rs
103
src/db.rs
@@ -9,9 +9,9 @@ use tokio_postgres::NoTls;
|
|||||||
use tracing::info;
|
use tracing::info;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::config::DatabaseConfig;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
use crate::config::DatabaseConfig;
|
|
||||||
|
|
||||||
/// Database wrapper with connection pool
|
/// Database wrapper with connection pool
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
|
|||||||
|
|
||||||
async fn find_dedup_match<C>(
|
async fn find_dedup_match<C>(
|
||||||
client: &C,
|
client: &C,
|
||||||
|
auth_scope: &str,
|
||||||
agent_id: &str,
|
agent_id: &str,
|
||||||
embedding: &Vector,
|
embedding: &Vector,
|
||||||
threshold: f64,
|
threshold: f64,
|
||||||
@@ -87,13 +88,14 @@ where
|
|||||||
r#"
|
r#"
|
||||||
SELECT id, metadata, expires_at
|
SELECT id, metadata, expires_at
|
||||||
FROM memories
|
FROM memories
|
||||||
WHERE agent_id = $1
|
WHERE auth_scope = $1
|
||||||
|
AND agent_id = $2
|
||||||
AND (expires_at IS NULL OR expires_at > NOW())
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
AND (1 - (embedding <=> $2)) >= $3
|
AND (1 - (embedding <=> $3)) >= $4
|
||||||
ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC
|
ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"#,
|
"#,
|
||||||
&[&agent_id, embedding, &threshold],
|
&[&auth_scope, &agent_id, embedding, &threshold],
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.context("Failed to check for duplicate memory")?;
|
.context("Failed to check for duplicate memory")?;
|
||||||
@@ -120,13 +122,19 @@ impl Database {
|
|||||||
.context("Failed to create database pool")?;
|
.context("Failed to create database pool")?;
|
||||||
|
|
||||||
// Test connection
|
// Test connection
|
||||||
let client = pool.get().await.context("Failed to get database connection")?;
|
let client = pool
|
||||||
|
.get()
|
||||||
|
.await
|
||||||
|
.context("Failed to get database connection")?;
|
||||||
client
|
client
|
||||||
.simple_query("SELECT 1")
|
.simple_query("SELECT 1")
|
||||||
.await
|
.await
|
||||||
.context("Failed to execute test query")?;
|
.context("Failed to execute test query")?;
|
||||||
|
|
||||||
info!("Database connection pool created with {} connections", config.pool_size);
|
info!(
|
||||||
|
"Database connection pool created with {} connections",
|
||||||
|
config.pool_size
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Self { pool })
|
Ok(Self { pool })
|
||||||
}
|
}
|
||||||
@@ -134,6 +142,7 @@ impl Database {
|
|||||||
/// Store a memory record
|
/// Store a memory record
|
||||||
pub async fn store_memory(
|
pub async fn store_memory(
|
||||||
&self,
|
&self,
|
||||||
|
auth_scope: &str,
|
||||||
agent_id: &str,
|
agent_id: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
embedding: &[f32],
|
embedding: &[f32],
|
||||||
@@ -146,7 +155,9 @@ impl Database {
|
|||||||
let vector = Vector::from(embedding.to_vec());
|
let vector = Vector::from(embedding.to_vec());
|
||||||
let dedup_threshold = dedup_threshold as f64;
|
let dedup_threshold = dedup_threshold as f64;
|
||||||
|
|
||||||
if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? {
|
if let Some(existing) =
|
||||||
|
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
|
||||||
|
{
|
||||||
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
||||||
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
||||||
|
|
||||||
@@ -176,10 +187,10 @@ impl Database {
|
|||||||
client
|
client
|
||||||
.execute(
|
.execute(
|
||||||
r#"
|
r#"
|
||||||
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at)
|
INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
"#,
|
"#,
|
||||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.context("Failed to store memory")?;
|
.context("Failed to store memory")?;
|
||||||
@@ -194,7 +205,8 @@ impl Database {
|
|||||||
/// Query memories by vector similarity
|
/// Query memories by vector similarity
|
||||||
pub async fn query_memories(
|
pub async fn query_memories(
|
||||||
&self,
|
&self,
|
||||||
agent_id: &str,
|
auth_scope: &str,
|
||||||
|
source_agent_id: Option<&str>,
|
||||||
query_text: &str,
|
query_text: &str,
|
||||||
embedding: &[f32],
|
embedding: &[f32],
|
||||||
limit: i64,
|
limit: i64,
|
||||||
@@ -230,7 +242,8 @@ impl Database {
|
|||||||
END AS text_score
|
END AS text_score
|
||||||
FROM memories
|
FROM memories
|
||||||
CROSS JOIN search_query
|
CROSS JOIN search_query
|
||||||
WHERE memories.agent_id = $3
|
WHERE memories.auth_scope = $3
|
||||||
|
AND ($4::text IS NULL OR memories.agent_id = $4)
|
||||||
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
|
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
|
||||||
),
|
),
|
||||||
ranked AS (
|
ranked AS (
|
||||||
@@ -251,18 +264,19 @@ impl Database {
|
|||||||
text_score,
|
text_score,
|
||||||
CASE
|
CASE
|
||||||
WHEN has_text_match = 1
|
WHEN has_text_match = 1
|
||||||
THEN (($5 * vector_score) + ($6 * text_score))::real
|
THEN (($6 * vector_score) + ($7 * text_score))::real
|
||||||
ELSE vector_score
|
ELSE vector_score
|
||||||
END AS hybrid_score
|
END AS hybrid_score
|
||||||
FROM ranked
|
FROM ranked
|
||||||
WHERE vector_score >= $4 OR text_score > 0
|
WHERE vector_score >= $5 OR text_score > 0
|
||||||
ORDER BY hybrid_score DESC, vector_score DESC
|
ORDER BY hybrid_score DESC, vector_score DESC
|
||||||
LIMIT $7
|
LIMIT $8
|
||||||
"#,
|
"#,
|
||||||
&[
|
&[
|
||||||
&vector,
|
&vector,
|
||||||
&query_text,
|
&query_text,
|
||||||
&agent_id,
|
&auth_scope,
|
||||||
|
&source_agent_id,
|
||||||
&threshold,
|
&threshold,
|
||||||
&vector_weight,
|
&vector_weight,
|
||||||
&text_weight,
|
&text_weight,
|
||||||
@@ -296,37 +310,47 @@ impl Database {
|
|||||||
Ok(matches)
|
Ok(matches)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Delete memories by agent_id and optional filters
|
/// Delete memories visible to an auth scope with an optional provenance filter
|
||||||
pub async fn purge_memories(
|
pub async fn purge_memories(
|
||||||
&self,
|
&self,
|
||||||
agent_id: &str,
|
auth_scope: &str,
|
||||||
|
source_agent_id: Option<&str>,
|
||||||
before: Option<chrono::DateTime<chrono::Utc>>,
|
before: Option<chrono::DateTime<chrono::Utc>>,
|
||||||
) -> Result<u64> {
|
) -> Result<u64> {
|
||||||
let client = self.pool.get().await?;
|
let client = self.pool.get().await?;
|
||||||
|
|
||||||
let count = if let Some(before_ts) = before {
|
let count = client
|
||||||
client
|
.execute(
|
||||||
.execute(
|
r#"
|
||||||
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
|
DELETE FROM memories
|
||||||
&[&agent_id, &before_ts],
|
WHERE auth_scope = $1
|
||||||
)
|
AND ($2::text IS NULL OR agent_id = $2)
|
||||||
.await?
|
AND ($3::timestamptz IS NULL OR created_at < $3)
|
||||||
} else {
|
"#,
|
||||||
client
|
&[&auth_scope, &source_agent_id, &before],
|
||||||
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
|
)
|
||||||
.await?
|
.await?;
|
||||||
};
|
|
||||||
|
|
||||||
Ok(count)
|
Ok(count)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get memory count for an agent
|
/// Get memory count for a token-visible scope and optional provenance filter
|
||||||
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
|
pub async fn count_memories(
|
||||||
|
&self,
|
||||||
|
auth_scope: &str,
|
||||||
|
source_agent_id: Option<&str>,
|
||||||
|
) -> Result<i64> {
|
||||||
let client = self.pool.get().await?;
|
let client = self.pool.get().await?;
|
||||||
let row = client
|
let row = client
|
||||||
.query_one(
|
.query_one(
|
||||||
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())",
|
r#"
|
||||||
&[&agent_id],
|
SELECT COUNT(*) as count
|
||||||
|
FROM memories
|
||||||
|
WHERE auth_scope = $1
|
||||||
|
AND ($2::text IS NULL OR agent_id = $2)
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
"#,
|
||||||
|
&[&auth_scope, &source_agent_id],
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(row.get("count"))
|
Ok(row.get("count"))
|
||||||
@@ -346,7 +370,6 @@ impl Database {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Result for a single batch entry
|
/// Result for a single batch entry
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct BatchStoreResult {
|
pub struct BatchStoreResult {
|
||||||
@@ -360,6 +383,7 @@ impl Database {
|
|||||||
/// Store multiple memories in a single transaction
|
/// Store multiple memories in a single transaction
|
||||||
pub async fn batch_store_memories(
|
pub async fn batch_store_memories(
|
||||||
&self,
|
&self,
|
||||||
|
auth_scope: &str,
|
||||||
agent_id: &str,
|
agent_id: &str,
|
||||||
entries: Vec<(
|
entries: Vec<(
|
||||||
String,
|
String,
|
||||||
@@ -378,7 +402,8 @@ impl Database {
|
|||||||
for (content, metadata, embedding, keywords, expires_at) in entries {
|
for (content, metadata, embedding, keywords, expires_at) in entries {
|
||||||
let vector = Vector::from(embedding);
|
let vector = Vector::from(embedding);
|
||||||
if let Some(existing) =
|
if let Some(existing) =
|
||||||
find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await?
|
find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
|
||||||
|
.await?
|
||||||
{
|
{
|
||||||
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
||||||
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
||||||
@@ -404,8 +429,8 @@ impl Database {
|
|||||||
} else {
|
} else {
|
||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
transaction.execute(
|
transaction.execute(
|
||||||
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#,
|
r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
|
||||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||||
).await?;
|
).await?;
|
||||||
results.push(BatchStoreResult {
|
results.push(BatchStoreResult {
|
||||||
id: id.to_string(),
|
id: id.to_string(),
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
//! Embedding engine using local ONNX models
|
//! Embedding engine using local ONNX models
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||||
use ort::value::Value;
|
use ort::value::Value;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
@@ -92,34 +92,39 @@ impl EmbeddingEngine {
|
|||||||
let model_path = PathBuf::from(&config.model_path);
|
let model_path = PathBuf::from(&config.model_path);
|
||||||
let dimension = config.dimension;
|
let dimension = config.dimension;
|
||||||
|
|
||||||
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
|
info!(
|
||||||
|
"Loading ONNX model from {:?}",
|
||||||
|
model_path.join("model.onnx")
|
||||||
|
);
|
||||||
|
|
||||||
// Use spawn_blocking to avoid blocking the async runtime
|
// Use spawn_blocking to avoid blocking the async runtime
|
||||||
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
let (session, tokenizer) =
|
||||||
// Initialize ONNX Runtime first
|
tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
||||||
init_ort_sync(&dylib_path)?;
|
// Initialize ONNX Runtime first
|
||||||
|
init_ort_sync(&dylib_path)?;
|
||||||
|
|
||||||
info!("Creating ONNX session...");
|
info!("Creating ONNX session...");
|
||||||
|
|
||||||
// Load ONNX model with ort 2.0 API
|
// Load ONNX model with ort 2.0 API
|
||||||
let session = Session::builder()
|
let session = Session::builder()
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
||||||
.with_intra_threads(4)
|
.with_intra_threads(4)
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
||||||
.commit_from_file(model_path.join("model.onnx"))
|
.commit_from_file(model_path.join("model.onnx"))
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
||||||
|
|
||||||
info!("ONNX model loaded, loading tokenizer...");
|
info!("ONNX model loaded, loading tokenizer...");
|
||||||
|
|
||||||
// Load tokenizer
|
// Load tokenizer
|
||||||
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
||||||
|
|
||||||
info!("Tokenizer loaded successfully");
|
info!("Tokenizer loaded successfully");
|
||||||
Ok((session, tokenizer))
|
Ok((session, tokenizer))
|
||||||
}).await
|
})
|
||||||
|
.await
|
||||||
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
|
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
@@ -136,12 +141,17 @@ impl EmbeddingEngine {
|
|||||||
|
|
||||||
/// Generate embedding for a single text
|
/// Generate embedding for a single text
|
||||||
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||||
let encoding = self.tokenizer
|
let encoding = self
|
||||||
|
.tokenizer
|
||||||
.encode(text, true)
|
.encode(text, true)
|
||||||
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
|
||||||
|
|
||||||
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
|
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
|
||||||
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
|
let attention_mask: Vec<i64> = encoding
|
||||||
|
.get_attention_mask()
|
||||||
|
.iter()
|
||||||
|
.map(|&x| x as i64)
|
||||||
|
.collect();
|
||||||
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
|
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
|
||||||
|
|
||||||
let seq_len = input_ids.len();
|
let seq_len = input_ids.len();
|
||||||
@@ -158,12 +168,15 @@ impl EmbeddingEngine {
|
|||||||
"token_type_ids" => token_type_ids_tensor,
|
"token_type_ids" => token_type_ids_tensor,
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut session_guard = self.session.lock()
|
let mut session_guard = self
|
||||||
|
.session
|
||||||
|
.lock()
|
||||||
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
|
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
|
||||||
let outputs = session_guard.run(inputs)?;
|
let outputs = session_guard.run(inputs)?;
|
||||||
|
|
||||||
// Extract output
|
// Extract output
|
||||||
let output = outputs.get("last_hidden_state")
|
let output = outputs
|
||||||
|
.get("last_hidden_state")
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
|
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
|
||||||
|
|
||||||
// Get the tensor data
|
// Get the tensor data
|
||||||
@@ -210,16 +223,18 @@ pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
let stop_words: std::collections::HashSet<&str> = [
|
let stop_words: std::collections::HashSet<&str> = [
|
||||||
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
|
||||||
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
|
"from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do",
|
||||||
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
"does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can",
|
||||||
"could", "should", "may", "might", "must", "shall", "can", "this",
|
"this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what",
|
||||||
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
|
"which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every",
|
||||||
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
|
"both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own",
|
||||||
"all", "each", "every", "both", "few", "more", "most", "other", "some",
|
"same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then",
|
||||||
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
|
"once", "if",
|
||||||
"very", "just", "also", "now", "here", "there", "then", "once", "if",
|
]
|
||||||
].iter().cloned().collect();
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
let mut word_counts: HashMap<String, usize> = HashMap::new();
|
let mut word_counts: HashMap<String, usize> = HashMap::new();
|
||||||
|
|
||||||
@@ -238,7 +253,8 @@ pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
|
|||||||
let mut sorted: Vec<_> = word_counts.into_iter().collect();
|
let mut sorted: Vec<_> = word_counts.into_iter().collect();
|
||||||
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
||||||
|
|
||||||
sorted.into_iter()
|
sorted
|
||||||
|
.into_iter()
|
||||||
.take(limit)
|
.take(limit)
|
||||||
.map(|(word, _)| word)
|
.map(|(word, _)| word)
|
||||||
.collect()
|
.collect()
|
||||||
|
|||||||
28
src/lib.rs
28
src/lib.rs
@@ -5,18 +5,18 @@ pub mod config;
|
|||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod migrations;
|
pub mod migrations;
|
||||||
pub mod ttl;
|
|
||||||
pub mod tools;
|
pub mod tools;
|
||||||
pub mod transport;
|
pub mod transport;
|
||||||
|
pub mod ttl;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{Router, Json, http::StatusCode, middleware};
|
use axum::{http::StatusCode, middleware, Json, Router};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
use tracing::{info, error};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use crate::auth::auth_middleware;
|
use crate::auth::auth_middleware;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
@@ -60,15 +60,15 @@ async fn readiness_handler(
|
|||||||
match readiness {
|
match readiness {
|
||||||
ReadinessState::Ready => (
|
ReadinessState::Ready => (
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
Json(json!({"status": "ready", "embedding": true}))
|
Json(json!({"status": "ready", "embedding": true})),
|
||||||
),
|
),
|
||||||
ReadinessState::Initializing => (
|
ReadinessState::Initializing => (
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
Json(json!({"status": "initializing", "embedding": false}))
|
Json(json!({"status": "initializing", "embedding": false})),
|
||||||
),
|
),
|
||||||
ReadinessState::Failed(err) => (
|
ReadinessState::Failed(err) => (
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
Json(json!({"status": "failed", "error": err}))
|
Json(json!({"status": "failed", "error": err})),
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,10 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
|||||||
|
|
||||||
loop {
|
loop {
|
||||||
attempt += 1;
|
attempt += 1;
|
||||||
info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries);
|
info!(
|
||||||
|
"Initializing embedding engine (attempt {}/{})",
|
||||||
|
attempt, max_retries
|
||||||
|
);
|
||||||
|
|
||||||
match EmbeddingEngine::new(&embedding_config).await {
|
match EmbeddingEngine::new(&embedding_config).await {
|
||||||
Ok(engine) => {
|
Ok(engine) => {
|
||||||
@@ -120,9 +123,8 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
|||||||
let cleanup_state = state.clone();
|
let cleanup_state = state.clone();
|
||||||
let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds;
|
let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds;
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
|
let mut interval =
|
||||||
cleanup_interval_seconds,
|
tokio::time::interval(tokio::time::Duration::from_secs(cleanup_interval_seconds));
|
||||||
));
|
|
||||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@@ -150,8 +152,10 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
|||||||
.with_state(state.clone());
|
.with_state(state.clone());
|
||||||
|
|
||||||
// Build MCP router with auth middleware
|
// Build MCP router with auth middleware
|
||||||
let mcp_router = transport::mcp_router(mcp_state)
|
let mcp_router = transport::mcp_router(mcp_state).layer(middleware::from_fn_with_state(
|
||||||
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
|
state.clone(),
|
||||||
|
auth_middleware,
|
||||||
|
));
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(health_router)
|
.merge(health_router)
|
||||||
|
|||||||
@@ -17,7 +17,10 @@ async fn main() -> Result<()> {
|
|||||||
.with(tracing_subscriber::fmt::layer())
|
.with(tracing_subscriber::fmt::layer())
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION"));
|
info!(
|
||||||
|
"Starting OpenBrain MCP Server v{}",
|
||||||
|
env!("CARGO_PKG_VERSION")
|
||||||
|
);
|
||||||
|
|
||||||
// Load configuration
|
// Load configuration
|
||||||
let config = Config::load()?;
|
let config = Config::load()?;
|
||||||
|
|||||||
@@ -3,12 +3,14 @@
|
|||||||
//! Accepts 1-50 entries per call, generates embeddings for each,
|
//! Accepts 1-50 entries per call, generates embeddings for each,
|
||||||
//! stores all in a single DB transaction, returns individual IDs/status.
|
//! stores all in a single DB transaction, returns individual IDs/status.
|
||||||
|
|
||||||
use anyhow::{Context, Result, anyhow};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||||
use crate::embedding::extract_keywords;
|
use crate::embedding::extract_keywords;
|
||||||
|
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||||
use crate::ttl::expires_at_from_ttl;
|
use crate::ttl::expires_at_from_ttl;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
@@ -18,7 +20,7 @@ const MAX_BATCH_SIZE: usize = 50;
|
|||||||
/// Execute the batch_store tool
|
/// Execute the batch_store tool
|
||||||
///
|
///
|
||||||
/// Accepts:
|
/// Accepts:
|
||||||
/// - `agent_id`: Optional agent identifier (defaults to "default")
|
/// - `agent_id`: Optional source agent label (defaults to "default")
|
||||||
/// - `ttl`: Optional default TTL string applied to entries without their own ttl
|
/// - `ttl`: Optional default TTL string applied to entries without their own ttl
|
||||||
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
|
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
|
||||||
///
|
///
|
||||||
@@ -38,6 +40,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
.get("agent_id")
|
.get("agent_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("default");
|
.unwrap_or("default");
|
||||||
|
let auth_scope = arguments
|
||||||
|
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||||
|
|
||||||
let entries = arguments
|
let entries = arguments
|
||||||
.get("entries")
|
.get("entries")
|
||||||
@@ -47,7 +53,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
|
|
||||||
// 3. Validate batch size
|
// 3. Validate batch size
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries"));
|
return Err(anyhow!(
|
||||||
|
"Empty entries array not allowed - must provide 1-50 entries"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
if entries.len() > MAX_BATCH_SIZE {
|
if entries.len() > MAX_BATCH_SIZE {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
@@ -69,7 +77,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
let content = entry
|
let content = entry
|
||||||
.get("content")
|
.get("content")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.context(format!("Entry at index {} missing required field: content", idx))?;
|
.context(format!(
|
||||||
|
"Entry at index {} missing required field: content",
|
||||||
|
idx
|
||||||
|
))?;
|
||||||
|
|
||||||
if content.is_empty() {
|
if content.is_empty() {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
@@ -82,10 +93,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
.get("metadata")
|
.get("metadata")
|
||||||
.cloned()
|
.cloned()
|
||||||
.unwrap_or(serde_json::json!({}));
|
.unwrap_or(serde_json::json!({}));
|
||||||
let ttl = entry
|
let ttl = entry.get("ttl").and_then(|v| v.as_str()).or(default_ttl);
|
||||||
.get("ttl")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.or(default_ttl);
|
|
||||||
let expires_at = expires_at_from_ttl(ttl)
|
let expires_at = expires_at_from_ttl(ttl)
|
||||||
.with_context(|| format!("Invalid ttl for entry at index {}", idx))?;
|
.with_context(|| format!("Invalid ttl for entry at index {}", idx))?;
|
||||||
|
|
||||||
@@ -97,13 +105,24 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
// Extract keywords
|
// Extract keywords
|
||||||
let keywords = extract_keywords(content, 10);
|
let keywords = extract_keywords(content, 10);
|
||||||
|
|
||||||
processed_entries.push((content.to_string(), metadata, embedding, keywords, expires_at));
|
processed_entries.push((
|
||||||
|
content.to_string(),
|
||||||
|
metadata,
|
||||||
|
embedding,
|
||||||
|
keywords,
|
||||||
|
expires_at,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Batch DB insert (single transaction for atomicity)
|
// 5. Batch DB insert (single transaction for atomicity)
|
||||||
let results = state
|
let results = state
|
||||||
.db
|
.db
|
||||||
.batch_store_memories(agent_id, processed_entries, state.config.dedup.threshold)
|
.batch_store_memories(
|
||||||
|
auth_scope,
|
||||||
|
agent_id,
|
||||||
|
processed_entries,
|
||||||
|
state.config.dedup.threshold,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.context("Failed to batch store memories")?;
|
.context("Failed to batch store memories")?;
|
||||||
|
|
||||||
@@ -114,5 +133,6 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
"success": true,
|
"success": true,
|
||||||
"results": results,
|
"results": results,
|
||||||
"count": count
|
"count": count
|
||||||
}).to_string())
|
})
|
||||||
|
.to_string())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
//! MCP Tools for OpenBrain
|
//! MCP Tools for OpenBrain
|
||||||
|
|
||||||
pub mod batch_store;
|
pub mod batch_store;
|
||||||
|
pub mod purge;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
pub mod store;
|
pub mod store;
|
||||||
pub mod purge;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
@@ -11,11 +11,13 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
|
pub const INTERNAL_AUTH_SCOPE_ARG: &str = "_auth_scope";
|
||||||
|
|
||||||
pub fn get_tool_definitions() -> Vec<Value> {
|
pub fn get_tool_definitions() -> Vec<Value> {
|
||||||
vec![
|
vec![
|
||||||
json!({
|
json!({
|
||||||
"name": "store",
|
"name": "store",
|
||||||
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
|
"description": "Store a memory with automatic embedding generation and keyword extraction. Near-duplicate memories for the same API-token scope and same source agent are deduplicated automatically by similarity, with metadata merged and timestamps refreshed.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -25,7 +27,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
},
|
},
|
||||||
"agent_id": {
|
"agent_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Unique identifier for the agent storing the memory (default: 'default')"
|
"description": "Optional source agent label recorded with the memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -47,7 +49,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
"properties": {
|
"properties": {
|
||||||
"agent_id": {
|
"agent_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Unique identifier for the agent storing the memories (default: 'default')"
|
"description": "Optional source agent label recorded with each stored memory. If omitted, the server may fall back to X-Agent-ID or 'default'."
|
||||||
},
|
},
|
||||||
"ttl": {
|
"ttl": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -89,9 +91,14 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The search query text"
|
"description": "The search query text"
|
||||||
},
|
},
|
||||||
|
"source_agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional provenance filter that only returns memories stored by the specified agent label"
|
||||||
|
},
|
||||||
"agent_id": {
|
"agent_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Agent ID to search within (default: 'default')"
|
"description": "Deprecated legacy alias. Query visibility is scoped by API token, not by agent_id.",
|
||||||
|
"deprecated": true
|
||||||
},
|
},
|
||||||
"limit": {
|
"limit": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
@@ -107,13 +114,18 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "purge",
|
"name": "purge",
|
||||||
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
|
"description": "Delete memories visible to the current API token. Can optionally filter by source agent label or by time range.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
"source_agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional provenance filter that only deletes memories stored by the specified agent label"
|
||||||
|
},
|
||||||
"agent_id": {
|
"agent_id": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Agent ID whose memories to delete (required)"
|
"description": "Deprecated legacy alias for source_agent_id",
|
||||||
|
"deprecated": true
|
||||||
},
|
},
|
||||||
"before": {
|
"before": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -124,9 +136,9 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
"description": "Must be true to confirm deletion"
|
"description": "Must be true to confirm deletion"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["agent_id", "confirm"]
|
"required": ["confirm"]
|
||||||
}
|
}
|
||||||
})
|
}),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//! Purge Tool - Delete memories by agent_id or time range
|
//! Purge Tool - Delete memories visible to the current token with optional filters
|
||||||
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use chrono::DateTime;
|
use chrono::DateTime;
|
||||||
@@ -6,15 +6,21 @@ use serde_json::Value;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||||
|
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
/// Execute the purge tool
|
/// Execute the purge tool
|
||||||
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
||||||
// Extract parameters
|
// Extract parameters
|
||||||
let agent_id = arguments
|
let source_agent_id = arguments
|
||||||
.get("agent_id")
|
.get("source_agent_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.context("Missing required parameter: agent_id")?;
|
.or_else(|| arguments.get("agent_id").and_then(|v| v.as_str()));
|
||||||
|
let auth_scope = arguments
|
||||||
|
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||||
|
|
||||||
let confirm = arguments
|
let confirm = arguments
|
||||||
.get("confirm")
|
.get("confirm")
|
||||||
@@ -36,15 +42,18 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
// Get current count before purge
|
// Get current count before purge
|
||||||
let count_before = state
|
let count_before = state
|
||||||
.db
|
.db
|
||||||
.count_memories(agent_id)
|
.count_memories(auth_scope, source_agent_id)
|
||||||
.await
|
.await
|
||||||
.context("Failed to count memories")?;
|
.context("Failed to count memories")?;
|
||||||
|
|
||||||
if count_before == 0 {
|
if count_before == 0 {
|
||||||
info!("No memories found for agent '{}'", agent_id);
|
info!(
|
||||||
|
"No memories found to purge for auth scope '{}' with source_agent_id={:?}",
|
||||||
|
auth_scope, source_agent_id
|
||||||
|
);
|
||||||
return Ok(serde_json::json!({
|
return Ok(serde_json::json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"agent_id": agent_id,
|
"source_agent_id_filter": source_agent_id,
|
||||||
"deleted": 0,
|
"deleted": 0,
|
||||||
"message": "No memories found to purge"
|
"message": "No memories found to purge"
|
||||||
})
|
})
|
||||||
@@ -52,25 +61,25 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
"Purging memories for agent '{}' (before={:?})",
|
"Purging memories for auth scope '{}' with source_agent_id={:?} (before={:?})",
|
||||||
agent_id, before
|
auth_scope, source_agent_id, before
|
||||||
);
|
);
|
||||||
|
|
||||||
// Execute purge
|
// Execute purge
|
||||||
let deleted = state
|
let deleted = state
|
||||||
.db
|
.db
|
||||||
.purge_memories(agent_id, before)
|
.purge_memories(auth_scope, source_agent_id, before)
|
||||||
.await
|
.await
|
||||||
.context("Failed to purge memories")?;
|
.context("Failed to purge memories")?;
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Purged {} memories for agent '{}'",
|
"Purged {} memories for auth scope '{}' with source_agent_id={:?}",
|
||||||
deleted, agent_id
|
deleted, auth_scope, source_agent_id
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(serde_json::json!({
|
Ok(serde_json::json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"agent_id": agent_id,
|
"source_agent_id_filter": source_agent_id,
|
||||||
"deleted": deleted,
|
"deleted": deleted,
|
||||||
"had_before_filter": before.is_some(),
|
"had_before_filter": before.is_some(),
|
||||||
"message": format!("Successfully purged {} memories", deleted)
|
"message": format!("Successfully purged {} memories", deleted)
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
//! Query Tool - Search memories by semantic similarity
|
//! Query Tool - Search memories by semantic similarity
|
||||||
|
|
||||||
use anyhow::{Context, Result, anyhow};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||||
|
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
/// Execute the query tool
|
/// Execute the query tool
|
||||||
@@ -21,10 +23,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.context("Missing required parameter: query")?;
|
.context("Missing required parameter: query")?;
|
||||||
|
|
||||||
let agent_id = arguments
|
let source_agent_id = arguments
|
||||||
.get("agent_id")
|
.get("source_agent_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("default");
|
.filter(|value| !value.is_empty());
|
||||||
|
let auth_scope = arguments
|
||||||
|
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||||
|
|
||||||
let limit = arguments
|
let limit = arguments
|
||||||
.get("limit")
|
.get("limit")
|
||||||
@@ -42,8 +48,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
);
|
);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
|
"Querying memories for auth scope '{}' with source_agent_id={:?}: '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
|
||||||
agent_id, query_text, limit, threshold, vector_weight, text_weight
|
auth_scope, source_agent_id, query_text, limit, threshold, vector_weight, text_weight
|
||||||
);
|
);
|
||||||
|
|
||||||
// Generate embedding for query using Arc<EmbeddingEngine>
|
// Generate embedding for query using Arc<EmbeddingEngine>
|
||||||
@@ -55,7 +61,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
let matches = state
|
let matches = state
|
||||||
.db
|
.db
|
||||||
.query_memories(
|
.query_memories(
|
||||||
agent_id,
|
auth_scope,
|
||||||
|
source_agent_id,
|
||||||
query_text,
|
query_text,
|
||||||
&query_embedding,
|
&query_embedding,
|
||||||
limit,
|
limit,
|
||||||
@@ -79,6 +86,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
"vector_score": m.vector_score,
|
"vector_score": m.vector_score,
|
||||||
"text_score": m.text_score,
|
"text_score": m.text_score,
|
||||||
"hybrid_score": m.hybrid_score,
|
"hybrid_score": m.hybrid_score,
|
||||||
|
"agent_id": m.record.agent_id,
|
||||||
"keywords": m.record.keywords,
|
"keywords": m.record.keywords,
|
||||||
"metadata": m.record.metadata,
|
"metadata": m.record.metadata,
|
||||||
"created_at": m.record.created_at.to_rfc3339(),
|
"created_at": m.record.created_at.to_rfc3339(),
|
||||||
@@ -89,8 +97,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
|
|
||||||
Ok(serde_json::json!({
|
Ok(serde_json::json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"agent_id": agent_id,
|
|
||||||
"query": query_text,
|
"query": query_text,
|
||||||
|
"source_agent_id_filter": source_agent_id,
|
||||||
"vector_weight": vector_weight,
|
"vector_weight": vector_weight,
|
||||||
"text_weight": text_weight,
|
"text_weight": text_weight,
|
||||||
"count": results.len(),
|
"count": results.len(),
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
//! Store Tool - Store memories with automatic embeddings
|
//! Store Tool - Store memories with automatic embeddings
|
||||||
|
|
||||||
use anyhow::{Context, Result, anyhow};
|
use anyhow::{anyhow, Context, Result};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::auth::PUBLIC_AUTH_SCOPE;
|
||||||
use crate::embedding::extract_keywords;
|
use crate::embedding::extract_keywords;
|
||||||
|
use crate::tools::INTERNAL_AUTH_SCOPE_ARG;
|
||||||
use crate::ttl::expires_at_from_ttl;
|
use crate::ttl::expires_at_from_ttl;
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
@@ -27,6 +29,10 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
.get("agent_id")
|
.get("agent_id")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("default");
|
.unwrap_or("default");
|
||||||
|
let auth_scope = arguments
|
||||||
|
.get(INTERNAL_AUTH_SCOPE_ARG)
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or(PUBLIC_AUTH_SCOPE);
|
||||||
|
|
||||||
let metadata = arguments
|
let metadata = arguments
|
||||||
.get("metadata")
|
.get("metadata")
|
||||||
@@ -54,6 +60,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
let id = state
|
let id = state
|
||||||
.db
|
.db
|
||||||
.store_memory(
|
.store_memory(
|
||||||
|
auth_scope,
|
||||||
agent_id,
|
agent_id,
|
||||||
content,
|
content,
|
||||||
&embedding,
|
&embedding,
|
||||||
@@ -67,7 +74,11 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Memory {} with ID: {}",
|
"Memory {} with ID: {}",
|
||||||
if id.deduplicated { "deduplicated" } else { "stored" },
|
if id.deduplicated {
|
||||||
|
"deduplicated"
|
||||||
|
} else {
|
||||||
|
"stored"
|
||||||
|
},
|
||||||
id.id
|
id.id
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
175
src/transport.rs
175
src/transport.rs
@@ -5,27 +5,25 @@
|
|||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Query, State},
|
extract::{Query, State},
|
||||||
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
|
http::{
|
||||||
|
header::{HOST, ORIGIN},
|
||||||
|
HeaderMap, StatusCode, Uri,
|
||||||
|
},
|
||||||
response::{
|
response::{
|
||||||
IntoResponse, Response,
|
|
||||||
sse::{Event, KeepAlive, Sse},
|
sse::{Event, KeepAlive, Sse},
|
||||||
|
IntoResponse, Response,
|
||||||
},
|
},
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use futures::stream::Stream;
|
use futures::stream::Stream;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{collections::HashMap, convert::Infallible, sync::Arc, time::Duration};
|
||||||
collections::HashMap,
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
convert::Infallible,
|
|
||||||
sync::Arc,
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
use tokio::sync::{RwLock, broadcast, mpsc};
|
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{AppState, auth, tools};
|
use crate::{auth, tools, AppState};
|
||||||
|
|
||||||
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
||||||
|
|
||||||
@@ -46,11 +44,7 @@ impl McpState {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn insert_session(
|
async fn insert_session(&self, session_id: String, tx: mpsc::Sender<serde_json::Value>) {
|
||||||
&self,
|
|
||||||
session_id: String,
|
|
||||||
tx: mpsc::Sender<serde_json::Value>,
|
|
||||||
) {
|
|
||||||
self.sessions.write().await.insert(session_id, tx);
|
self.sessions.write().await.insert(session_id, tx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +157,12 @@ struct PostMessageQuery {
|
|||||||
/// Create the MCP router
|
/// Create the MCP router
|
||||||
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
|
.route(
|
||||||
|
"/mcp",
|
||||||
|
get(streamable_get_handler)
|
||||||
|
.post(streamable_post_handler)
|
||||||
|
.delete(streamable_delete_handler),
|
||||||
|
)
|
||||||
.route("/mcp/sse", get(sse_handler))
|
.route("/mcp/sse", get(sse_handler))
|
||||||
.route("/mcp/message", post(message_handler))
|
.route("/mcp/message", post(message_handler))
|
||||||
.route("/mcp/health", get(health_handler))
|
.route("/mcp/health", get(health_handler))
|
||||||
@@ -186,7 +185,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
||||||
warn!("Rejected MCP request with invalid origin header: {}", origin);
|
warn!(
|
||||||
|
"Rejected MCP request with invalid origin header: {}",
|
||||||
|
origin
|
||||||
|
);
|
||||||
StatusCode::FORBIDDEN
|
StatusCode::FORBIDDEN
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -203,7 +205,10 @@ fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
|||||||
.map(str::trim)
|
.map(str::trim)
|
||||||
.filter(|value| !value.is_empty())
|
.filter(|value| !value.is_empty())
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
warn!("Rejected MCP request without host header for origin {}", origin);
|
warn!(
|
||||||
|
"Rejected MCP request without host header for origin {}",
|
||||||
|
origin
|
||||||
|
);
|
||||||
StatusCode::FORBIDDEN
|
StatusCode::FORBIDDEN
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@@ -247,12 +252,12 @@ async fn streamable_post_handler(
|
|||||||
|
|
||||||
info!(
|
info!(
|
||||||
method = %request.method,
|
method = %request.method,
|
||||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||||
"Received streamable MCP request"
|
"Received streamable MCP request"
|
||||||
);
|
);
|
||||||
|
|
||||||
let request = apply_request_context(request, &headers);
|
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
|
||||||
let response = dispatch_request(&state, &request).await;
|
let response = dispatch_request(&state, &request).await;
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
@@ -284,8 +289,12 @@ async fn sse_handler(
|
|||||||
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
||||||
validate_origin(&headers)?;
|
validate_origin(&headers)?;
|
||||||
info!(
|
info!(
|
||||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
client_id = auth::get_optional_agent_id(&headers)
|
||||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
.as_deref()
|
||||||
|
.unwrap_or("unset"),
|
||||||
|
agent_type = auth::get_optional_agent_type(&headers)
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("unset"),
|
||||||
"Opening legacy SSE MCP stream"
|
"Opening legacy SSE MCP stream"
|
||||||
);
|
);
|
||||||
let mut broadcast_rx = state.event_tx.subscribe();
|
let mut broadcast_rx = state.event_tx.subscribe();
|
||||||
@@ -354,7 +363,7 @@ async fn message_handler(
|
|||||||
|
|
||||||
info!(
|
info!(
|
||||||
method = %request.method,
|
method = %request.method,
|
||||||
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
client_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||||
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||||
"Received legacy SSE MCP request"
|
"Received legacy SSE MCP request"
|
||||||
);
|
);
|
||||||
@@ -365,7 +374,7 @@ async fn message_handler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = apply_request_context(request, &headers);
|
let request = apply_request_context(request, &headers, state.app.config.auth.enabled);
|
||||||
let response = dispatch_request(&state, &request).await;
|
let response = dispatch_request(&state, &request).await;
|
||||||
|
|
||||||
match query.session_id.as_deref() {
|
match query.session_id.as_deref() {
|
||||||
@@ -408,17 +417,10 @@ async fn route_session_response(
|
|||||||
fn apply_request_context(
|
fn apply_request_context(
|
||||||
mut request: JsonRpcRequest,
|
mut request: JsonRpcRequest,
|
||||||
headers: &HeaderMap,
|
headers: &HeaderMap,
|
||||||
|
auth_enabled: bool,
|
||||||
) -> JsonRpcRequest {
|
) -> JsonRpcRequest {
|
||||||
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
|
||||||
inject_agent_id(&mut request, &agent_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
request
|
|
||||||
}
|
|
||||||
|
|
||||||
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
|
||||||
if request.method != "tools/call" {
|
if request.method != "tools/call" {
|
||||||
return;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
if !request.params.is_object() {
|
if !request.params.is_object() {
|
||||||
@@ -429,6 +431,11 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
|||||||
.params
|
.params
|
||||||
.as_object_mut()
|
.as_object_mut()
|
||||||
.expect("params should be an object");
|
.expect("params should be an object");
|
||||||
|
let tool_name = params
|
||||||
|
.get("name")
|
||||||
|
.and_then(|value| value.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
let arguments = params
|
let arguments = params
|
||||||
.entry("arguments")
|
.entry("arguments")
|
||||||
.or_insert_with(|| serde_json::json!({}));
|
.or_insert_with(|| serde_json::json!({}));
|
||||||
@@ -437,11 +444,22 @@ fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
|||||||
*arguments = serde_json::json!({});
|
*arguments = serde_json::json!({});
|
||||||
}
|
}
|
||||||
|
|
||||||
arguments
|
let arguments = arguments
|
||||||
.as_object_mut()
|
.as_object_mut()
|
||||||
.expect("arguments should be an object")
|
.expect("arguments should be an object");
|
||||||
.entry("agent_id".to_string())
|
arguments.insert(
|
||||||
.or_insert_with(|| serde_json::json!(agent_id));
|
tools::INTERNAL_AUTH_SCOPE_ARG.to_string(),
|
||||||
|
serde_json::json!(auth::get_auth_scope(headers, auth_enabled)),
|
||||||
|
);
|
||||||
|
|
||||||
|
if matches!(tool_name.as_str(), "store" | "batch_store") && !arguments.contains_key("agent_id")
|
||||||
|
{
|
||||||
|
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||||
|
arguments.insert("agent_id".to_string(), serde_json::json!(agent_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn dispatch_request(
|
async fn dispatch_request(
|
||||||
@@ -496,10 +514,7 @@ fn initialize_result() -> serde_json::Value {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn success_response(
|
fn success_response(id: serde_json::Value, result: serde_json::Value) -> JsonRpcResponse {
|
||||||
id: serde_json::Value,
|
|
||||||
result: serde_json::Value,
|
|
||||||
) -> JsonRpcResponse {
|
|
||||||
JsonRpcResponse {
|
JsonRpcResponse {
|
||||||
jsonrpc: "2.0".to_string(),
|
jsonrpc: "2.0".to_string(),
|
||||||
id,
|
id,
|
||||||
@@ -578,10 +593,11 @@ async fn handle_tools_call(
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use axum::http::HeaderValue;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn injects_agent_id_when_missing_from_tool_arguments() {
|
fn request_context_injects_auth_scope_for_tool_calls() {
|
||||||
let mut request = JsonRpcRequest {
|
let request = JsonRpcRequest {
|
||||||
jsonrpc: "2.0".to_string(),
|
jsonrpc: "2.0".to_string(),
|
||||||
id: Some(serde_json::json!("1")),
|
id: Some(serde_json::json!("1")),
|
||||||
method: "tools/call".to_string(),
|
method: "tools/call".to_string(),
|
||||||
@@ -592,8 +608,39 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-API-Key", HeaderValue::from_static("test-token"));
|
||||||
|
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||||
|
|
||||||
inject_agent_id(&mut request, "agent-from-header");
|
let request = apply_request_context(request, &headers, true);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request
|
||||||
|
.params
|
||||||
|
.get("arguments")
|
||||||
|
.and_then(|value| value.get(tools::INTERNAL_AUTH_SCOPE_ARG))
|
||||||
|
.and_then(|value| value.as_str()),
|
||||||
|
Some(auth::hash_api_key("test-token").as_str())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_context_does_not_inject_query_filter_from_header() {
|
||||||
|
let request = JsonRpcRequest {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id: Some(serde_json::json!("1")),
|
||||||
|
method: "tools/call".to_string(),
|
||||||
|
params: serde_json::json!({
|
||||||
|
"name": "query",
|
||||||
|
"arguments": {
|
||||||
|
"query": "editor preferences"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||||
|
|
||||||
|
let request = apply_request_context(request, &headers, false);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request
|
request
|
||||||
@@ -601,25 +648,55 @@ mod tests {
|
|||||||
.get("arguments")
|
.get("arguments")
|
||||||
.and_then(|value| value.get("agent_id"))
|
.and_then(|value| value.get("agent_id"))
|
||||||
.and_then(|value| value.as_str()),
|
.and_then(|value| value.as_str()),
|
||||||
Some("agent-from-header")
|
None
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn preserves_explicit_agent_id() {
|
fn request_context_injects_store_agent_id_from_header_when_missing() {
|
||||||
let mut request = JsonRpcRequest {
|
let request = JsonRpcRequest {
|
||||||
jsonrpc: "2.0".to_string(),
|
jsonrpc: "2.0".to_string(),
|
||||||
id: Some(serde_json::json!("1")),
|
id: Some(serde_json::json!("1")),
|
||||||
method: "tools/call".to_string(),
|
method: "tools/call".to_string(),
|
||||||
params: serde_json::json!({
|
params: serde_json::json!({
|
||||||
"name": "query",
|
"name": "store",
|
||||||
|
"arguments": {
|
||||||
|
"content": "prefers dark mode"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||||
|
|
||||||
|
let request = apply_request_context(request, &headers, false);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request
|
||||||
|
.params
|
||||||
|
.get("arguments")
|
||||||
|
.and_then(|value| value.get("agent_id"))
|
||||||
|
.and_then(|value| value.as_str()),
|
||||||
|
Some("codex-desktop")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_context_preserves_explicit_store_agent_id() {
|
||||||
|
let request = JsonRpcRequest {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id: Some(serde_json::json!("1")),
|
||||||
|
method: "tools/call".to_string(),
|
||||||
|
params: serde_json::json!({
|
||||||
|
"name": "store",
|
||||||
"arguments": {
|
"arguments": {
|
||||||
"agent_id": "explicit-agent"
|
"agent_id": "explicit-agent"
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Agent-ID", HeaderValue::from_static("codex-desktop"));
|
||||||
|
|
||||||
inject_agent_id(&mut request, "agent-from-header");
|
let request = apply_request_context(request, &headers, false);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
request
|
request
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use anyhow::{Result, anyhow};
|
use anyhow::{anyhow, Result};
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
|
|
||||||
pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
|
pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
|
||||||
@@ -25,7 +25,9 @@ pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
|
|||||||
.parse()
|
.parse()
|
||||||
.map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?;
|
.map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?;
|
||||||
if value <= 0 {
|
if value <= 0 {
|
||||||
return Err(anyhow!("invalid ttl '{ttl}'. Duration value must be greater than zero"));
|
return Err(anyhow!(
|
||||||
|
"invalid ttl '{ttl}'. Duration value must be greater than zero"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let total_seconds = value
|
let total_seconds = value
|
||||||
|
|||||||
764
tests/e2e_mcp.rs
764
tests/e2e_mcp.rs
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user