mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Merge branch 'codex/issue-12-repair'
This commit is contained in:
@@ -10,8 +10,14 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
env:
|
||||||
CARGO_TERM_COLOR: always
|
CARGO_TERM_COLOR: always
|
||||||
VPS_HOST: ${{ vars.VPS_HOST }}
|
VPS_HOST: ${{ secrets.VPS_HOST }}
|
||||||
VPS_USER: ${{ vars.VPS_USER }}
|
VPS_USER: ${{ secrets.VPS_USER }}
|
||||||
|
OPENBRAIN__DATABASE__HOST: ${{ secrets.OPENBRAIN__DATABASE__HOST }}
|
||||||
|
OPENBRAIN__DATABASE__PORT: ${{ secrets.OPENBRAIN__DATABASE__PORT }}
|
||||||
|
OPENBRAIN__DATABASE__NAME: ${{ secrets.OPENBRAIN__DATABASE__NAME }}
|
||||||
|
OPENBRAIN__DATABASE__USER: ${{ secrets.OPENBRAIN__DATABASE__USER }}
|
||||||
|
OPENBRAIN__DATABASE__PASSWORD: ${{ secrets.OPENBRAIN__DATABASE__PASSWORD }}
|
||||||
|
OPENBRAIN__DATABASE__POOL_SIZE: ${{ secrets.OPENBRAIN__DATABASE__POOL_SIZE }}
|
||||||
DEPLOY_DIR: /opt/openbrain-mcp
|
DEPLOY_DIR: /opt/openbrain-mcp
|
||||||
SERVICE_NAME: openbrain-mcp
|
SERVICE_NAME: openbrain-mcp
|
||||||
steps:
|
steps:
|
||||||
@@ -65,7 +71,7 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
run: |
|
run: |
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
: "${VPS_HOST:?Set repository secret VPS_HOST}"
|
||||||
: "${VPS_USER:=root}"
|
: "${VPS_USER:=root}"
|
||||||
install -d -m 700 "$HOME/.ssh"
|
install -d -m 700 "$HOME/.ssh"
|
||||||
printf '%s\n' "${{ secrets.VPS_SSH_KEY }}" > "$HOME/.ssh/deploy_key"
|
printf '%s\n' "${{ secrets.VPS_SSH_KEY }}" > "$HOME/.ssh/deploy_key"
|
||||||
@@ -76,7 +82,7 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
run: |
|
run: |
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
: "${VPS_HOST:?Set repository secret VPS_HOST}"
|
||||||
: "${VPS_USER:=root}"
|
: "${VPS_USER:=root}"
|
||||||
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
SCP="scp -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
SCP="scp -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
@@ -94,11 +100,29 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
run: |
|
run: |
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
: "${VPS_HOST:?Set repository secret VPS_HOST}"
|
||||||
: "${VPS_USER:=root}"
|
: "${VPS_USER:=root}"
|
||||||
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
|
|
||||||
$SSH "$VPS_USER@$VPS_HOST" "DEPLOY_DIR=$DEPLOY_DIR SERVICE_USER=openbrain SERVICE_GROUP=openbrain ORT_VERSION=1.24.3 bash -s" <<'EOS'
|
: "${OPENBRAIN__DATABASE__HOST:?Set repository secret OPENBRAIN__DATABASE__HOST}"
|
||||||
|
: "${OPENBRAIN__DATABASE__NAME:?Set repository secret OPENBRAIN__DATABASE__NAME}"
|
||||||
|
: "${OPENBRAIN__DATABASE__USER:?Set repository secret OPENBRAIN__DATABASE__USER}"
|
||||||
|
: "${OPENBRAIN__DATABASE__PASSWORD:?Set repository secret OPENBRAIN__DATABASE__PASSWORD}"
|
||||||
|
: "${OPENBRAIN__DATABASE__PORT:=5432}"
|
||||||
|
: "${OPENBRAIN__DATABASE__POOL_SIZE:=10}"
|
||||||
|
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "\
|
||||||
|
DEPLOY_DIR=$DEPLOY_DIR \
|
||||||
|
SERVICE_USER=openbrain \
|
||||||
|
SERVICE_GROUP=openbrain \
|
||||||
|
ORT_VERSION=1.24.3 \
|
||||||
|
OPENBRAIN__DATABASE__HOST='$OPENBRAIN__DATABASE__HOST' \
|
||||||
|
OPENBRAIN__DATABASE__PORT='$OPENBRAIN__DATABASE__PORT' \
|
||||||
|
OPENBRAIN__DATABASE__NAME='$OPENBRAIN__DATABASE__NAME' \
|
||||||
|
OPENBRAIN__DATABASE__USER='$OPENBRAIN__DATABASE__USER' \
|
||||||
|
OPENBRAIN__DATABASE__PASSWORD='$OPENBRAIN__DATABASE__PASSWORD' \
|
||||||
|
OPENBRAIN__DATABASE__POOL_SIZE='$OPENBRAIN__DATABASE__POOL_SIZE' \
|
||||||
|
bash -s" <<'EOS'
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}"
|
DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}"
|
||||||
SERVICE_USER="${SERVICE_USER:-openbrain}"
|
SERVICE_USER="${SERVICE_USER:-openbrain}"
|
||||||
@@ -147,13 +171,21 @@ jobs:
|
|||||||
upsert_env() {
|
upsert_env() {
|
||||||
local key="$1"
|
local key="$1"
|
||||||
local value="$2"
|
local value="$2"
|
||||||
|
local escaped_value
|
||||||
|
escaped_value="$(printf '%s' "$value" | sed -e 's/[\\&|]/\\&/g')"
|
||||||
if grep -qE "^${key}=" "$ENV_FILE"; then
|
if grep -qE "^${key}=" "$ENV_FILE"; then
|
||||||
sed -i "s|^${key}=.*|${key}=${value}|" "$ENV_FILE"
|
sed -i "s|^${key}=.*|${key}=${escaped_value}|" "$ENV_FILE"
|
||||||
else
|
else
|
||||||
printf '%s=%s\n' "$key" "$value" >> "$ENV_FILE"
|
printf '%s=%s\n' "$key" "$value" >> "$ENV_FILE"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
upsert_env "OPENBRAIN__DATABASE__HOST" "$OPENBRAIN__DATABASE__HOST"
|
||||||
|
upsert_env "OPENBRAIN__DATABASE__PORT" "$OPENBRAIN__DATABASE__PORT"
|
||||||
|
upsert_env "OPENBRAIN__DATABASE__NAME" "$OPENBRAIN__DATABASE__NAME"
|
||||||
|
upsert_env "OPENBRAIN__DATABASE__USER" "$OPENBRAIN__DATABASE__USER"
|
||||||
|
upsert_env "OPENBRAIN__DATABASE__PASSWORD" "$OPENBRAIN__DATABASE__PASSWORD"
|
||||||
|
upsert_env "OPENBRAIN__DATABASE__POOL_SIZE" "$OPENBRAIN__DATABASE__POOL_SIZE"
|
||||||
upsert_env "OPENBRAIN__EMBEDDING__MODEL_PATH" "$DEPLOY_DIR/models/all-MiniLM-L6-v2"
|
upsert_env "OPENBRAIN__EMBEDDING__MODEL_PATH" "$DEPLOY_DIR/models/all-MiniLM-L6-v2"
|
||||||
upsert_env "ORT_DYLIB_PATH" "$DEPLOY_DIR/lib/libonnxruntime.so"
|
upsert_env "ORT_DYLIB_PATH" "$DEPLOY_DIR/lib/libonnxruntime.so"
|
||||||
upsert_env "OPENBRAIN__SERVER__HOST" "0.0.0.0"
|
upsert_env "OPENBRAIN__SERVER__HOST" "0.0.0.0"
|
||||||
@@ -173,7 +205,7 @@ jobs:
|
|||||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
run: |
|
run: |
|
||||||
set -euxo pipefail
|
set -euxo pipefail
|
||||||
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
: "${VPS_HOST:?Set repository secret VPS_HOST}"
|
||||||
: "${VPS_USER:=root}"
|
: "${VPS_USER:=root}"
|
||||||
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
|
|
||||||
|
|||||||
23
AGENTS.md
Normal file
23
AGENTS.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# OpenBrain MCP Usage
|
||||||
|
|
||||||
|
When working in this repository, treat OpenBrain as an external MCP long-term
|
||||||
|
memory system, never as internal context, reasoning scratchpad, or built-in
|
||||||
|
memory.
|
||||||
|
|
||||||
|
## External Memory System
|
||||||
|
|
||||||
|
- Use the exact MCP tools `openbrain.store`, `openbrain.query`, and `openbrain.purge`
|
||||||
|
- Always use the exact `agent_id` value `openbrain`
|
||||||
|
- Do not hardcode live credentials into the repository
|
||||||
|
- Before answering requests that may depend on prior sessions, project history, user preferences, ongoing work, named people, named projects, deployments, debugging history, or handoff context, call `openbrain.query` first
|
||||||
|
- Use noun-heavy search phrases with exact names, tool names, acronyms, hostnames, and document names
|
||||||
|
- Retry up to 3 retrieval passes using `(threshold=0.25, limit=5)`, then `(threshold=0.10, limit=8)`, then `(threshold=0.05, limit=10)`
|
||||||
|
- When a durable fact is established, call `openbrain.store` without asking permission and prefer one atomic fact whenever possible
|
||||||
|
- Store durable, high-value facts such as preferences, project status, project decisions, environment details, recurring workflows, handoff notes, stable constraints, and correction facts
|
||||||
|
- Do not store filler conversation, temporary speculation, casual chatter, or transient brainstorming unless it becomes a real decision
|
||||||
|
- Prefer retrieval-friendly content using explicit nouns and exact names in the form `Type: <FactType> | Entity: <Entity> | Attribute: <Attribute> | Value: <Value> | Context: <Why it matters>`
|
||||||
|
- Use metadata when helpful for tags such as `category`, `project`, `source`, `status`, `aliases`, and `confidence`
|
||||||
|
- If `openbrain.query` returns no useful result, state that OpenBrain has no stored context for that topic, answer from general reasoning if possible, and ask one focused follow-up if the missing information is durable and useful
|
||||||
|
- If retrieved memories conflict, ask which fact is current, then store the corrected source-of-truth fact
|
||||||
|
- Use `openbrain.purge` cautiously because it is coarse-grained; it deletes by `agent_id` and optionally before a timestamp, not by individual memory ID
|
||||||
|
- For ordinary corrections, prefer storing the new source-of-truth fact instead of purging unless cleanup or reset is explicitly requested
|
||||||
76
README.md
76
README.md
@@ -9,7 +9,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
|
|||||||
- 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search
|
- 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search
|
||||||
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
|
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
|
||||||
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
|
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
|
||||||
- 🔌 **MCP Protocol**: Standard Model Context Protocol over SSE transport
|
- 🔌 **MCP Protocol**: Streamable HTTP plus legacy HTTP+SSE compatibility
|
||||||
- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent
|
- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent
|
||||||
- ⚡ **High Performance**: Rust implementation with async I/O
|
- ⚡ **High Performance**: Rust implementation with async I/O
|
||||||
|
|
||||||
@@ -18,6 +18,7 @@ OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with
|
|||||||
| Tool | Description |
|
| Tool | Description |
|
||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `store` | Store a memory with automatic embedding generation and keyword extraction |
|
| `store` | Store a memory with automatic embedding generation and keyword extraction |
|
||||||
|
| `batch_store` | Store 1-50 memories atomically in a single call |
|
||||||
| `query` | Search memories by semantic similarity |
|
| `query` | Search memories by semantic similarity |
|
||||||
| `purge` | Delete memories by agent ID or time range |
|
| `purge` | Delete memories by agent ID or time range |
|
||||||
|
|
||||||
@@ -102,14 +103,53 @@ Recommended target file in A0:
|
|||||||
|
|
||||||
## MCP Integration
|
## MCP Integration
|
||||||
|
|
||||||
Connect to the server using SSE transport:
|
OpenBrain exposes both MCP HTTP transports:
|
||||||
|
|
||||||
```
|
```
|
||||||
SSE Endpoint: http://localhost:3100/mcp/sse
|
Streamable HTTP Endpoint: http://localhost:3100/mcp
|
||||||
Message Endpoint: http://localhost:3100/mcp/message
|
Legacy SSE Endpoint: http://localhost:3100/mcp/sse
|
||||||
|
Legacy Message Endpoint: http://localhost:3100/mcp/message
|
||||||
Health Check: http://localhost:3100/mcp/health
|
Health Check: http://localhost:3100/mcp/health
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Use the streamable HTTP endpoint for modern clients such as Codex. Keep the
|
||||||
|
legacy SSE endpoints for older MCP clients that still use the deprecated
|
||||||
|
2024-11-05 HTTP+SSE transport.
|
||||||
|
|
||||||
|
Header roles:
|
||||||
|
- `X-Agent-ID` is the memory namespace. Keep this stable if multiple clients
|
||||||
|
should share the same OpenBrain memories.
|
||||||
|
- `X-Agent-Type` is an optional client profile label for logging and config
|
||||||
|
clarity, such as `agent-zero` or `codex`.
|
||||||
|
|
||||||
|
### Example: Codex Configuration
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[mcp_servers.openbrain]
|
||||||
|
url = "https://ob.ingwaz.work/mcp"
|
||||||
|
http_headers = { "X-API-Key" = "YOUR_OPENBRAIN_API_KEY", "X-Agent-ID" = "openbrain", "X-Agent-Type" = "codex" }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Agent Zero Configuration
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"openbrain": {
|
||||||
|
"url": "https://ob.ingwaz.work/mcp/sse",
|
||||||
|
"headers": {
|
||||||
|
"X-API-Key": "YOUR_OPENBRAIN_API_KEY",
|
||||||
|
"X-Agent-ID": "openbrain",
|
||||||
|
"X-Agent-Type": "agent-zero"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Agent Zero should keep using the legacy HTTP+SSE transport unless and until its
|
||||||
|
client runtime supports streamable HTTP. Codex should use `/mcp`.
|
||||||
|
|
||||||
### Example: Store a Memory
|
### Example: Store a Memory
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@@ -147,13 +187,39 @@ Health Check: http://localhost:3100/mcp/health
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Example: Batch Store Memories
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 3,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "batch_store",
|
||||||
|
"arguments": {
|
||||||
|
"agent_id": "assistant-1",
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"content": "The user prefers dark mode",
|
||||||
|
"metadata": {"category": "preference"}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "The user uses vim keybindings",
|
||||||
|
"metadata": {"category": "preference"}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
┌─────────────────────────────────────────────────────────┐
|
┌─────────────────────────────────────────────────────────┐
|
||||||
│ AI Agent │
|
│ AI Agent │
|
||||||
└─────────────────────┬───────────────────────────────────┘
|
└─────────────────────┬───────────────────────────────────┘
|
||||||
│ MCP Protocol (SSE)
|
│ MCP Protocol (Streamable HTTP / Legacy SSE)
|
||||||
┌─────────────────────▼───────────────────────────────────┐
|
┌─────────────────────▼───────────────────────────────────┐
|
||||||
│ OpenBrain MCP Server │
|
│ OpenBrain MCP Server │
|
||||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||||
|
|||||||
20
src/auth.rs
20
src/auth.rs
@@ -90,6 +90,15 @@ pub fn get_optional_agent_id(headers: &HeaderMap) -> Option<String> {
|
|||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_optional_agent_type(headers: &HeaderMap) -> Option<String> {
|
||||||
|
headers
|
||||||
|
.get("X-Agent-Type")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract agent ID from request headers or default
|
/// Extract agent ID from request headers or default
|
||||||
pub fn get_agent_id(request: &Request) -> String {
|
pub fn get_agent_id(request: &Request) -> String {
|
||||||
get_optional_agent_id(request.headers())
|
get_optional_agent_id(request.headers())
|
||||||
@@ -122,4 +131,15 @@ mod tests {
|
|||||||
Some("agent-zero")
|
Some("agent-zero")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_agent_type_from_header() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Agent-Type", HeaderValue::from_static("codex"));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
get_optional_agent_type(&headers).as_deref(),
|
||||||
|
Some("codex")
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
|||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(health_router)
|
.merge(health_router)
|
||||||
.nest("/mcp", mcp_router)
|
.merge(mcp_router)
|
||||||
.layer(TraceLayer::new_for_http())
|
.layer(TraceLayer::new_for_http())
|
||||||
.layer(
|
.layer(
|
||||||
CorsLayer::new()
|
CorsLayer::new()
|
||||||
|
|||||||
@@ -15,31 +15,50 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
vec![
|
vec![
|
||||||
json!({
|
json!({
|
||||||
"name": "store",
|
"name": "store",
|
||||||
"description": "Store a memory with automatic embedding generation",
|
"description": "Store a memory with automatic embedding generation and keyword extraction. The memory will be associated with the agent_id for isolated retrieval.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {"type": "string"},
|
"content": {
|
||||||
"agent_id": {"type": "string"},
|
"type": "string",
|
||||||
"metadata": {"type": "object"}
|
"description": "The text content to store as a memory"
|
||||||
|
},
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Unique identifier for the agent storing the memory (default: 'default')"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Optional metadata to attach to the memory"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"required": ["content"]
|
"required": ["content"]
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "batch_store",
|
"name": "batch_store",
|
||||||
"description": "Store multiple memories in a single call (1-50 entries)",
|
"description": "Store multiple memories with automatic embedding generation and keyword extraction. Accepts 1-50 entries and stores them atomically in a single transaction.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"agent_id": {"type": "string"},
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Unique identifier for the agent storing the memories (default: 'default')"
|
||||||
|
},
|
||||||
"entries": {
|
"entries": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
|
"description": "Array of 1-50 memory entries to store atomically",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {"type": "string"},
|
"content": {
|
||||||
"metadata": {"type": "object"}
|
"type": "string",
|
||||||
|
"description": "The text content to store as a memory"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Optional metadata to attach to the memory"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"required": ["content"]
|
"required": ["content"]
|
||||||
}
|
}
|
||||||
@@ -50,27 +69,48 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "query",
|
"name": "query",
|
||||||
"description": "Query memories by semantic similarity",
|
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {"type": "string"},
|
"query": {
|
||||||
"agent_id": {"type": "string"},
|
"type": "string",
|
||||||
"limit": {"type": "integer"},
|
"description": "The search query text"
|
||||||
"threshold": {"type": "number"}
|
},
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Agent ID to search within (default: 'default')"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results to return (default: 10)"
|
||||||
|
},
|
||||||
|
"threshold": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Minimum similarity threshold 0.0-1.0 (default: 0.5)"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "purge",
|
"name": "purge",
|
||||||
"description": "Delete memories by agent_id",
|
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"agent_id": {"type": "string"},
|
"agent_id": {
|
||||||
"before": {"type": "string"},
|
"type": "string",
|
||||||
"confirm": {"type": "boolean"}
|
"description": "Agent ID whose memories to delete (required)"
|
||||||
|
},
|
||||||
|
"before": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional ISO8601 timestamp - delete memories created before this time"
|
||||||
|
},
|
||||||
|
"confirm": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Must be true to confirm deletion"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"required": ["agent_id", "confirm"]
|
"required": ["agent_id", "confirm"]
|
||||||
}
|
}
|
||||||
|
|||||||
161
src/transport.rs
161
src/transport.rs
@@ -1,10 +1,11 @@
|
|||||||
//! SSE Transport for MCP Protocol
|
//! HTTP transport for MCP Protocol.
|
||||||
//!
|
//!
|
||||||
//! Implements Server-Sent Events transport for the Model Context Protocol.
|
//! Supports both the legacy HTTP+SSE transport and the newer streamable HTTP
|
||||||
|
//! transport on the same server.
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Query, State},
|
extract::{Query, State},
|
||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode, Uri, header::{HOST, ORIGIN}},
|
||||||
response::{
|
response::{
|
||||||
IntoResponse, Response,
|
IntoResponse, Response,
|
||||||
sse::{Event, KeepAlive, Sse},
|
sse::{Event, KeepAlive, Sse},
|
||||||
@@ -162,16 +163,131 @@ struct PostMessageQuery {
|
|||||||
/// Create the MCP router
|
/// Create the MCP router
|
||||||
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/sse", get(sse_handler))
|
.route("/mcp", get(streamable_get_handler).post(streamable_post_handler).delete(streamable_delete_handler))
|
||||||
.route("/message", post(message_handler))
|
.route("/mcp/sse", get(sse_handler))
|
||||||
.route("/health", get(health_handler))
|
.route("/mcp/message", post(message_handler))
|
||||||
|
.route("/mcp/health", get(health_handler))
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_origin(headers: &HeaderMap) -> Result<(), StatusCode> {
|
||||||
|
let Some(origin) = headers
|
||||||
|
.get(ORIGIN)
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
if origin.eq_ignore_ascii_case("null") {
|
||||||
|
warn!("Rejected MCP request with null origin");
|
||||||
|
return Err(StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
let origin_uri = origin.parse::<Uri>().map_err(|_| {
|
||||||
|
warn!("Rejected MCP request with invalid origin header: {}", origin);
|
||||||
|
StatusCode::FORBIDDEN
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let origin_host = origin_uri.host().ok_or_else(|| {
|
||||||
|
warn!("Rejected MCP request without origin host: {}", origin);
|
||||||
|
StatusCode::FORBIDDEN
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let request_host = headers
|
||||||
|
.get("X-Forwarded-Host")
|
||||||
|
.or_else(|| headers.get(HOST))
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.and_then(|value| value.split(',').next())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.ok_or_else(|| {
|
||||||
|
warn!("Rejected MCP request without host header for origin {}", origin);
|
||||||
|
StatusCode::FORBIDDEN
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let origin_authority = origin_uri
|
||||||
|
.authority()
|
||||||
|
.map(|authority| authority.as_str())
|
||||||
|
.unwrap_or(origin_host);
|
||||||
|
let origin_with_default_port = origin_uri
|
||||||
|
.port_u16()
|
||||||
|
.or_else(|| match origin_uri.scheme_str() {
|
||||||
|
Some("https") => Some(443),
|
||||||
|
Some("http") => Some(80),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.map(|port| format!("{origin_host}:{port}"));
|
||||||
|
|
||||||
|
if request_host.eq_ignore_ascii_case(origin_host)
|
||||||
|
|| request_host.eq_ignore_ascii_case(origin_authority)
|
||||||
|
|| origin_with_default_port
|
||||||
|
.as_deref()
|
||||||
|
.is_some_and(|value| request_host.eq_ignore_ascii_case(value))
|
||||||
|
{
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"Rejected MCP request due to origin/host mismatch: origin={}, host={}",
|
||||||
|
origin, request_host
|
||||||
|
);
|
||||||
|
Err(StatusCode::FORBIDDEN)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn streamable_post_handler(
|
||||||
|
State(state): State<Arc<McpState>>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
Json(request): Json<JsonRpcRequest>,
|
||||||
|
) -> Response {
|
||||||
|
if let Err(status) = validate_origin(&headers) {
|
||||||
|
return status.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
method = %request.method,
|
||||||
|
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||||
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||||
|
"Received streamable MCP request"
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = apply_request_context(request, &headers);
|
||||||
|
let response = dispatch_request(&state, &request).await;
|
||||||
|
|
||||||
|
match response {
|
||||||
|
Some(response) => Json(response).into_response(),
|
||||||
|
None => StatusCode::ACCEPTED.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn streamable_get_handler(headers: HeaderMap) -> Response {
|
||||||
|
if let Err(status) = validate_origin(&headers) {
|
||||||
|
return status.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode::METHOD_NOT_ALLOWED.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn streamable_delete_handler(headers: HeaderMap) -> Response {
|
||||||
|
if let Err(status) = validate_origin(&headers) {
|
||||||
|
return status.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusCode::METHOD_NOT_ALLOWED.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
/// SSE endpoint for streaming events
|
/// SSE endpoint for streaming events
|
||||||
async fn sse_handler(
|
async fn sse_handler(
|
||||||
State(state): State<Arc<McpState>>,
|
State(state): State<Arc<McpState>>,
|
||||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
headers: HeaderMap,
|
||||||
|
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
|
||||||
|
validate_origin(&headers)?;
|
||||||
|
info!(
|
||||||
|
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||||
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||||
|
"Opening legacy SSE MCP stream"
|
||||||
|
);
|
||||||
let mut broadcast_rx = state.event_tx.subscribe();
|
let mut broadcast_rx = state.event_tx.subscribe();
|
||||||
let (session_tx, mut session_rx) = mpsc::channel(32);
|
let (session_tx, mut session_rx) = mpsc::channel(32);
|
||||||
let session_id = Uuid::new_v4().to_string();
|
let session_id = Uuid::new_v4().to_string();
|
||||||
@@ -222,7 +338,7 @@ async fn sse_handler(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
|
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Message endpoint for JSON-RPC requests
|
/// Message endpoint for JSON-RPC requests
|
||||||
@@ -232,7 +348,16 @@ async fn message_handler(
|
|||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
Json(request): Json<JsonRpcRequest>,
|
Json(request): Json<JsonRpcRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
info!("Received MCP request: {}", request.method);
|
if let Err(status) = validate_origin(&headers) {
|
||||||
|
return status.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
method = %request.method,
|
||||||
|
agent_id = auth::get_optional_agent_id(&headers).as_deref().unwrap_or("unset"),
|
||||||
|
agent_type = auth::get_optional_agent_type(&headers).as_deref().unwrap_or("unset"),
|
||||||
|
"Received legacy SSE MCP request"
|
||||||
|
);
|
||||||
|
|
||||||
if let Some(session_id) = query.session_id.as_deref() {
|
if let Some(session_id) = query.session_id.as_deref() {
|
||||||
if !state.has_session(session_id).await {
|
if !state.has_session(session_id).await {
|
||||||
@@ -513,4 +638,22 @@ mod tests {
|
|||||||
"/mcp/message?sessionId=abc123"
|
"/mcp/message?sessionId=abc123"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn allows_matching_origin_and_host() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(ORIGIN, "https://ob.ingwaz.work".parse().unwrap());
|
||||||
|
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
|
||||||
|
|
||||||
|
assert!(validate_origin(&headers).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_mismatched_origin_and_host() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(ORIGIN, "https://evil.example".parse().unwrap());
|
||||||
|
headers.insert(HOST, "ob.ingwaz.work".parse().unwrap());
|
||||||
|
|
||||||
|
assert_eq!(validate_origin(&headers), Err(StatusCode::FORBIDDEN));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
300
tests/e2e_mcp.rs
300
tests/e2e_mcp.rs
@@ -110,6 +110,26 @@ async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> V
|
|||||||
.expect("JSON-RPC response body")
|
.expect("JSON-RPC response body")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn call_streamable_jsonrpc(
|
||||||
|
client: &reqwest::Client,
|
||||||
|
base: &str,
|
||||||
|
request: Value,
|
||||||
|
) -> reqwest::Response {
|
||||||
|
let mut req_builder = client
|
||||||
|
.post(format!("{base}/mcp"))
|
||||||
|
.header("Accept", "application/json, text/event-stream")
|
||||||
|
.json(&request);
|
||||||
|
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
req_builder = req_builder.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
req_builder
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("streamable JSON-RPC HTTP request")
|
||||||
|
}
|
||||||
|
|
||||||
/// Make an authenticated GET request to an MCP endpoint
|
/// Make an authenticated GET request to an MCP endpoint
|
||||||
async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response {
|
async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response {
|
||||||
let mut req_builder = client.get(format!("{base}{path}"));
|
let mut req_builder = client.get(format!("{base}{path}"));
|
||||||
@@ -357,6 +377,98 @@ async fn e2e_transport_tools_list_and_unknown_method() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_streamable_initialize_and_tools_list() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let initialize_response: Value = call_streamable_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "streamable-init-1",
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "e2e-client",
|
||||||
|
"version": "0.1.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("streamable initialize JSON");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
initialize_response
|
||||||
|
.get("result")
|
||||||
|
.and_then(|value| value.get("protocolVersion"))
|
||||||
|
.and_then(Value::as_str),
|
||||||
|
Some("2024-11-05")
|
||||||
|
);
|
||||||
|
|
||||||
|
let tools_list_response: Value = call_streamable_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "streamable-tools-list-1",
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("streamable tools/list JSON");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
tools_list_response
|
||||||
|
.get("result")
|
||||||
|
.and_then(|value| value.get("tools"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.map(|tools| !tools.is_empty())
|
||||||
|
.unwrap_or(false),
|
||||||
|
"streamable /mcp tools/list should return tool definitions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_streamable_get_returns_405() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let mut request = client
|
||||||
|
.get(format!("{base}/mcp"))
|
||||||
|
.header("Accept", "text/event-stream");
|
||||||
|
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
request = request.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = request.send().await.expect("GET /mcp");
|
||||||
|
assert_eq!(
|
||||||
|
response.status(),
|
||||||
|
reqwest::StatusCode::METHOD_NOT_ALLOWED,
|
||||||
|
"streamable GET /mcp should explicitly return 405 when standalone SSE streams are not offered"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_purge_requires_confirm_flag() {
|
async fn e2e_purge_requires_confirm_flag() {
|
||||||
let base = base_url();
|
let base = base_url();
|
||||||
@@ -879,60 +991,169 @@ async fn e2e_auth_enabled_accepts_test_key() {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_batch_store_basic() -> anyhow::Result<()> {
|
async fn e2e_batch_store_basic() -> anyhow::Result<()> {
|
||||||
let agent = format!("batch_{}", uuid::Uuid::new_v4());
|
let base = base_url();
|
||||||
let _ = db.purge_memories(&agent, None).await;
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
let resp = client.call_tool("batch_store", serde_json::json!({
|
ensure_schema().await;
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let agent = format!("batch_{}", uuid::Uuid::new_v4());
|
||||||
|
let _ = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"purge",
|
||||||
|
json!({ "agent_id": agent.clone(), "confirm": true }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let result = call_tool(&client, &base, "batch_store", serde_json::json!({
|
||||||
"agent_id": agent.clone(),
|
"agent_id": agent.clone(),
|
||||||
"entries": [
|
"entries": [
|
||||||
{"content": "Fact alpha for batch test"},
|
{"content": "Fact alpha for batch test"},
|
||||||
{"content": "Fact beta for batch test"},
|
{"content": "Fact beta for batch test"},
|
||||||
{"content": "Fact gamma for batch test"}
|
{"content": "Fact gamma for batch test"}
|
||||||
]
|
]
|
||||||
})).await?;
|
})).await;
|
||||||
|
|
||||||
|
let _ = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"purge",
|
||||||
|
json!({ "agent_id": agent, "confirm": true }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let result: Value = serde_json::from_str(&resp.content[0].text)?;
|
|
||||||
assert!(result["success"].as_bool().unwrap_or(false));
|
assert!(result["success"].as_bool().unwrap_or(false));
|
||||||
assert_eq!(result["count"].as_i64().unwrap_or(0), 3);
|
assert_eq!(result["count"].as_i64().unwrap_or(0), 3);
|
||||||
|
|
||||||
db.purge_memories(&agent, None).await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_batch_store_empty_rejected() -> anyhow::Result<()> {
|
async fn e2e_batch_store_empty_rejected() -> anyhow::Result<()> {
|
||||||
let resp = client.call_tool("batch_store", serde_json::json!({
|
let base = base_url();
|
||||||
"entries": []
|
let client = reqwest::Client::builder()
|
||||||
})).await;
|
.timeout(Duration::from_secs(20))
|
||||||
assert!(resp.is_err() || resp.as_ref().unwrap().is_error());
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "batch-empty-1",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "batch_store",
|
||||||
|
"arguments": {
|
||||||
|
"entries": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(response.get("error").is_some(), "empty batch_store should return an error");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_batch_store_exceeds_max() -> anyhow::Result<()> {
|
async fn e2e_batch_store_exceeds_max() -> anyhow::Result<()> {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
let entries: Vec<Value> = (0..51).map(|i| serde_json::json!({"content": format!("Entry {}", i)})).collect();
|
let entries: Vec<Value> = (0..51).map(|i| serde_json::json!({"content": format!("Entry {}", i)})).collect();
|
||||||
let resp = client.call_tool("batch_store", serde_json::json!({
|
let response = call_jsonrpc(
|
||||||
"entries": entries
|
&client,
|
||||||
})).await;
|
&base,
|
||||||
assert!(resp.is_err() || resp.as_ref().unwrap().is_error());
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "batch-too-large-1",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "batch_store",
|
||||||
|
"arguments": {
|
||||||
|
"entries": entries
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(response.get("error").is_some(), "oversized batch_store should return an error");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_batch_store_missing_content() -> anyhow::Result<()> {
|
async fn e2e_batch_store_missing_content() -> anyhow::Result<()> {
|
||||||
let resp = client.call_tool("batch_store", serde_json::json!({
|
let base = base_url();
|
||||||
"entries": [{"content": "Valid entry"}, {"metadata": {}}]
|
let client = reqwest::Client::builder()
|
||||||
})).await;
|
.timeout(Duration::from_secs(20))
|
||||||
assert!(resp.is_err() || resp.as_ref().unwrap().is_error());
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "batch-missing-content-1",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "batch_store",
|
||||||
|
"arguments": {
|
||||||
|
"entries": [{"content": "Valid entry"}, {"metadata": {}}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(response.get("error").is_some(), "missing batch entry content should return an error");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_batch_store_appears_in_tools() -> anyhow::Result<()> {
|
async fn e2e_batch_store_appears_in_tools() -> anyhow::Result<()> {
|
||||||
let tools = client.list_tools().await?;
|
let base = base_url();
|
||||||
let parsed: Value = serde_json::from_str(&tools.content[0].text)?;
|
let client = reqwest::Client::builder()
|
||||||
let names: Vec<&str> = parsed.as_array().unwrap().iter()
|
.timeout(Duration::from_secs(20))
|
||||||
.filter_map(|t| t.get("name").and_then(|n| n.as_str()))
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "batch-tools-list-1",
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let names: Vec<&str> = response
|
||||||
|
.get("result")
|
||||||
|
.and_then(|value| value.get("tools"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.expect("tools/list result.tools")
|
||||||
|
.iter()
|
||||||
|
.filter_map(|t| t.get("name").and_then(Value::as_str))
|
||||||
.collect();
|
.collect();
|
||||||
assert!(names.contains(&"batch_store"));
|
assert!(names.contains(&"batch_store"));
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -940,14 +1161,37 @@ async fn e2e_batch_store_appears_in_tools() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn e2e_existing_store_unchanged() -> anyhow::Result<()> {
|
async fn e2e_existing_store_unchanged() -> anyhow::Result<()> {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
ensure_schema().await;
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
let agent = format!("compat_{}", uuid::Uuid::new_v4());
|
let agent = format!("compat_{}", uuid::Uuid::new_v4());
|
||||||
let _ = db.purge_memories(&agent, None).await;
|
let _ = call_tool(
|
||||||
let resp = client.call_tool("store", serde_json::json!({
|
&client,
|
||||||
|
&base,
|
||||||
|
"purge",
|
||||||
|
json!({ "agent_id": agent.clone(), "confirm": true }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let result = call_tool(&client, &base, "store", serde_json::json!({
|
||||||
"agent_id": agent.clone(),
|
"agent_id": agent.clone(),
|
||||||
"content": "Original store still works"
|
"content": "Original store still works"
|
||||||
})).await?;
|
})).await;
|
||||||
let result: Value = serde_json::from_str(&resp.content[0].text)?;
|
|
||||||
|
let _ = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"purge",
|
||||||
|
json!({ "agent_id": agent, "confirm": true }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
assert!(result["success"].as_bool().unwrap_or(false));
|
assert!(result["success"].as_bool().unwrap_or(false));
|
||||||
db.purge_memories(&agent, None).await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user