mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Initial public release
This commit is contained in:
28
.env.example
Normal file
28
.env.example
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# OpenBrain MCP Server Configuration
|
||||||
|
# Copy this to .env and fill in your values
|
||||||
|
|
||||||
|
# Server Configuration
|
||||||
|
OPENBRAIN__SERVER__HOST=0.0.0.0
|
||||||
|
OPENBRAIN__SERVER__PORT=3100
|
||||||
|
|
||||||
|
# Database Configuration (PostgreSQL with pgvector)
|
||||||
|
# This role should own the OpenBrain database objects that migrations manage.
|
||||||
|
OPENBRAIN__DATABASE__HOST=localhost
|
||||||
|
OPENBRAIN__DATABASE__PORT=5432
|
||||||
|
OPENBRAIN__DATABASE__NAME=openbrain
|
||||||
|
OPENBRAIN__DATABASE__USER=openbrain_svc
|
||||||
|
OPENBRAIN__DATABASE__PASSWORD=your_secure_password_here
|
||||||
|
OPENBRAIN__DATABASE__POOL_SIZE=10
|
||||||
|
|
||||||
|
# Embedding Configuration
|
||||||
|
# Path to ONNX model directory (all-MiniLM-L6-v2)
|
||||||
|
OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2
|
||||||
|
OPENBRAIN__EMBEDDING__DIMENSION=384
|
||||||
|
|
||||||
|
# Authentication (optional)
|
||||||
|
OPENBRAIN__AUTH__ENABLED=false
|
||||||
|
# Comma-separated list of API keys
|
||||||
|
# OPENBRAIN__AUTH__API_KEYS=key1,key2,key3
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
RUST_LOG=info,openbrain_mcp=debug
|
||||||
213
.gitea/deploy.sh
Executable file
213
.gitea/deploy.sh
Executable file
@@ -0,0 +1,213 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# OpenBrain MCP Deployment Script
|
||||||
|
# Deploys the OpenBrain MCP server to the VPS
|
||||||
|
#
|
||||||
|
# Usage: ./deploy.sh [options]
|
||||||
|
# Options:
|
||||||
|
# --build-local Build on local machine (requires cross-compilation)
|
||||||
|
# --build-remote Build on VPS (default)
|
||||||
|
# --skip-model Skip model download
|
||||||
|
# --restart-only Only restart the service
|
||||||
|
#
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
VPS_HOST="${VPS_HOST:-}"
|
||||||
|
VPS_USER="${VPS_USER:-root}"
|
||||||
|
DEPLOY_DIR="/opt/openbrain-mcp"
|
||||||
|
SERVICE_NAME="openbrain-mcp"
|
||||||
|
SSH_KEY="${SSH_KEY:-/tmp/id_ed25519}"
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
||||||
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
|
log_error() { echo -e "${RED}[ERROR]${NC} $1"; }
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
BUILD_REMOTE=true
|
||||||
|
SKIP_MODEL=false
|
||||||
|
RESTART_ONLY=false
|
||||||
|
|
||||||
|
for arg in "$@"; do
|
||||||
|
case $arg in
|
||||||
|
--build-local) BUILD_REMOTE=false ;;
|
||||||
|
--build-remote) BUILD_REMOTE=true ;;
|
||||||
|
--skip-model) SKIP_MODEL=true ;;
|
||||||
|
--restart-only) RESTART_ONLY=true ;;
|
||||||
|
*) log_error "Unknown argument: $arg"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Get script directory (where .gitea folder is)
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||||
|
|
||||||
|
log_info "Project root: $PROJECT_ROOT"
|
||||||
|
if [ -z "$VPS_HOST" ]; then
|
||||||
|
log_error "VPS_HOST is required. Export VPS_HOST before running deploy.sh"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
log_info "Deploying to: $VPS_USER@$VPS_HOST:$DEPLOY_DIR"
|
||||||
|
|
||||||
|
# SSH command helper
|
||||||
|
ssh_cmd() {
|
||||||
|
ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no "$VPS_USER@$VPS_HOST" "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
scp_cmd() {
|
||||||
|
scp -i "$SSH_KEY" -o StrictHostKeyChecking=no "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Restart only mode
|
||||||
|
if [ "$RESTART_ONLY" = true ]; then
|
||||||
|
log_info "Restarting service only..."
|
||||||
|
ssh_cmd "systemctl restart $SERVICE_NAME"
|
||||||
|
ssh_cmd "systemctl status $SERVICE_NAME --no-pager"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 1: Create deployment directory on VPS
|
||||||
|
log_info "Creating deployment directory on VPS..."
|
||||||
|
ssh_cmd "mkdir -p $DEPLOY_DIR/{src,models,logs,lib,.gitea}"
|
||||||
|
|
||||||
|
# Step 2: Sync source code to VPS
|
||||||
|
log_info "Syncing source code to VPS..."
|
||||||
|
rsync -avz --delete \
|
||||||
|
-e "ssh -i $SSH_KEY -o StrictHostKeyChecking=no" \
|
||||||
|
--exclude 'target/' \
|
||||||
|
--exclude '.git/' \
|
||||||
|
--exclude '.a0proj/' \
|
||||||
|
--exclude 'models/' \
|
||||||
|
--exclude '*.md' \
|
||||||
|
"$PROJECT_ROOT/" \
|
||||||
|
"$VPS_USER@$VPS_HOST:$DEPLOY_DIR/"
|
||||||
|
|
||||||
|
# Step 3: Copy .env if it doesn't exist on VPS
|
||||||
|
if ! ssh_cmd "test -f $DEPLOY_DIR/.env"; then
|
||||||
|
log_warn ".env not found on VPS. Copying .env.example..."
|
||||||
|
ssh_cmd "cp $DEPLOY_DIR/.env.example $DEPLOY_DIR/.env"
|
||||||
|
log_warn "Please edit $DEPLOY_DIR/.env on VPS with actual credentials!"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 4: Download model if needed
|
||||||
|
if [ "$SKIP_MODEL" = false ]; then
|
||||||
|
log_info "Checking/downloading embedding model..."
|
||||||
|
ssh_cmd "bash $DEPLOY_DIR/.gitea/download-model.sh"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 5: Build on VPS
|
||||||
|
if [ "$BUILD_REMOTE" = true ]; then
|
||||||
|
log_info "Building on VPS (this may take a while on first run)..."
|
||||||
|
ssh_cmd "cd $DEPLOY_DIR && \
|
||||||
|
source ~/.cargo/env 2>/dev/null || true && \
|
||||||
|
cargo build --release 2>&1"
|
||||||
|
else
|
||||||
|
log_error "Local cross-compilation not yet implemented"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Step 5b: Install the built binary where systemd expects it
|
||||||
|
log_info "Installing built binary..."
|
||||||
|
ssh_cmd "cp $DEPLOY_DIR/target/release/openbrain-mcp $DEPLOY_DIR/openbrain-mcp && chmod +x $DEPLOY_DIR/openbrain-mcp"
|
||||||
|
|
||||||
|
# Step 5c: Bootstrap runtime prerequisites
|
||||||
|
log_info "Bootstrapping runtime prerequisites..."
|
||||||
|
ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no "$VPS_USER@$VPS_HOST" \
|
||||||
|
"DEPLOY_DIR=$DEPLOY_DIR SERVICE_USER=openbrain SERVICE_GROUP=openbrain ORT_VERSION=1.24.3 bash -s" <<'EOS'
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}"
|
||||||
|
SERVICE_USER="${SERVICE_USER:-openbrain}"
|
||||||
|
SERVICE_GROUP="${SERVICE_GROUP:-openbrain}"
|
||||||
|
ORT_VERSION="${ORT_VERSION:-1.24.3}"
|
||||||
|
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y --no-install-recommends ca-certificates curl tar libssl3
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! getent group "$SERVICE_GROUP" >/dev/null 2>&1; then
|
||||||
|
groupadd --system "$SERVICE_GROUP"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! id -u "$SERVICE_USER" >/dev/null 2>&1; then
|
||||||
|
useradd --system --gid "$SERVICE_GROUP" --home "$DEPLOY_DIR" --shell /usr/sbin/nologin "$SERVICE_USER"
|
||||||
|
fi
|
||||||
|
|
||||||
|
install -d -m 0755 "$DEPLOY_DIR" "$DEPLOY_DIR/models" "$DEPLOY_DIR/logs" "$DEPLOY_DIR/lib"
|
||||||
|
|
||||||
|
ARCH="$(uname -m)"
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ORT_ARCH="x64" ;;
|
||||||
|
aarch64|arm64) ORT_ARCH="aarch64" ;;
|
||||||
|
*) echo "Unsupported arch: $ARCH"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if [[ ! -f "$DEPLOY_DIR/lib/libonnxruntime.so" ]]; then
|
||||||
|
TMP_DIR="$(mktemp -d)"
|
||||||
|
ORT_TGZ="onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}.tgz"
|
||||||
|
ORT_URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/${ORT_TGZ}"
|
||||||
|
curl -fL "$ORT_URL" -o "$TMP_DIR/$ORT_TGZ"
|
||||||
|
tar -xzf "$TMP_DIR/$ORT_TGZ" -C "$TMP_DIR"
|
||||||
|
ORT_ROOT="$TMP_DIR/onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}"
|
||||||
|
cp "$ORT_ROOT/lib/libonnxruntime.so" "$DEPLOY_DIR/lib/libonnxruntime.so"
|
||||||
|
cp "$ORT_ROOT/lib/libonnxruntime.so.${ORT_VERSION}" "$DEPLOY_DIR/lib/libonnxruntime.so.${ORT_VERSION}" || true
|
||||||
|
rm -rf "$TMP_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
ENV_FILE="$DEPLOY_DIR/.env"
|
||||||
|
if [[ ! -f "$ENV_FILE" ]]; then
|
||||||
|
if [[ -f "$DEPLOY_DIR/.env.example" ]]; then
|
||||||
|
cp "$DEPLOY_DIR/.env.example" "$ENV_FILE"
|
||||||
|
else
|
||||||
|
touch "$ENV_FILE"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
upsert_env() {
|
||||||
|
local key="$1"
|
||||||
|
local value="$2"
|
||||||
|
if grep -qE "^${key}=" "$ENV_FILE"; then
|
||||||
|
sed -i "s|^${key}=.*|${key}=${value}|" "$ENV_FILE"
|
||||||
|
else
|
||||||
|
printf '%s=%s\n' "$key" "$value" >> "$ENV_FILE"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
upsert_env "OPENBRAIN__EMBEDDING__MODEL_PATH" "$DEPLOY_DIR/models/all-MiniLM-L6-v2"
|
||||||
|
upsert_env "ORT_DYLIB_PATH" "$DEPLOY_DIR/lib/libonnxruntime.so"
|
||||||
|
upsert_env "OPENBRAIN__SERVER__HOST" "0.0.0.0"
|
||||||
|
|
||||||
|
chmod +x "$DEPLOY_DIR/openbrain-mcp" "$DEPLOY_DIR/.gitea/download-model.sh"
|
||||||
|
chown -R "$SERVICE_USER:$SERVICE_GROUP" "$DEPLOY_DIR"
|
||||||
|
EOS
|
||||||
|
|
||||||
|
# Step 5d: Run database migrations with the newly deployed binary
|
||||||
|
log_info "Running database migrations..."
|
||||||
|
ssh_cmd "cd $DEPLOY_DIR && ./openbrain-mcp migrate"
|
||||||
|
|
||||||
|
# Step 6: Install systemd service
|
||||||
|
log_info "Installing systemd service..."
|
||||||
|
scp_cmd "$SCRIPT_DIR/openbrain.service" "$VPS_USER@$VPS_HOST:/etc/systemd/system/$SERVICE_NAME.service"
|
||||||
|
ssh_cmd "systemctl daemon-reload"
|
||||||
|
ssh_cmd "systemctl enable $SERVICE_NAME"
|
||||||
|
|
||||||
|
# Step 7: Restart service
|
||||||
|
log_info "Restarting service..."
|
||||||
|
ssh_cmd "systemctl restart $SERVICE_NAME"
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# Step 8: Check status
|
||||||
|
log_info "Checking service status..."
|
||||||
|
ssh_cmd "systemctl status $SERVICE_NAME --no-pager" || true
|
||||||
|
|
||||||
|
log_info "Deployment complete!"
|
||||||
|
log_info "Service URL: http://$VPS_HOST:3100/mcp/health"
|
||||||
92
.gitea/download-model.sh
Executable file
92
.gitea/download-model.sh
Executable file
@@ -0,0 +1,92 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Download ONNX embedding model for OpenBrain MCP
|
||||||
|
# Downloads all-MiniLM-L6-v2 from Hugging Face
|
||||||
|
#
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}"
|
||||||
|
MODEL_DIR="$DEPLOY_DIR/models/all-MiniLM-L6-v2"
|
||||||
|
MODEL_NAME="sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
# Colors
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m'
|
||||||
|
|
||||||
|
log_info() { echo -e "${GREEN}[INFO]${NC} $1"; }
|
||||||
|
log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; }
|
||||||
|
|
||||||
|
# Check if model already exists
|
||||||
|
if [ -f "$MODEL_DIR/model.onnx" ] && [ -f "$MODEL_DIR/tokenizer.json" ]; then
|
||||||
|
log_info "Model already exists at $MODEL_DIR"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
log_info "Downloading embedding model to $MODEL_DIR..."
|
||||||
|
mkdir -p "$MODEL_DIR"
|
||||||
|
|
||||||
|
# Method 1: Try using huggingface-cli if available
|
||||||
|
if command -v huggingface-cli &> /dev/null; then
|
||||||
|
log_info "Using huggingface-cli to download model..."
|
||||||
|
huggingface-cli download "$MODEL_NAME" \
|
||||||
|
--local-dir "$MODEL_DIR" \
|
||||||
|
--include "*.onnx" "*.json" "*.txt" \
|
||||||
|
--exclude "*.bin" "*.safetensors" "*.h5"
|
||||||
|
else
|
||||||
|
# Method 2: Direct download from Hugging Face
|
||||||
|
log_info "Downloading directly from Hugging Face..."
|
||||||
|
|
||||||
|
BASE_URL="https://huggingface.co/$MODEL_NAME/resolve/main"
|
||||||
|
|
||||||
|
# Download ONNX model (we need the optimized one)
|
||||||
|
# First try the onnx directory
|
||||||
|
ONNX_URL="https://huggingface.co/$MODEL_NAME/resolve/main/onnx/model.onnx"
|
||||||
|
|
||||||
|
log_info "Downloading model.onnx..."
|
||||||
|
if ! curl -fSL "$ONNX_URL" -o "$MODEL_DIR/model.onnx" 2>/dev/null; then
|
||||||
|
# Fallback: convert from pytorch (requires python)
|
||||||
|
log_warn "ONNX model not found, will need to convert from PyTorch..."
|
||||||
|
log_warn "Installing optimum for ONNX export..."
|
||||||
|
pip install --quiet optimum[exporters] onnx onnxruntime
|
||||||
|
|
||||||
|
python3 << PYEOF
|
||||||
|
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
model = ORTModelForFeatureExtraction.from_pretrained("$MODEL_NAME", export=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("$MODEL_NAME")
|
||||||
|
|
||||||
|
model.save_pretrained("$MODEL_DIR")
|
||||||
|
tokenizer.save_pretrained("$MODEL_DIR")
|
||||||
|
print("Model exported to ONNX successfully!")
|
||||||
|
PYEOF
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download tokenizer files
|
||||||
|
log_info "Downloading tokenizer.json..."
|
||||||
|
curl -fSL "$BASE_URL/tokenizer.json" -o "$MODEL_DIR/tokenizer.json" 2>/dev/null || true
|
||||||
|
|
||||||
|
log_info "Downloading tokenizer_config.json..."
|
||||||
|
curl -fSL "$BASE_URL/tokenizer_config.json" -o "$MODEL_DIR/tokenizer_config.json" 2>/dev/null || true
|
||||||
|
|
||||||
|
log_info "Downloading config.json..."
|
||||||
|
curl -fSL "$BASE_URL/config.json" -o "$MODEL_DIR/config.json" 2>/dev/null || true
|
||||||
|
|
||||||
|
log_info "Downloading vocab.txt..."
|
||||||
|
curl -fSL "$BASE_URL/vocab.txt" -o "$MODEL_DIR/vocab.txt" 2>/dev/null || true
|
||||||
|
|
||||||
|
log_info "Downloading special_tokens_map.json..."
|
||||||
|
curl -fSL "$BASE_URL/special_tokens_map.json" -o "$MODEL_DIR/special_tokens_map.json" 2>/dev/null || true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Verify download
|
||||||
|
if [ -f "$MODEL_DIR/model.onnx" ]; then
|
||||||
|
MODEL_SIZE=$(du -h "$MODEL_DIR/model.onnx" | cut -f1)
|
||||||
|
log_info "Model downloaded successfully! Size: $MODEL_SIZE"
|
||||||
|
ls -la "$MODEL_DIR/"
|
||||||
|
else
|
||||||
|
log_warn "Warning: model.onnx not found after download"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
31
.gitea/openbrain.service
Normal file
31
.gitea/openbrain.service
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=OpenBrain MCP Server - Vector Memory for AI Agents
|
||||||
|
After=network-online.target postgresql.service
|
||||||
|
Wants=network-online.target postgresql.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=openbrain
|
||||||
|
Group=openbrain
|
||||||
|
WorkingDirectory=/opt/openbrain-mcp
|
||||||
|
EnvironmentFile=/opt/openbrain-mcp/.env
|
||||||
|
ExecStart=/opt/openbrain-mcp/openbrain-mcp
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5
|
||||||
|
StandardOutput=journal
|
||||||
|
StandardError=journal
|
||||||
|
SyslogIdentifier=openbrain-mcp
|
||||||
|
|
||||||
|
# Security hardening
|
||||||
|
NoNewPrivileges=true
|
||||||
|
PrivateTmp=true
|
||||||
|
ProtectSystem=strict
|
||||||
|
ProtectHome=true
|
||||||
|
ReadWritePaths=/opt/openbrain-mcp /opt/openbrain-mcp/logs /opt/openbrain-mcp/models /opt/openbrain-mcp/lib
|
||||||
|
|
||||||
|
# Resource limits
|
||||||
|
LimitNOFILE=65535
|
||||||
|
MemoryMax=1G
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
188
.gitea/workflows/ci-cd.yaml
Normal file
188
.gitea/workflows/ci-cd.yaml
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
name: OpenBrain MCP Build and Deploy
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- '**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
CARGO_TERM_COLOR: always
|
||||||
|
VPS_HOST: ${{ vars.VPS_HOST }}
|
||||||
|
VPS_USER: ${{ vars.VPS_USER }}
|
||||||
|
DEPLOY_DIR: /opt/openbrain-mcp
|
||||||
|
SERVICE_NAME: openbrain-mcp
|
||||||
|
steps:
|
||||||
|
- name: Install prerequisites
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
if command -v sudo >/dev/null 2>&1; then SUDO=sudo; else SUDO=; fi
|
||||||
|
$SUDO apt-get update
|
||||||
|
$SUDO apt-get install -y --no-install-recommends git ca-certificates curl
|
||||||
|
elif command -v apk >/dev/null 2>&1; then
|
||||||
|
apk add --no-cache git ca-certificates curl
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Checkout repository
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
git clone --depth 1 "https://${{ github.token }}@gitea.ingwaz.work/${{ github.repository }}.git" .
|
||||||
|
git fetch origin "${{ github.ref }}"
|
||||||
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
if command -v sudo >/dev/null 2>&1; then SUDO=sudo; else SUDO=; fi
|
||||||
|
$SUDO apt-get install -y --no-install-recommends build-essential pkg-config libssl-dev openssh-client
|
||||||
|
|
||||||
|
- name: Install Rust toolchain
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
|
||||||
|
. "$HOME/.cargo/env"
|
||||||
|
rustc --version
|
||||||
|
cargo --version
|
||||||
|
|
||||||
|
- name: CI checks
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
. "$HOME/.cargo/env"
|
||||||
|
cargo check
|
||||||
|
cargo test --no-run
|
||||||
|
|
||||||
|
- name: Build release
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
. "$HOME/.cargo/env"
|
||||||
|
cargo build --release
|
||||||
|
test -x target/release/openbrain-mcp
|
||||||
|
|
||||||
|
- name: Setup SSH auth
|
||||||
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
||||||
|
: "${VPS_USER:=root}"
|
||||||
|
install -d -m 700 "$HOME/.ssh"
|
||||||
|
printf '%s\n' "${{ secrets.VPS_SSH_KEY }}" > "$HOME/.ssh/deploy_key"
|
||||||
|
chmod 600 "$HOME/.ssh/deploy_key"
|
||||||
|
ssh-keyscan -H "$VPS_HOST" >> "$HOME/.ssh/known_hosts"
|
||||||
|
|
||||||
|
- name: Deploy artifacts
|
||||||
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
||||||
|
: "${VPS_USER:=root}"
|
||||||
|
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
|
SCP="scp -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
|
|
||||||
|
# Stop service before deploying to avoid "Text file busy" error
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "systemctl stop $SERVICE_NAME 2>/dev/null || true"
|
||||||
|
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "mkdir -p $DEPLOY_DIR/.gitea $DEPLOY_DIR/models $DEPLOY_DIR/logs $DEPLOY_DIR/lib"
|
||||||
|
|
||||||
|
$SCP target/release/openbrain-mcp "$VPS_USER@$VPS_HOST:$DEPLOY_DIR/openbrain-mcp"
|
||||||
|
$SCP .gitea/openbrain.service "$VPS_USER@$VPS_HOST:/etc/systemd/system/$SERVICE_NAME.service"
|
||||||
|
$SCP .gitea/download-model.sh "$VPS_USER@$VPS_HOST:$DEPLOY_DIR/.gitea/download-model.sh"
|
||||||
|
|
||||||
|
- name: Bootstrap VPS and restart service
|
||||||
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
||||||
|
: "${VPS_USER:=root}"
|
||||||
|
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
|
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "DEPLOY_DIR=$DEPLOY_DIR SERVICE_USER=openbrain SERVICE_GROUP=openbrain ORT_VERSION=1.24.3 bash -s" <<'EOS'
|
||||||
|
set -euo pipefail
|
||||||
|
DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}"
|
||||||
|
SERVICE_USER="${SERVICE_USER:-openbrain}"
|
||||||
|
SERVICE_GROUP="${SERVICE_GROUP:-openbrain}"
|
||||||
|
ORT_VERSION="${ORT_VERSION:-1.24.3}"
|
||||||
|
|
||||||
|
if command -v apt-get >/dev/null 2>&1; then
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
apt-get update
|
||||||
|
apt-get install -y --no-install-recommends ca-certificates curl tar libssl3
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! getent group "$SERVICE_GROUP" >/dev/null 2>&1; then
|
||||||
|
groupadd --system "$SERVICE_GROUP"
|
||||||
|
fi
|
||||||
|
if ! id -u "$SERVICE_USER" >/dev/null 2>&1; then
|
||||||
|
useradd --system --gid "$SERVICE_GROUP" --home "$DEPLOY_DIR" --shell /usr/sbin/nologin "$SERVICE_USER"
|
||||||
|
fi
|
||||||
|
|
||||||
|
install -d -m 0755 "$DEPLOY_DIR" "$DEPLOY_DIR/models" "$DEPLOY_DIR/logs" "$DEPLOY_DIR/lib"
|
||||||
|
|
||||||
|
ARCH="$(uname -m)"
|
||||||
|
case "$ARCH" in
|
||||||
|
x86_64) ORT_ARCH="x64" ;;
|
||||||
|
aarch64|arm64) ORT_ARCH="aarch64" ;;
|
||||||
|
*) echo "Unsupported arch: $ARCH"; exit 1 ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if [[ ! -f "$DEPLOY_DIR/lib/libonnxruntime.so" ]]; then
|
||||||
|
TMP_DIR="$(mktemp -d)"
|
||||||
|
ORT_TGZ="onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}.tgz"
|
||||||
|
ORT_URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/${ORT_TGZ}"
|
||||||
|
curl -fL "$ORT_URL" -o "$TMP_DIR/$ORT_TGZ"
|
||||||
|
tar -xzf "$TMP_DIR/$ORT_TGZ" -C "$TMP_DIR"
|
||||||
|
ORT_ROOT="$TMP_DIR/onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}"
|
||||||
|
cp "$ORT_ROOT/lib/libonnxruntime.so" "$DEPLOY_DIR/lib/libonnxruntime.so"
|
||||||
|
cp "$ORT_ROOT/lib/libonnxruntime.so.${ORT_VERSION}" "$DEPLOY_DIR/lib/libonnxruntime.so.${ORT_VERSION}" || true
|
||||||
|
rm -rf "$TMP_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
ENV_FILE="$DEPLOY_DIR/.env"
|
||||||
|
if [[ ! -f "$ENV_FILE" ]]; then
|
||||||
|
if [[ -f "$DEPLOY_DIR/.env.example" ]]; then cp "$DEPLOY_DIR/.env.example" "$ENV_FILE"; else touch "$ENV_FILE"; fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
upsert_env() {
|
||||||
|
local key="$1"
|
||||||
|
local value="$2"
|
||||||
|
if grep -qE "^${key}=" "$ENV_FILE"; then
|
||||||
|
sed -i "s|^${key}=.*|${key}=${value}|" "$ENV_FILE"
|
||||||
|
else
|
||||||
|
printf '%s=%s\n' "$key" "$value" >> "$ENV_FILE"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
upsert_env "OPENBRAIN__EMBEDDING__MODEL_PATH" "$DEPLOY_DIR/models/all-MiniLM-L6-v2"
|
||||||
|
upsert_env "ORT_DYLIB_PATH" "$DEPLOY_DIR/lib/libonnxruntime.so"
|
||||||
|
upsert_env "OPENBRAIN__SERVER__HOST" "0.0.0.0"
|
||||||
|
|
||||||
|
chmod +x "$DEPLOY_DIR/openbrain-mcp" "$DEPLOY_DIR/.gitea/download-model.sh"
|
||||||
|
chown -R "$SERVICE_USER:$SERVICE_GROUP" "$DEPLOY_DIR"
|
||||||
|
EOS
|
||||||
|
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "DEPLOY_DIR=$DEPLOY_DIR bash $DEPLOY_DIR/.gitea/download-model.sh"
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "cd $DEPLOY_DIR && ./openbrain-mcp migrate"
|
||||||
|
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "systemctl daemon-reload"
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "systemctl enable $SERVICE_NAME"
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "systemctl restart $SERVICE_NAME"
|
||||||
|
|
||||||
|
- name: Verify deployment
|
||||||
|
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master'
|
||||||
|
run: |
|
||||||
|
set -euxo pipefail
|
||||||
|
: "${VPS_HOST:?Set repository variable VPS_HOST}"
|
||||||
|
: "${VPS_USER:=root}"
|
||||||
|
SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes"
|
||||||
|
|
||||||
|
sleep 5
|
||||||
|
$SSH "$VPS_USER@$VPS_HOST" "systemctl status $SERVICE_NAME --no-pager || journalctl -u $SERVICE_NAME --no-pager -n 80"
|
||||||
|
curl -fsS "http://$VPS_HOST:3100/health"
|
||||||
|
curl -fsS "http://$VPS_HOST:3100/ready"
|
||||||
|
|
||||||
|
- name: Cleanup SSH key
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
rm -f "$HOME/.ssh/deploy_key"
|
||||||
27
.gitignore
vendored
Normal file
27
.gitignore
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Rust build artifacts
|
||||||
|
/target/
|
||||||
|
**/*.rs.bk
|
||||||
|
Cargo.lock
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/*.log
|
||||||
|
|
||||||
|
# Models (downloaded at runtime)
|
||||||
|
models/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
.a0proj
|
||||||
|
CREDENTIALS.md
|
||||||
65
Cargo.toml
Normal file
65
Cargo.toml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
[package]
|
||||||
|
name = "openbrain-mcp"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
authors = ["Ingwaz <agent@ingwaz.work>"]
|
||||||
|
description = "High-performance vector memory MCP server for AI agents"
|
||||||
|
license = "MIT"
|
||||||
|
repository = "https://gitea.ingwaz.work/Ingwaz/openbrain-mcp"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# MCP Framework
|
||||||
|
rmcp = { version = "0.1", features = ["server", "transport-sse"] }
|
||||||
|
|
||||||
|
# ONNX Runtime for local embeddings
|
||||||
|
ort = { version = "2.0.0-rc.12", features = ["load-dynamic"] }
|
||||||
|
ndarray = "0.16"
|
||||||
|
tokenizers = "0.21"
|
||||||
|
|
||||||
|
# Database
|
||||||
|
tokio-postgres = { version = "0.7", features = ["with-uuid-1", "with-chrono-0_4", "with-serde_json-1"] }
|
||||||
|
deadpool-postgres = "0.14"
|
||||||
|
pgvector = { version = "0.4", features = ["postgres"] }
|
||||||
|
refinery = { version = "0.9", features = ["tokio-postgres"] }
|
||||||
|
|
||||||
|
# HTTP Server
|
||||||
|
axum = { version = "0.8", features = ["macros"] }
|
||||||
|
axum-extra = { version = "0.10", features = ["typed-header"] }
|
||||||
|
tower = "0.5"
|
||||||
|
tower-http = { version = "0.6", features = ["cors", "trace"] }
|
||||||
|
|
||||||
|
# Async Runtime
|
||||||
|
tokio = { version = "1.44", features = ["full"] }
|
||||||
|
futures = "0.3"
|
||||||
|
async-stream = "0.3"
|
||||||
|
|
||||||
|
# Serialization
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
|
||||||
|
# Security
|
||||||
|
sha2 = "0.10"
|
||||||
|
hex = "0.4"
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
uuid = { version = "1.16", features = ["v4", "serde"] }
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
thiserror = "2.0"
|
||||||
|
anyhow = "1.0"
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
config = "0.15"
|
||||||
|
dotenvy = "0.15"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
lto = true
|
||||||
|
codegen-units = 1
|
||||||
|
opt-level = 3
|
||||||
|
strip = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||||
161
README.md
Normal file
161
README.md
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# OpenBrain MCP Server
|
||||||
|
|
||||||
|
**High-performance vector memory for AI agents**
|
||||||
|
|
||||||
|
OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with a persistent, semantic memory system. It uses local ONNX-based embeddings and PostgreSQL with pgvector for efficient similarity search.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search
|
||||||
|
- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2
|
||||||
|
- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing
|
||||||
|
- 🔌 **MCP Protocol**: Standard Model Context Protocol over SSE transport
|
||||||
|
- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent
|
||||||
|
- ⚡ **High Performance**: Rust implementation with async I/O
|
||||||
|
|
||||||
|
## MCP Tools
|
||||||
|
|
||||||
|
| Tool | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `store` | Store a memory with automatic embedding generation and keyword extraction |
|
||||||
|
| `query` | Search memories by semantic similarity |
|
||||||
|
| `purge` | Delete memories by agent ID or time range |
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Rust 1.75+
|
||||||
|
- PostgreSQL 14+ with pgvector extension
|
||||||
|
- ONNX model files (all-MiniLM-L6-v2)
|
||||||
|
|
||||||
|
### Database Setup
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE ROLE openbrain_svc LOGIN PASSWORD 'change-me';
|
||||||
|
CREATE DATABASE openbrain OWNER openbrain_svc;
|
||||||
|
\c openbrain
|
||||||
|
CREATE EXTENSION IF NOT EXISTS vector;
|
||||||
|
```
|
||||||
|
|
||||||
|
Use the same PostgreSQL role for the app and for migrations. Do not create the
|
||||||
|
`memories` table manually as `postgres` or another owner and then run
|
||||||
|
OpenBrain as `openbrain_svc`, because later `ALTER TABLE` migrations will fail
|
||||||
|
with `must be owner of table memories`.
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your database credentials
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build & Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo build --release
|
||||||
|
./target/release/openbrain-mcp migrate
|
||||||
|
./target/release/openbrain-mcp
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Migrations
|
||||||
|
|
||||||
|
This project uses `refinery` with embedded SQL migrations in `migrations/`.
|
||||||
|
|
||||||
|
Run pending migrations explicitly before starting or restarting the service:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./target/release/openbrain-mcp migrate
|
||||||
|
```
|
||||||
|
|
||||||
|
If you use the deploy script or CI workflow in [`.gitea/deploy.sh`](/Users/bobbytables/ai/openbrain-mcp/.gitea/deploy.sh) and [`.gitea/workflows/ci-cd.yaml`](/Users/bobbytables/ai/openbrain-mcp/.gitea/workflows/ci-cd.yaml), they already run this for you.
|
||||||
|
|
||||||
|
## MCP Integration
|
||||||
|
|
||||||
|
Connect to the server using SSE transport:
|
||||||
|
|
||||||
|
```
|
||||||
|
SSE Endpoint: http://localhost:3100/mcp/sse
|
||||||
|
Message Endpoint: http://localhost:3100/mcp/message
|
||||||
|
Health Check: http://localhost:3100/mcp/health
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Store a Memory
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "store",
|
||||||
|
"arguments": {
|
||||||
|
"content": "The user prefers dark mode and uses vim keybindings",
|
||||||
|
"agent_id": "assistant-1",
|
||||||
|
"metadata": {"source": "preferences"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example: Query Memories
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 2,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "query",
|
||||||
|
"arguments": {
|
||||||
|
"query": "What are the user's editor preferences?",
|
||||||
|
"agent_id": "assistant-1",
|
||||||
|
"limit": 5,
|
||||||
|
"threshold": 0.6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ AI Agent │
|
||||||
|
└─────────────────────┬───────────────────────────────────┘
|
||||||
|
│ MCP Protocol (SSE)
|
||||||
|
┌─────────────────────▼───────────────────────────────────┐
|
||||||
|
│ OpenBrain MCP Server │
|
||||||
|
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ store │ │ query │ │ purge │ │
|
||||||
|
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ┌──────▼────────────────▼────────────────▼──────┐ │
|
||||||
|
│ │ Embedding Engine (ONNX) │ │
|
||||||
|
│ │ all-MiniLM-L6-v2 (384d) │ │
|
||||||
|
│ └──────────────────────┬────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ┌──────────────────────▼────────────────────────┐ │
|
||||||
|
│ │ PostgreSQL + pgvector │ │
|
||||||
|
│ │ HNSW Index for fast search │ │
|
||||||
|
│ └────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `OPENBRAIN__SERVER__HOST` | `0.0.0.0` | Server bind address |
|
||||||
|
| `OPENBRAIN__SERVER__PORT` | `3100` | Server port |
|
||||||
|
| `OPENBRAIN__DATABASE__HOST` | `localhost` | PostgreSQL host |
|
||||||
|
| `OPENBRAIN__DATABASE__PORT` | `5432` | PostgreSQL port |
|
||||||
|
| `OPENBRAIN__DATABASE__NAME` | `openbrain` | Database name |
|
||||||
|
| `OPENBRAIN__DATABASE__USER` | - | Database user |
|
||||||
|
| `OPENBRAIN__DATABASE__PASSWORD` | - | Database password |
|
||||||
|
| `OPENBRAIN__EMBEDDING__MODEL_PATH` | `models/all-MiniLM-L6-v2` | ONNX model path |
|
||||||
|
| `OPENBRAIN__AUTH__ENABLED` | `false` | Enable API key auth |
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
38
migrations/V1__baseline_memories.sql
Normal file
38
migrations/V1__baseline_memories.sql
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
-- Run OpenBrain migrations as the same database role that owns these objects.
|
||||||
|
-- Existing installs with a differently owned memories table must transfer
|
||||||
|
-- ownership before this baseline migration can apply ALTER TABLE changes.
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id UUID PRIMARY KEY,
|
||||||
|
agent_id VARCHAR(255) NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding vector(384) NOT NULL,
|
||||||
|
keywords TEXT[] DEFAULT '{}',
|
||||||
|
metadata JSONB DEFAULT '{}'::jsonb,
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ALTER COLUMN agent_id TYPE VARCHAR(255);
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ADD COLUMN IF NOT EXISTS keywords TEXT[] DEFAULT '{}';
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'::jsonb;
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT NOW();
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ALTER COLUMN keywords SET DEFAULT '{}';
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ALTER COLUMN metadata SET DEFAULT '{}'::jsonb;
|
||||||
|
|
||||||
|
ALTER TABLE memories
|
||||||
|
ALTER COLUMN created_at SET DEFAULT NOW();
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(agent_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories
|
||||||
|
USING hnsw (embedding vector_cosine_ops);
|
||||||
125
src/auth.rs
Normal file
125
src/auth.rs
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
//! Authentication module for OpenBrain MCP
|
||||||
|
//!
|
||||||
|
//! Provides API key-based authentication for securing the MCP endpoints.
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Request, State},
|
||||||
|
http::{HeaderMap, StatusCode, header::AUTHORIZATION},
|
||||||
|
middleware::Next,
|
||||||
|
response::Response,
|
||||||
|
};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Hash an API key for secure comparison
|
||||||
|
pub fn hash_api_key(key: &str) -> String {
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
hasher.update(key.as_bytes());
|
||||||
|
hex::encode(hasher.finalize())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware for API key authentication
|
||||||
|
pub async fn auth_middleware(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
request: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> Result<Response, StatusCode> {
|
||||||
|
// Skip auth if disabled
|
||||||
|
if !state.config.auth.enabled {
|
||||||
|
return Ok(next.run(request).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_key = extract_api_key(request.headers());
|
||||||
|
|
||||||
|
match api_key {
|
||||||
|
Some(key) => {
|
||||||
|
// Check if key is valid
|
||||||
|
let key_hash = hash_api_key(&key);
|
||||||
|
let valid = state
|
||||||
|
.config
|
||||||
|
.auth
|
||||||
|
.api_keys
|
||||||
|
.iter()
|
||||||
|
.any(|k| hash_api_key(k) == key_hash);
|
||||||
|
|
||||||
|
if valid {
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
} else {
|
||||||
|
warn!("Invalid API key or bearer token attempted");
|
||||||
|
Err(StatusCode::UNAUTHORIZED)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!("Missing API key or bearer token in request");
|
||||||
|
Err(StatusCode::UNAUTHORIZED)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
|
||||||
|
headers
|
||||||
|
.get("X-API-Key")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
.or_else(|| {
|
||||||
|
headers
|
||||||
|
.get(AUTHORIZATION)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|value| {
|
||||||
|
let (scheme, token) = value.split_once(' ')?;
|
||||||
|
scheme
|
||||||
|
.eq_ignore_ascii_case("bearer")
|
||||||
|
.then_some(token.trim())
|
||||||
|
})
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_optional_agent_id(headers: &HeaderMap) -> Option<String> {
|
||||||
|
headers
|
||||||
|
.get("X-Agent-ID")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(ToOwned::to_owned)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract agent ID from request headers or default
|
||||||
|
pub fn get_agent_id(request: &Request) -> String {
|
||||||
|
get_optional_agent_id(request.headers())
|
||||||
|
.unwrap_or_else(|| "default".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use axum::http::{HeaderValue, header::AUTHORIZATION};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_api_key_from_bearer_header() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
AUTHORIZATION,
|
||||||
|
HeaderValue::from_static("Bearer test-token"),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extracts_agent_id_from_header() {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Agent-ID", HeaderValue::from_static("agent-zero"));
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
get_optional_agent_id(&headers).as_deref(),
|
||||||
|
Some("agent-zero")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
146
src/config.rs
Normal file
146
src/config.rs
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
//! Configuration management for OpenBrain MCP
|
||||||
|
//!
|
||||||
|
//! Loads configuration from environment variables with sensible defaults.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use serde::{Deserialize, Deserializer};
|
||||||
|
|
||||||
|
/// Main configuration structure
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub server: ServerConfig,
|
||||||
|
pub database: DatabaseConfig,
|
||||||
|
pub embedding: EmbeddingConfig,
|
||||||
|
pub auth: AuthConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server configuration
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct ServerConfig {
|
||||||
|
#[serde(default = "default_host")]
|
||||||
|
pub host: String,
|
||||||
|
#[serde(default = "default_port")]
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Database configuration
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct DatabaseConfig {
|
||||||
|
pub host: String,
|
||||||
|
#[serde(default = "default_db_port")]
|
||||||
|
pub port: u16,
|
||||||
|
pub name: String,
|
||||||
|
pub user: String,
|
||||||
|
pub password: String,
|
||||||
|
#[serde(default = "default_pool_size")]
|
||||||
|
pub pool_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embedding engine configuration
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct EmbeddingConfig {
|
||||||
|
#[serde(default = "default_model_path")]
|
||||||
|
pub model_path: String,
|
||||||
|
#[serde(default = "default_embedding_dim")]
|
||||||
|
pub dimension: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Authentication configuration
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct AuthConfig {
|
||||||
|
#[serde(default = "default_auth_enabled")]
|
||||||
|
pub enabled: bool,
|
||||||
|
#[serde(default, deserialize_with = "deserialize_api_keys")]
|
||||||
|
pub api_keys: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize API keys from either a comma-separated string or a Vec<String>
|
||||||
|
fn deserialize_api_keys<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
// Try to deserialize as a string first, then as a Vec
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum StringOrVec {
|
||||||
|
String(String),
|
||||||
|
Vec(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
match Option::<StringOrVec>::deserialize(deserializer)? {
|
||||||
|
Some(StringOrVec::String(s)) => {
|
||||||
|
Ok(s.split(',')
|
||||||
|
.map(|k| k.trim().to_string())
|
||||||
|
.filter(|k| !k.is_empty())
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
Some(StringOrVec::Vec(v)) => Ok(v),
|
||||||
|
None => Ok(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default value functions
|
||||||
|
fn default_host() -> String { "0.0.0.0".to_string() }
|
||||||
|
fn default_port() -> u16 { 3100 }
|
||||||
|
fn default_db_port() -> u16 { 5432 }
|
||||||
|
fn default_pool_size() -> usize { 10 }
|
||||||
|
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
||||||
|
fn default_embedding_dim() -> usize { 384 }
|
||||||
|
fn default_auth_enabled() -> bool { false }
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
/// Load configuration from environment variables
|
||||||
|
pub fn load() -> Result<Self> {
|
||||||
|
// Load .env file if present
|
||||||
|
dotenvy::dotenv().ok();
|
||||||
|
|
||||||
|
let config = config::Config::builder()
|
||||||
|
// Server settings
|
||||||
|
.set_default("server.host", default_host())?
|
||||||
|
.set_default("server.port", default_port() as i64)?
|
||||||
|
// Database settings
|
||||||
|
.set_default("database.port", default_db_port() as i64)?
|
||||||
|
.set_default("database.pool_size", default_pool_size() as i64)?
|
||||||
|
// Embedding settings
|
||||||
|
.set_default("embedding.model_path", default_model_path())?
|
||||||
|
.set_default("embedding.dimension", default_embedding_dim() as i64)?
|
||||||
|
// Auth settings
|
||||||
|
.set_default("auth.enabled", default_auth_enabled())?
|
||||||
|
// Load from environment with OPENBRAIN_ prefix
|
||||||
|
.add_source(
|
||||||
|
config::Environment::with_prefix("OPENBRAIN")
|
||||||
|
.separator("__")
|
||||||
|
.try_parsing(true),
|
||||||
|
)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
Ok(config.try_deserialize()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
server: ServerConfig {
|
||||||
|
host: default_host(),
|
||||||
|
port: default_port(),
|
||||||
|
},
|
||||||
|
database: DatabaseConfig {
|
||||||
|
host: "localhost".to_string(),
|
||||||
|
port: default_db_port(),
|
||||||
|
name: "openbrain".to_string(),
|
||||||
|
user: "openbrain_svc".to_string(),
|
||||||
|
password: String::new(),
|
||||||
|
pool_size: default_pool_size(),
|
||||||
|
},
|
||||||
|
embedding: EmbeddingConfig {
|
||||||
|
model_path: default_model_path(),
|
||||||
|
dimension: default_embedding_dim(),
|
||||||
|
},
|
||||||
|
auth: AuthConfig {
|
||||||
|
enabled: default_auth_enabled(),
|
||||||
|
api_keys: Vec::new(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
176
src/db.rs
Normal file
176
src/db.rs
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
//! Database module for PostgreSQL with pgvector support
|
||||||
|
//!
|
||||||
|
//! Provides connection pooling and query helpers for vector operations.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use deadpool_postgres::{Config, Pool, Runtime};
|
||||||
|
use pgvector::Vector;
|
||||||
|
use tokio_postgres::NoTls;
|
||||||
|
use tracing::info;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::config::DatabaseConfig;
|
||||||
|
|
||||||
|
/// Database wrapper with connection pool
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Database {
|
||||||
|
pool: Pool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A memory record stored in the database
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MemoryRecord {
|
||||||
|
pub id: Uuid,
|
||||||
|
pub agent_id: String,
|
||||||
|
pub content: String,
|
||||||
|
pub embedding: Vec<f32>,
|
||||||
|
pub keywords: Vec<String>,
|
||||||
|
pub metadata: serde_json::Value,
|
||||||
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query result with similarity score
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MemoryMatch {
|
||||||
|
pub record: MemoryRecord,
|
||||||
|
pub similarity: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
/// Create a new database connection pool
|
||||||
|
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
|
||||||
|
let mut cfg = Config::new();
|
||||||
|
cfg.host = Some(config.host.clone());
|
||||||
|
cfg.port = Some(config.port);
|
||||||
|
cfg.dbname = Some(config.name.clone());
|
||||||
|
cfg.user = Some(config.user.clone());
|
||||||
|
cfg.password = Some(config.password.clone());
|
||||||
|
|
||||||
|
let pool = cfg
|
||||||
|
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||||
|
.context("Failed to create database pool")?;
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
let client = pool.get().await.context("Failed to get database connection")?;
|
||||||
|
client
|
||||||
|
.simple_query("SELECT 1")
|
||||||
|
.await
|
||||||
|
.context("Failed to execute test query")?;
|
||||||
|
|
||||||
|
info!("Database connection pool created with {} connections", config.pool_size);
|
||||||
|
|
||||||
|
Ok(Self { pool })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store a memory record
|
||||||
|
pub async fn store_memory(
|
||||||
|
&self,
|
||||||
|
agent_id: &str,
|
||||||
|
content: &str,
|
||||||
|
embedding: &[f32],
|
||||||
|
keywords: &[String],
|
||||||
|
metadata: serde_json::Value,
|
||||||
|
) -> Result<Uuid> {
|
||||||
|
let client = self.pool.get().await?;
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let vector = Vector::from(embedding.to_vec());
|
||||||
|
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
"#,
|
||||||
|
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.context("Failed to store memory")?;
|
||||||
|
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query memories by vector similarity
|
||||||
|
pub async fn query_memories(
|
||||||
|
&self,
|
||||||
|
agent_id: &str,
|
||||||
|
embedding: &[f32],
|
||||||
|
limit: i64,
|
||||||
|
threshold: f32,
|
||||||
|
) -> Result<Vec<MemoryMatch>> {
|
||||||
|
let client = self.pool.get().await?;
|
||||||
|
let vector = Vector::from(embedding.to_vec());
|
||||||
|
let threshold_f64 = threshold as f64;
|
||||||
|
|
||||||
|
let rows = client
|
||||||
|
.query(
|
||||||
|
r#"
|
||||||
|
SELECT
|
||||||
|
id, agent_id, content, keywords, metadata, created_at,
|
||||||
|
(1 - (embedding <=> $1))::real AS similarity
|
||||||
|
FROM memories
|
||||||
|
WHERE agent_id = $2
|
||||||
|
AND (1 - (embedding <=> $1)) >= $3
|
||||||
|
ORDER BY embedding <=> $1
|
||||||
|
LIMIT $4
|
||||||
|
"#,
|
||||||
|
&[&vector, &agent_id, &threshold_f64, &limit],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.context("Failed to query memories")?;
|
||||||
|
|
||||||
|
let matches = rows
|
||||||
|
.iter()
|
||||||
|
.map(|row| MemoryMatch {
|
||||||
|
record: MemoryRecord {
|
||||||
|
id: row.get("id"),
|
||||||
|
agent_id: row.get("agent_id"),
|
||||||
|
content: row.get("content"),
|
||||||
|
// Query responses do not include raw embedding payloads.
|
||||||
|
embedding: Vec::new(),
|
||||||
|
keywords: row.get("keywords"),
|
||||||
|
metadata: row.get("metadata"),
|
||||||
|
created_at: row.get("created_at"),
|
||||||
|
},
|
||||||
|
similarity: row.get("similarity"),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(matches)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete memories by agent_id and optional filters
|
||||||
|
pub async fn purge_memories(
|
||||||
|
&self,
|
||||||
|
agent_id: &str,
|
||||||
|
before: Option<chrono::DateTime<chrono::Utc>>,
|
||||||
|
) -> Result<u64> {
|
||||||
|
let client = self.pool.get().await?;
|
||||||
|
|
||||||
|
let count = if let Some(before_ts) = before {
|
||||||
|
client
|
||||||
|
.execute(
|
||||||
|
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
|
||||||
|
&[&agent_id, &before_ts],
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
client
|
||||||
|
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(count)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get memory count for an agent
|
||||||
|
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
|
||||||
|
let client = self.pool.get().await?;
|
||||||
|
let row = client
|
||||||
|
.query_one(
|
||||||
|
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1",
|
||||||
|
&[&agent_id],
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
Ok(row.get("count"))
|
||||||
|
}
|
||||||
|
}
|
||||||
245
src/embedding.rs
Normal file
245
src/embedding.rs
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
//! Embedding engine using local ONNX models
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
|
use ort::value::Value;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Once;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::config::EmbeddingConfig;
|
||||||
|
|
||||||
|
static ORT_INIT: Once = Once::new();
|
||||||
|
|
||||||
|
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
|
||||||
|
fn init_ort_sync(dylib_path: &str) -> Result<()> {
|
||||||
|
info!("Initializing ONNX Runtime from: {}", dylib_path);
|
||||||
|
|
||||||
|
let mut init_error: Option<String> = None;
|
||||||
|
|
||||||
|
ORT_INIT.call_once(|| {
|
||||||
|
info!("ORT_INIT.call_once - starting initialization");
|
||||||
|
match ort::init_from(dylib_path) {
|
||||||
|
Ok(builder) => {
|
||||||
|
info!("ort::init_from succeeded, calling commit()");
|
||||||
|
let committed = builder.commit();
|
||||||
|
info!("commit() returned: {}", committed);
|
||||||
|
if !committed {
|
||||||
|
init_error = Some("ONNX Runtime commit returned false".to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let err_msg = format!("ONNX Runtime init_from failed: {:?}", e);
|
||||||
|
info!("{}", err_msg);
|
||||||
|
init_error = Some(err_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("ORT_INIT.call_once - finished");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Note: init_error won't be set if ORT_INIT was already called
|
||||||
|
// This is fine - we only initialize once
|
||||||
|
if let Some(err) = init_error {
|
||||||
|
return Err(anyhow::anyhow!("{}", err));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("ONNX Runtime initialization complete");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve ONNX Runtime dylib path from env var or common local install locations.
|
||||||
|
fn resolve_ort_dylib_path() -> Result<String> {
|
||||||
|
if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
|
||||||
|
if Path::new(&path).exists() {
|
||||||
|
return Ok(path);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"ORT_DYLIB_PATH is set but file does not exist: {}",
|
||||||
|
path
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let candidates = [
|
||||||
|
"/opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib",
|
||||||
|
"/usr/local/opt/onnxruntime/lib/libonnxruntime.dylib",
|
||||||
|
];
|
||||||
|
|
||||||
|
for candidate in candidates {
|
||||||
|
if Path::new(candidate).exists() {
|
||||||
|
return Ok(candidate.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow::anyhow!(
|
||||||
|
"ORT_DYLIB_PATH environment variable not set and ONNX Runtime dylib not found. \
|
||||||
|
Set ORT_DYLIB_PATH to your libonnxruntime.dylib path (for example: /opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib)."
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmbeddingEngine {
|
||||||
|
session: std::sync::Mutex<Session>,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
dimension: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbeddingEngine {
|
||||||
|
/// Create a new embedding engine
|
||||||
|
pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
|
||||||
|
let dylib_path = resolve_ort_dylib_path()?;
|
||||||
|
|
||||||
|
let model_path = PathBuf::from(&config.model_path);
|
||||||
|
let dimension = config.dimension;
|
||||||
|
|
||||||
|
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
|
||||||
|
|
||||||
|
// Use spawn_blocking to avoid blocking the async runtime
|
||||||
|
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
||||||
|
// Initialize ONNX Runtime first
|
||||||
|
init_ort_sync(&dylib_path)?;
|
||||||
|
|
||||||
|
info!("Creating ONNX session...");
|
||||||
|
|
||||||
|
// Load ONNX model with ort 2.0 API
|
||||||
|
let session = Session::builder()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
||||||
|
.with_intra_threads(4)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
||||||
|
.commit_from_file(model_path.join("model.onnx"))
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
||||||
|
|
||||||
|
info!("ONNX model loaded, loading tokenizer...");
|
||||||
|
|
||||||
|
// Load tokenizer
|
||||||
|
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
||||||
|
|
||||||
|
info!("Tokenizer loaded successfully");
|
||||||
|
Ok((session, tokenizer))
|
||||||
|
}).await
|
||||||
|
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Embedding engine initialized: model={}, dimension={}",
|
||||||
|
config.model_path, dimension
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
session: std::sync::Mutex::new(session),
|
||||||
|
tokenizer,
|
||||||
|
dimension,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate embedding for a single text
|
||||||
|
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||||
|
let encoding = self.tokenizer
|
||||||
|
.encode(text, true)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
|
||||||
|
|
||||||
|
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
|
||||||
|
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
|
||||||
|
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
|
||||||
|
|
||||||
|
let seq_len = input_ids.len();
|
||||||
|
|
||||||
|
// Create input tensors with ort 2.0 API
|
||||||
|
let input_ids_tensor = Value::from_array(([1, seq_len], input_ids))?;
|
||||||
|
let attention_mask_tensor = Value::from_array(([1, seq_len], attention_mask))?;
|
||||||
|
let token_type_ids_tensor = Value::from_array(([1, seq_len], token_type_ids))?;
|
||||||
|
|
||||||
|
// Run inference
|
||||||
|
let inputs = ort::inputs![
|
||||||
|
"input_ids" => input_ids_tensor,
|
||||||
|
"attention_mask" => attention_mask_tensor,
|
||||||
|
"token_type_ids" => token_type_ids_tensor,
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut session_guard = self.session.lock()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
|
||||||
|
let outputs = session_guard.run(inputs)?;
|
||||||
|
|
||||||
|
// Extract output
|
||||||
|
let output = outputs.get("last_hidden_state")
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
|
||||||
|
|
||||||
|
// Get the tensor data
|
||||||
|
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
||||||
|
|
||||||
|
// Mean pooling over sequence dimension
|
||||||
|
let hidden_size = *shape.last().unwrap_or(&384) as usize;
|
||||||
|
let seq_len = data.len() / hidden_size;
|
||||||
|
|
||||||
|
let mut embedding = vec![0.0f32; hidden_size];
|
||||||
|
for i in 0..seq_len {
|
||||||
|
for j in 0..hidden_size {
|
||||||
|
embedding[j] += data[i * hidden_size + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for val in &mut embedding {
|
||||||
|
*val /= seq_len as f32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// L2 normalize
|
||||||
|
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||||
|
if norm > 0.0 {
|
||||||
|
for val in &mut embedding {
|
||||||
|
*val /= norm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(embedding)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate embeddings for multiple texts
|
||||||
|
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
||||||
|
texts.iter().map(|text| self.embed(text)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the embedding dimension
|
||||||
|
pub fn dimension(&self) -> usize {
|
||||||
|
self.dimension
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract keywords from text using simple frequency analysis
|
||||||
|
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
let stop_words: std::collections::HashSet<&str> = [
|
||||||
|
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||||
|
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
|
||||||
|
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
||||||
|
"could", "should", "may", "might", "must", "shall", "can", "this",
|
||||||
|
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
|
||||||
|
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
|
||||||
|
"all", "each", "every", "both", "few", "more", "most", "other", "some",
|
||||||
|
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
|
||||||
|
"very", "just", "also", "now", "here", "there", "then", "once", "if",
|
||||||
|
].iter().cloned().collect();
|
||||||
|
|
||||||
|
let mut word_counts: HashMap<String, usize> = HashMap::new();
|
||||||
|
|
||||||
|
for word in text.split_whitespace() {
|
||||||
|
let clean: String = word
|
||||||
|
.chars()
|
||||||
|
.filter(|c| c.is_alphanumeric())
|
||||||
|
.collect::<String>()
|
||||||
|
.to_lowercase();
|
||||||
|
|
||||||
|
if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
|
||||||
|
*word_counts.entry(clean).or_insert(0) += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sorted: Vec<_> = word_counts.into_iter().collect();
|
||||||
|
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
||||||
|
|
||||||
|
sorted.into_iter()
|
||||||
|
.take(limit)
|
||||||
|
.map(|(word, _)| word)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
150
src/lib.rs
Normal file
150
src/lib.rs
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
//! OpenBrain MCP - High-performance vector memory for AI agents
|
||||||
|
|
||||||
|
pub mod auth;
|
||||||
|
pub mod config;
|
||||||
|
pub mod db;
|
||||||
|
pub mod embedding;
|
||||||
|
pub mod migrations;
|
||||||
|
pub mod tools;
|
||||||
|
pub mod transport;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use axum::{Router, Json, http::StatusCode, middleware};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
use tower_http::trace::TraceLayer;
|
||||||
|
use tracing::{info, error};
|
||||||
|
|
||||||
|
use crate::auth::auth_middleware;
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::db::Database;
|
||||||
|
use crate::embedding::EmbeddingEngine;
|
||||||
|
use crate::transport::McpState;
|
||||||
|
|
||||||
|
/// Service readiness state
|
||||||
|
#[derive(Clone, Debug, PartialEq)]
|
||||||
|
pub enum ReadinessState {
|
||||||
|
Initializing,
|
||||||
|
Ready,
|
||||||
|
Failed(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shared application state
|
||||||
|
pub struct AppState {
|
||||||
|
pub db: Database,
|
||||||
|
pub embedding: tokio::sync::RwLock<Option<Arc<EmbeddingEngine>>>,
|
||||||
|
pub config: Config,
|
||||||
|
pub readiness: tokio::sync::RwLock<ReadinessState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
/// Get embedding engine, returns None if not ready
|
||||||
|
pub async fn get_embedding(&self) -> Option<Arc<EmbeddingEngine>> {
|
||||||
|
self.embedding.read().await.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check endpoint - always returns OK if server is running
|
||||||
|
async fn health_handler() -> Json<serde_json::Value> {
|
||||||
|
Json(json!({"status": "ok"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Readiness endpoint - returns 503 if embedding not ready
|
||||||
|
async fn readiness_handler(
|
||||||
|
state: axum::extract::State<Arc<AppState>>,
|
||||||
|
) -> (StatusCode, Json<serde_json::Value>) {
|
||||||
|
let readiness = state.readiness.read().await.clone();
|
||||||
|
match readiness {
|
||||||
|
ReadinessState::Ready => (
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(json!({"status": "ready", "embedding": true}))
|
||||||
|
),
|
||||||
|
ReadinessState::Initializing => (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"status": "initializing", "embedding": false}))
|
||||||
|
),
|
||||||
|
ReadinessState::Failed(err) => (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({"status": "failed", "error": err}))
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the MCP server
|
||||||
|
pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
||||||
|
// Create state with None embedding (will init in background)
|
||||||
|
let state = Arc::new(AppState {
|
||||||
|
db,
|
||||||
|
embedding: tokio::sync::RwLock::new(None),
|
||||||
|
config: config.clone(),
|
||||||
|
readiness: tokio::sync::RwLock::new(ReadinessState::Initializing),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Spawn background task to initialize embedding with retry
|
||||||
|
let state_clone = state.clone();
|
||||||
|
let embedding_config = config.embedding.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let max_retries = 3;
|
||||||
|
let mut attempt = 0;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
attempt += 1;
|
||||||
|
info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries);
|
||||||
|
|
||||||
|
match EmbeddingEngine::new(&embedding_config).await {
|
||||||
|
Ok(engine) => {
|
||||||
|
let engine = Arc::new(engine);
|
||||||
|
*state_clone.embedding.write().await = Some(engine);
|
||||||
|
*state_clone.readiness.write().await = ReadinessState::Ready;
|
||||||
|
info!("Embedding engine initialized successfully");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to init embedding (attempt {}): {:?}", attempt, e);
|
||||||
|
if attempt >= max_retries {
|
||||||
|
let err_msg = format!("Failed after {} attempts: {:?}", max_retries, e);
|
||||||
|
*state_clone.readiness.write().await = ReadinessState::Failed(err_msg);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// Exponential backoff: 2s, 4s, 8s...
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create MCP state for SSE transport
|
||||||
|
let mcp_state = McpState::new(state.clone());
|
||||||
|
|
||||||
|
// Build router with health/readiness endpoints (no auth required)
|
||||||
|
let health_router = Router::new()
|
||||||
|
.route("/health", axum::routing::get(health_handler))
|
||||||
|
.route("/ready", axum::routing::get(readiness_handler))
|
||||||
|
.with_state(state.clone());
|
||||||
|
|
||||||
|
// Build MCP router with auth middleware
|
||||||
|
let mcp_router = transport::mcp_router(mcp_state)
|
||||||
|
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.merge(health_router)
|
||||||
|
.nest("/mcp", mcp_router)
|
||||||
|
.layer(TraceLayer::new_for_http())
|
||||||
|
.layer(
|
||||||
|
CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Start server immediately
|
||||||
|
let bind_addr = format!("{}:{}", config.server.host, config.server.port);
|
||||||
|
let listener = TcpListener::bind(&bind_addr).await?;
|
||||||
|
info!("Server listening on {}", bind_addr);
|
||||||
|
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
46
src/main.rs
Normal file
46
src/main.rs
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
//! OpenBrain MCP Server - High-performance vector memory for AI agents
|
||||||
|
//!
|
||||||
|
//! This is the main entry point for the OpenBrain MCP server.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::env;
|
||||||
|
use tracing::info;
|
||||||
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||||
|
|
||||||
|
use openbrain_mcp::{config::Config, db::Database, migrations, run_server};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
// Initialize tracing
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
|
||||||
|
.with(tracing_subscriber::fmt::layer())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION"));
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
let config = Config::load()?;
|
||||||
|
info!("Configuration loaded from environment");
|
||||||
|
|
||||||
|
match env::args().nth(1).as_deref() {
|
||||||
|
Some("migrate") => {
|
||||||
|
migrations::run(&config.database).await?;
|
||||||
|
info!("Database migrations completed successfully");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Some(arg) => {
|
||||||
|
anyhow::bail!("Unknown command: {arg}. Supported commands: migrate");
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize database connection pool
|
||||||
|
let db = Database::new(&config.database).await?;
|
||||||
|
info!("Database connection pool initialized");
|
||||||
|
|
||||||
|
// Run the MCP server
|
||||||
|
run_server(config, db).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
50
src/migrations.rs
Normal file
50
src/migrations.rs
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//! Database migrations using refinery.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use refinery::embed_migrations;
|
||||||
|
use tokio_postgres::NoTls;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::config::DatabaseConfig;
|
||||||
|
|
||||||
|
embed_migrations!("migrations");
|
||||||
|
|
||||||
|
/// Apply all pending database migrations.
|
||||||
|
pub async fn run(config: &DatabaseConfig) -> Result<()> {
|
||||||
|
let mut pg_config = tokio_postgres::Config::new();
|
||||||
|
pg_config.host(&config.host);
|
||||||
|
pg_config.port(config.port);
|
||||||
|
pg_config.dbname(&config.name);
|
||||||
|
pg_config.user(&config.user);
|
||||||
|
pg_config.password(&config.password);
|
||||||
|
|
||||||
|
let (mut client, connection) = pg_config
|
||||||
|
.connect(NoTls)
|
||||||
|
.await
|
||||||
|
.context("Failed to connect to database for migrations")?;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = connection.await {
|
||||||
|
tracing::error!("Database migration connection error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let report = migrations::runner()
|
||||||
|
.run_async(&mut client)
|
||||||
|
.await
|
||||||
|
.context("Failed to apply database migrations")?;
|
||||||
|
|
||||||
|
if report.applied_migrations().is_empty() {
|
||||||
|
info!("No database migrations to apply");
|
||||||
|
} else {
|
||||||
|
for migration in report.applied_migrations() {
|
||||||
|
info!(
|
||||||
|
version = migration.version(),
|
||||||
|
name = migration.name(),
|
||||||
|
"Applied database migration"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
106
src/tools/mod.rs
Normal file
106
src/tools/mod.rs
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
//! MCP Tools for OpenBrain
|
||||||
|
//!
|
||||||
|
//! Provides the core tools for memory storage and retrieval:
|
||||||
|
//! - `store`: Store a memory with automatic embedding generation
|
||||||
|
//! - `query`: Query memories by semantic similarity
|
||||||
|
//! - `purge`: Delete memories by agent_id or time range
|
||||||
|
|
||||||
|
pub mod query;
|
||||||
|
pub mod store;
|
||||||
|
pub mod purge;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Get all tool definitions for MCP tools/list
|
||||||
|
pub fn get_tool_definitions() -> Vec<Value> {
|
||||||
|
vec![
|
||||||
|
json!({
|
||||||
|
"name": "store",
|
||||||
|
"description": "Store a memory with automatic embedding generation and keyword extraction. The memory will be associated with the agent_id for isolated retrieval.",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content to store as a memory"
|
||||||
|
},
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Unique identifier for the agent storing the memory (default: 'default')"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "Optional metadata to attach to the memory"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["content"]
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
json!({
|
||||||
|
"name": "query",
|
||||||
|
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The search query text"
|
||||||
|
},
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Agent ID to search within (default: 'default')"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results to return (default: 10)"
|
||||||
|
},
|
||||||
|
"threshold": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Minimum similarity threshold 0.0-1.0 (default: 0.5)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
json!({
|
||||||
|
"name": "purge",
|
||||||
|
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"agent_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Agent ID whose memories to delete (required)"
|
||||||
|
},
|
||||||
|
"before": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional ISO8601 timestamp - delete memories created before this time"
|
||||||
|
},
|
||||||
|
"confirm": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Must be true to confirm deletion"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["agent_id", "confirm"]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a tool by name with given arguments
|
||||||
|
pub async fn execute_tool(
|
||||||
|
state: &Arc<AppState>,
|
||||||
|
tool_name: &str,
|
||||||
|
arguments: Value,
|
||||||
|
) -> Result<String> {
|
||||||
|
match tool_name {
|
||||||
|
"store" => store::execute(state, arguments).await,
|
||||||
|
"query" => query::execute(state, arguments).await,
|
||||||
|
"purge" => purge::execute(state, arguments).await,
|
||||||
|
_ => anyhow::bail!("Unknown tool: {}", tool_name),
|
||||||
|
}
|
||||||
|
}
|
||||||
79
src/tools/purge.rs
Normal file
79
src/tools/purge.rs
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
//! Purge Tool - Delete memories by agent_id or time range
|
||||||
|
|
||||||
|
use anyhow::{bail, Context, Result};
|
||||||
|
use chrono::DateTime;
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Execute the purge tool
|
||||||
|
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
||||||
|
// Extract parameters
|
||||||
|
let agent_id = arguments
|
||||||
|
.get("agent_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.context("Missing required parameter: agent_id")?;
|
||||||
|
|
||||||
|
let confirm = arguments
|
||||||
|
.get("confirm")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if !confirm {
|
||||||
|
bail!("Purge operation requires 'confirm: true' to proceed");
|
||||||
|
}
|
||||||
|
|
||||||
|
let before = arguments
|
||||||
|
.get("before")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| DateTime::parse_from_rfc3339(s))
|
||||||
|
.transpose()
|
||||||
|
.context("Invalid 'before' timestamp format - use ISO8601/RFC3339")?
|
||||||
|
.map(|dt| dt.with_timezone(&chrono::Utc));
|
||||||
|
|
||||||
|
// Get current count before purge
|
||||||
|
let count_before = state
|
||||||
|
.db
|
||||||
|
.count_memories(agent_id)
|
||||||
|
.await
|
||||||
|
.context("Failed to count memories")?;
|
||||||
|
|
||||||
|
if count_before == 0 {
|
||||||
|
info!("No memories found for agent '{}'", agent_id);
|
||||||
|
return Ok(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"deleted": 0,
|
||||||
|
"message": "No memories found to purge"
|
||||||
|
})
|
||||||
|
.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
"Purging memories for agent '{}' (before={:?})",
|
||||||
|
agent_id, before
|
||||||
|
);
|
||||||
|
|
||||||
|
// Execute purge
|
||||||
|
let deleted = state
|
||||||
|
.db
|
||||||
|
.purge_memories(agent_id, before)
|
||||||
|
.await
|
||||||
|
.context("Failed to purge memories")?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Purged {} memories for agent '{}'",
|
||||||
|
deleted, agent_id
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"deleted": deleted,
|
||||||
|
"had_before_filter": before.is_some(),
|
||||||
|
"message": format!("Successfully purged {} memories", deleted)
|
||||||
|
})
|
||||||
|
.to_string())
|
||||||
|
}
|
||||||
81
src/tools/query.rs
Normal file
81
src/tools/query.rs
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
//! Query Tool - Search memories by semantic similarity
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Execute the query tool
|
||||||
|
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
||||||
|
// Get embedding engine, return error if not ready
|
||||||
|
let embedding_engine = state
|
||||||
|
.get_embedding()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?;
|
||||||
|
|
||||||
|
// Extract parameters
|
||||||
|
let query_text = arguments
|
||||||
|
.get("query")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.context("Missing required parameter: query")?;
|
||||||
|
|
||||||
|
let agent_id = arguments
|
||||||
|
.get("agent_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("default");
|
||||||
|
|
||||||
|
let limit = arguments
|
||||||
|
.get("limit")
|
||||||
|
.and_then(|v| v.as_i64())
|
||||||
|
.unwrap_or(10);
|
||||||
|
|
||||||
|
let threshold = arguments
|
||||||
|
.get("threshold")
|
||||||
|
.and_then(|v| v.as_f64())
|
||||||
|
.unwrap_or(0.5) as f32;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Querying memories for agent '{}': '{}' (limit={}, threshold={})",
|
||||||
|
agent_id, query_text, limit, threshold
|
||||||
|
);
|
||||||
|
|
||||||
|
// Generate embedding for query using Arc<EmbeddingEngine>
|
||||||
|
let query_embedding = embedding_engine
|
||||||
|
.embed(query_text)
|
||||||
|
.context("Failed to generate query embedding")?;
|
||||||
|
|
||||||
|
// Search database
|
||||||
|
let matches = state
|
||||||
|
.db
|
||||||
|
.query_memories(agent_id, &query_embedding, limit, threshold)
|
||||||
|
.await
|
||||||
|
.context("Failed to query memories")?;
|
||||||
|
|
||||||
|
info!("Found {} matching memories", matches.len());
|
||||||
|
|
||||||
|
// Format results
|
||||||
|
let results: Vec<Value> = matches
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": m.record.id.to_string(),
|
||||||
|
"content": m.record.content,
|
||||||
|
"similarity": m.similarity,
|
||||||
|
"keywords": m.record.keywords,
|
||||||
|
"metadata": m.record.metadata,
|
||||||
|
"created_at": m.record.created_at.to_rfc3339()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": query_text,
|
||||||
|
"count": results.len(),
|
||||||
|
"results": results
|
||||||
|
})
|
||||||
|
.to_string())
|
||||||
|
}
|
||||||
66
src/tools/store.rs
Normal file
66
src/tools/store.rs
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//! Store Tool - Store memories with automatic embeddings
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
use crate::embedding::extract_keywords;
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Execute the store tool
|
||||||
|
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
||||||
|
// Get embedding engine, return error if not ready
|
||||||
|
let embedding_engine = state
|
||||||
|
.get_embedding()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?;
|
||||||
|
|
||||||
|
// Extract parameters
|
||||||
|
let content = arguments
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.context("Missing required parameter: content")?;
|
||||||
|
|
||||||
|
let agent_id = arguments
|
||||||
|
.get("agent_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("default");
|
||||||
|
|
||||||
|
let metadata = arguments
|
||||||
|
.get("metadata")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::json!({}));
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Storing memory for agent '{}': {} chars",
|
||||||
|
agent_id,
|
||||||
|
content.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Generate embedding using Arc<EmbeddingEngine>
|
||||||
|
let embedding = embedding_engine
|
||||||
|
.embed(content)
|
||||||
|
.context("Failed to generate embedding")?;
|
||||||
|
|
||||||
|
// Extract keywords
|
||||||
|
let keywords = extract_keywords(content, 10);
|
||||||
|
|
||||||
|
// Store in database
|
||||||
|
let id = state
|
||||||
|
.db
|
||||||
|
.store_memory(agent_id, content, &embedding, &keywords, metadata)
|
||||||
|
.await
|
||||||
|
.context("Failed to store memory")?;
|
||||||
|
|
||||||
|
info!("Memory stored with ID: {}", id);
|
||||||
|
|
||||||
|
Ok(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"id": id.to_string(),
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"keywords": keywords,
|
||||||
|
"embedding_dimension": embedding.len()
|
||||||
|
})
|
||||||
|
.to_string())
|
||||||
|
}
|
||||||
531
src/transport.rs
Normal file
531
src/transport.rs
Normal file
@@ -0,0 +1,531 @@
|
|||||||
|
//! SSE Transport for MCP Protocol
|
||||||
|
//!
|
||||||
|
//! Implements Server-Sent Events transport for the Model Context Protocol.
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Query, State},
|
||||||
|
http::{HeaderMap, StatusCode},
|
||||||
|
response::{
|
||||||
|
IntoResponse, Response,
|
||||||
|
sse::{Event, KeepAlive, Sse},
|
||||||
|
},
|
||||||
|
routing::{get, post},
|
||||||
|
Json, Router,
|
||||||
|
};
|
||||||
|
use futures::stream::Stream;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
convert::Infallible,
|
||||||
|
sync::Arc,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
use tokio::sync::{RwLock, broadcast, mpsc};
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{AppState, auth, tools};
|
||||||
|
|
||||||
|
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
|
||||||
|
|
||||||
|
/// MCP Server State
|
||||||
|
pub struct McpState {
|
||||||
|
pub app: Arc<AppState>,
|
||||||
|
pub event_tx: broadcast::Sender<McpEvent>,
|
||||||
|
sessions: SessionStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl McpState {
|
||||||
|
pub fn new(app: Arc<AppState>) -> Arc<Self> {
|
||||||
|
let (event_tx, _) = broadcast::channel(100);
|
||||||
|
Arc::new(Self {
|
||||||
|
app,
|
||||||
|
event_tx,
|
||||||
|
sessions: RwLock::new(HashMap::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn insert_session(
|
||||||
|
&self,
|
||||||
|
session_id: String,
|
||||||
|
tx: mpsc::Sender<serde_json::Value>,
|
||||||
|
) {
|
||||||
|
self.sessions.write().await.insert(session_id, tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn remove_session(&self, session_id: &str) {
|
||||||
|
self.sessions.write().await.remove(session_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn has_session(&self, session_id: &str) -> bool {
|
||||||
|
self.sessions.read().await.contains_key(session_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_to_session(
|
||||||
|
&self,
|
||||||
|
session_id: &str,
|
||||||
|
response: &JsonRpcResponse,
|
||||||
|
) -> Result<(), SessionSendError> {
|
||||||
|
let tx = {
|
||||||
|
let sessions = self.sessions.read().await;
|
||||||
|
sessions
|
||||||
|
.get(session_id)
|
||||||
|
.cloned()
|
||||||
|
.ok_or(SessionSendError::NotFound)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let payload =
|
||||||
|
serde_json::to_value(response).expect("serializing JSON-RPC response should succeed");
|
||||||
|
|
||||||
|
if tx.send(payload).await.is_err() {
|
||||||
|
self.remove_session(session_id).await;
|
||||||
|
return Err(SessionSendError::Closed);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum SessionSendError {
|
||||||
|
NotFound,
|
||||||
|
Closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SessionGuard {
|
||||||
|
state: Arc<McpState>,
|
||||||
|
session_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionGuard {
|
||||||
|
fn new(state: Arc<McpState>, session_id: String) -> Self {
|
||||||
|
Self { state, session_id }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for SessionGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let state = self.state.clone();
|
||||||
|
let session_id = self.session_id.clone();
|
||||||
|
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||||
|
handle.spawn(async move {
|
||||||
|
state.remove_session(&session_id).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP Event for SSE streaming
|
||||||
|
#[derive(Clone, Debug, Serialize)]
|
||||||
|
pub struct McpEvent {
|
||||||
|
pub id: String,
|
||||||
|
pub event_type: String,
|
||||||
|
pub data: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP JSON-RPC Request
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct JsonRpcRequest {
|
||||||
|
pub jsonrpc: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub id: Option<serde_json::Value>,
|
||||||
|
pub method: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub params: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// MCP JSON-RPC Response
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct JsonRpcResponse {
|
||||||
|
pub jsonrpc: String,
|
||||||
|
pub id: serde_json::Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub result: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<JsonRpcError>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct JsonRpcError {
|
||||||
|
pub code: i32,
|
||||||
|
pub message: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub data: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct PostMessageQuery {
|
||||||
|
#[serde(default)]
|
||||||
|
session_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create the MCP router
|
||||||
|
pub fn mcp_router(state: Arc<McpState>) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route("/sse", get(sse_handler))
|
||||||
|
.route("/message", post(message_handler))
|
||||||
|
.route("/health", get(health_handler))
|
||||||
|
.with_state(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SSE endpoint for streaming events
|
||||||
|
async fn sse_handler(
|
||||||
|
State(state): State<Arc<McpState>>,
|
||||||
|
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||||
|
let mut broadcast_rx = state.event_tx.subscribe();
|
||||||
|
let (session_tx, mut session_rx) = mpsc::channel(32);
|
||||||
|
let session_id = Uuid::new_v4().to_string();
|
||||||
|
let endpoint = session_message_endpoint(&session_id);
|
||||||
|
|
||||||
|
state.insert_session(session_id.clone(), session_tx).await;
|
||||||
|
|
||||||
|
let stream = async_stream::stream! {
|
||||||
|
let _session_guard = SessionGuard::new(state.clone(), session_id.clone());
|
||||||
|
|
||||||
|
// Send endpoint event (required by MCP SSE protocol)
|
||||||
|
// This tells the client where to POST JSON-RPC messages for this session.
|
||||||
|
yield Ok(Event::default()
|
||||||
|
.event("endpoint")
|
||||||
|
.data(endpoint));
|
||||||
|
|
||||||
|
// Send initial tools list so Agent Zero knows what's available
|
||||||
|
let tools_response = JsonRpcResponse {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id: serde_json::json!("initial-tools"),
|
||||||
|
result: Some(serde_json::json!({
|
||||||
|
"tools": tools::get_tool_definitions()
|
||||||
|
})),
|
||||||
|
error: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
yield Ok(Event::default()
|
||||||
|
.event("message")
|
||||||
|
.json_data(&tools_response)
|
||||||
|
.unwrap());
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
maybe_message = session_rx.recv() => {
|
||||||
|
match maybe_message {
|
||||||
|
Some(message) => {
|
||||||
|
yield Ok(Event::default()
|
||||||
|
.event("message")
|
||||||
|
.json_data(&message)
|
||||||
|
.unwrap());
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
event = broadcast_rx.recv() => {
|
||||||
|
match event {
|
||||||
|
Ok(event) => {
|
||||||
|
yield Ok(Event::default()
|
||||||
|
.event(&event.event_type)
|
||||||
|
.id(&event.id)
|
||||||
|
.json_data(&event.data)
|
||||||
|
.unwrap());
|
||||||
|
}
|
||||||
|
Err(broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
warn!("SSE client lagged, missed {} events", n);
|
||||||
|
}
|
||||||
|
Err(broadcast::error::RecvError::Closed) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message endpoint for JSON-RPC requests
|
||||||
|
async fn message_handler(
|
||||||
|
State(state): State<Arc<McpState>>,
|
||||||
|
Query(query): Query<PostMessageQuery>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
Json(request): Json<JsonRpcRequest>,
|
||||||
|
) -> Response {
|
||||||
|
info!("Received MCP request: {}", request.method);
|
||||||
|
|
||||||
|
if let Some(session_id) = query.session_id.as_deref() {
|
||||||
|
if !state.has_session(session_id).await {
|
||||||
|
return StatusCode::NOT_FOUND.into_response();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = apply_request_context(request, &headers);
|
||||||
|
let response = dispatch_request(&state, &request).await;
|
||||||
|
|
||||||
|
match query.session_id.as_deref() {
|
||||||
|
Some(session_id) => route_session_response(&state, session_id, response).await,
|
||||||
|
None => match response {
|
||||||
|
Some(response) => Json(response).into_response(),
|
||||||
|
None => StatusCode::ACCEPTED.into_response(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check endpoint
|
||||||
|
async fn health_handler() -> Json<serde_json::Value> {
|
||||||
|
Json(serde_json::json!({
|
||||||
|
"status": "healthy",
|
||||||
|
"server": "openbrain-mcp",
|
||||||
|
"version": env!("CARGO_PKG_VERSION")
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn session_message_endpoint(session_id: &str) -> String {
|
||||||
|
format!("/mcp/message?sessionId={session_id}")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_session_response(
|
||||||
|
state: &Arc<McpState>,
|
||||||
|
session_id: &str,
|
||||||
|
response: Option<JsonRpcResponse>,
|
||||||
|
) -> Response {
|
||||||
|
match response {
|
||||||
|
Some(response) => match state.send_to_session(session_id, &response).await {
|
||||||
|
Ok(()) => StatusCode::ACCEPTED.into_response(),
|
||||||
|
Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(),
|
||||||
|
Err(SessionSendError::Closed) => StatusCode::GONE.into_response(),
|
||||||
|
},
|
||||||
|
None => StatusCode::ACCEPTED.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_request_context(
|
||||||
|
mut request: JsonRpcRequest,
|
||||||
|
headers: &HeaderMap,
|
||||||
|
) -> JsonRpcRequest {
|
||||||
|
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
|
||||||
|
inject_agent_id(&mut request, &agent_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
request
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
|
||||||
|
if request.method != "tools/call" {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if !request.params.is_object() {
|
||||||
|
request.params = serde_json::json!({});
|
||||||
|
}
|
||||||
|
|
||||||
|
let params = request
|
||||||
|
.params
|
||||||
|
.as_object_mut()
|
||||||
|
.expect("params should be an object");
|
||||||
|
let arguments = params
|
||||||
|
.entry("arguments")
|
||||||
|
.or_insert_with(|| serde_json::json!({}));
|
||||||
|
|
||||||
|
if !arguments.is_object() {
|
||||||
|
*arguments = serde_json::json!({});
|
||||||
|
}
|
||||||
|
|
||||||
|
arguments
|
||||||
|
.as_object_mut()
|
||||||
|
.expect("arguments should be an object")
|
||||||
|
.entry("agent_id".to_string())
|
||||||
|
.or_insert_with(|| serde_json::json!(agent_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn dispatch_request(
|
||||||
|
state: &Arc<McpState>,
|
||||||
|
request: &JsonRpcRequest,
|
||||||
|
) -> Option<JsonRpcResponse> {
|
||||||
|
match request.method.as_str() {
|
||||||
|
"initialize" => request
|
||||||
|
.id
|
||||||
|
.clone()
|
||||||
|
.map(|id| success_response(id, initialize_result())),
|
||||||
|
"ping" => request
|
||||||
|
.id
|
||||||
|
.clone()
|
||||||
|
.map(|id| success_response(id, serde_json::json!({}))),
|
||||||
|
"tools/list" => request.id.clone().map(|id| {
|
||||||
|
success_response(
|
||||||
|
id,
|
||||||
|
serde_json::json!({
|
||||||
|
"tools": tools::get_tool_definitions()
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
"tools/call" => handle_tools_call(state, request).await,
|
||||||
|
"notifications/initialized" => {
|
||||||
|
info!("Received MCP initialized notification");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => request.id.clone().map(|id| {
|
||||||
|
error_response(
|
||||||
|
id,
|
||||||
|
-32601,
|
||||||
|
format!("Method not found: {}", request.method),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn initialize_result() -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"serverInfo": {
|
||||||
|
"name": "openbrain-mcp",
|
||||||
|
"version": env!("CARGO_PKG_VERSION")
|
||||||
|
},
|
||||||
|
"capabilities": {
|
||||||
|
"tools": {
|
||||||
|
"listChanged": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn success_response(
|
||||||
|
id: serde_json::Value,
|
||||||
|
result: serde_json::Value,
|
||||||
|
) -> JsonRpcResponse {
|
||||||
|
JsonRpcResponse {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id,
|
||||||
|
result: Some(result),
|
||||||
|
error: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn error_response(
|
||||||
|
id: serde_json::Value,
|
||||||
|
code: i32,
|
||||||
|
message: String,
|
||||||
|
data: Option<serde_json::Value>,
|
||||||
|
) -> JsonRpcResponse {
|
||||||
|
JsonRpcResponse {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id,
|
||||||
|
result: None,
|
||||||
|
error: Some(JsonRpcError {
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
data,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle tools/call request
|
||||||
|
async fn handle_tools_call(
|
||||||
|
state: &Arc<McpState>,
|
||||||
|
request: &JsonRpcRequest,
|
||||||
|
) -> Option<JsonRpcResponse> {
|
||||||
|
let params = &request.params;
|
||||||
|
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||||
|
let arguments = params
|
||||||
|
.get("arguments")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::json!({}));
|
||||||
|
|
||||||
|
let result = match tools::execute_tool(&state.app, tool_name, arguments).await {
|
||||||
|
Ok(result) => Ok(serde_json::json!({
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": result
|
||||||
|
}]
|
||||||
|
})),
|
||||||
|
Err(e) => {
|
||||||
|
let full_error = format!("{:#}", e);
|
||||||
|
error!("Tool execution error: {}", full_error);
|
||||||
|
Err(JsonRpcError {
|
||||||
|
code: -32000,
|
||||||
|
message: full_error,
|
||||||
|
data: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match (request.id.clone(), result) {
|
||||||
|
(Some(id), Ok(result)) => Some(success_response(id, result)),
|
||||||
|
(Some(id), Err(error)) => Some(JsonRpcResponse {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id,
|
||||||
|
result: None,
|
||||||
|
error: Some(error),
|
||||||
|
}),
|
||||||
|
(None, Ok(_)) => None,
|
||||||
|
(None, Err(error)) => {
|
||||||
|
error!(
|
||||||
|
"Tool execution failed for notification '{}': {}",
|
||||||
|
request.method, error.message
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn injects_agent_id_when_missing_from_tool_arguments() {
|
||||||
|
let mut request = JsonRpcRequest {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id: Some(serde_json::json!("1")),
|
||||||
|
method: "tools/call".to_string(),
|
||||||
|
params: serde_json::json!({
|
||||||
|
"name": "query",
|
||||||
|
"arguments": {
|
||||||
|
"query": "editor preferences"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
inject_agent_id(&mut request, "agent-from-header");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request
|
||||||
|
.params
|
||||||
|
.get("arguments")
|
||||||
|
.and_then(|value| value.get("agent_id"))
|
||||||
|
.and_then(|value| value.as_str()),
|
||||||
|
Some("agent-from-header")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preserves_explicit_agent_id() {
|
||||||
|
let mut request = JsonRpcRequest {
|
||||||
|
jsonrpc: "2.0".to_string(),
|
||||||
|
id: Some(serde_json::json!("1")),
|
||||||
|
method: "tools/call".to_string(),
|
||||||
|
params: serde_json::json!({
|
||||||
|
"name": "query",
|
||||||
|
"arguments": {
|
||||||
|
"agent_id": "explicit-agent"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
inject_agent_id(&mut request, "agent-from-header");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request
|
||||||
|
.params
|
||||||
|
.get("arguments")
|
||||||
|
.and_then(|value| value.get("agent_id"))
|
||||||
|
.and_then(|value| value.as_str()),
|
||||||
|
Some("explicit-agent")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn session_endpoint_uses_camel_case_query_param() {
|
||||||
|
assert_eq!(
|
||||||
|
session_message_endpoint("abc123"),
|
||||||
|
"/mcp/message?sessionId=abc123"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
873
tests/e2e_mcp.rs
Normal file
873
tests/e2e_mcp.rs
Normal file
@@ -0,0 +1,873 @@
|
|||||||
|
use serde_json::{json, Value};
|
||||||
|
use std::process::{Command, Stdio};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio_postgres::NoTls;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
fn base_url() -> String {
|
||||||
|
std::env::var("OPENBRAIN_E2E_BASE_URL").unwrap_or_else(|_| "http://127.0.0.1:3100".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn api_key() -> Option<String> {
|
||||||
|
std::env::var("OPENBRAIN_E2E_API_KEY").ok()
|
||||||
|
.or_else(|| std::env::var("OPENBRAIN__AUTH__API_KEYS").ok())
|
||||||
|
.map(|keys| keys.split(',').next().unwrap_or("").trim().to_string())
|
||||||
|
.filter(|k| !k.is_empty())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn db_url() -> String {
|
||||||
|
let host = std::env::var("OPENBRAIN__DATABASE__HOST").unwrap_or_else(|_| "localhost".to_string());
|
||||||
|
let port = std::env::var("OPENBRAIN__DATABASE__PORT").unwrap_or_else(|_| "5432".to_string());
|
||||||
|
let name = std::env::var("OPENBRAIN__DATABASE__NAME").unwrap_or_else(|_| "openbrain".to_string());
|
||||||
|
let user = std::env::var("OPENBRAIN__DATABASE__USER").unwrap_or_else(|_| "openbrain_svc".to_string());
|
||||||
|
let password = std::env::var("OPENBRAIN__DATABASE__PASSWORD")
|
||||||
|
.unwrap_or_else(|_| "your_secure_password_here".to_string());
|
||||||
|
|
||||||
|
format!("host={host} port={port} dbname={name} user={user} password={password}")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn ensure_schema() {
|
||||||
|
let (client, connection) = tokio_postgres::connect(&db_url(), NoTls)
|
||||||
|
.await
|
||||||
|
.expect("connect to postgres for e2e schema setup");
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = connection.await {
|
||||||
|
eprintln!("postgres connection error: {e}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let vector_exists = client
|
||||||
|
.query_one("SELECT to_regtype('vector')::text", &[])
|
||||||
|
.await
|
||||||
|
.expect("query vector type availability")
|
||||||
|
.get::<_, Option<String>>(0)
|
||||||
|
.is_some();
|
||||||
|
|
||||||
|
if !vector_exists {
|
||||||
|
if let Err(e) = client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[]).await {
|
||||||
|
panic!(
|
||||||
|
"pgvector extension is not available for this PostgreSQL instance: {e}. \
|
||||||
|
Install pgvector for your active PostgreSQL major version, then run: CREATE EXTENSION vector;"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
client
|
||||||
|
.batch_execute(
|
||||||
|
r#"
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
agent_id VARCHAR(255) NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
embedding vector(384) NOT NULL,
|
||||||
|
keywords TEXT[] DEFAULT '{}',
|
||||||
|
metadata JSONB DEFAULT '{}',
|
||||||
|
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(agent_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories
|
||||||
|
USING hnsw (embedding vector_cosine_ops);
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("create memories table/indexes for e2e");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_until_ready(client: &reqwest::Client, base: &str) {
|
||||||
|
for _ in 0..60 {
|
||||||
|
let resp = client.get(format!("{base}/ready")).send().await;
|
||||||
|
if let Ok(resp) = resp {
|
||||||
|
if resp.status().is_success() {
|
||||||
|
let body: Value = resp.json().await.expect("/ready JSON response");
|
||||||
|
if body.get("status").and_then(Value::as_str) == Some("ready") {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
panic!("Server did not become ready at {base}/ready within timeout");
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> Value {
|
||||||
|
let mut req_builder = client
|
||||||
|
.post(format!("{base}/mcp/message"))
|
||||||
|
.json(&request);
|
||||||
|
|
||||||
|
// Add API key header if available
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
req_builder = req_builder.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
req_builder
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("JSON-RPC HTTP request")
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("JSON-RPC response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Make an authenticated GET request to an MCP endpoint
|
||||||
|
async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response {
|
||||||
|
let mut req_builder = client.get(format!("{base}{path}"));
|
||||||
|
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
req_builder = req_builder.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
req_builder.send().await.expect(&format!("GET {path}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn read_sse_event(
|
||||||
|
response: &mut reqwest::Response,
|
||||||
|
buffer: &mut String,
|
||||||
|
) -> Option<(Option<String>, String)> {
|
||||||
|
loop {
|
||||||
|
*buffer = buffer.replace("\r\n", "\n");
|
||||||
|
if let Some(idx) = buffer.find("\n\n") {
|
||||||
|
let raw_event = buffer[..idx].to_string();
|
||||||
|
*buffer = buffer[idx + 2..].to_string();
|
||||||
|
|
||||||
|
let mut event_type = None;
|
||||||
|
let mut data_lines = Vec::new();
|
||||||
|
for line in raw_event.lines() {
|
||||||
|
if let Some(value) = line.strip_prefix("event:") {
|
||||||
|
event_type = Some(value.trim().to_string());
|
||||||
|
} else if let Some(value) = line.strip_prefix("data:") {
|
||||||
|
data_lines.push(value.trim_start().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some((event_type, data_lines.join("\n")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = response
|
||||||
|
.chunk()
|
||||||
|
.await
|
||||||
|
.expect("read SSE chunk")?;
|
||||||
|
buffer.push_str(std::str::from_utf8(&chunk).expect("SSE chunk should be valid UTF-8"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn call_tool(
|
||||||
|
client: &reqwest::Client,
|
||||||
|
base: &str,
|
||||||
|
tool_name: &str,
|
||||||
|
arguments: Value,
|
||||||
|
) -> Value {
|
||||||
|
let request = json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": Uuid::new_v4().to_string(),
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = call_jsonrpc(client, base, request).await;
|
||||||
|
|
||||||
|
if let Some(error) = response.get("error") {
|
||||||
|
panic!("tools/call for '{tool_name}' failed: {error}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let text_payload = response
|
||||||
|
.get("result")
|
||||||
|
.and_then(|r| r.get("content"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.and_then(|arr| arr.first())
|
||||||
|
.and_then(|item| item.get("text"))
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.expect("result.content[0].text payload");
|
||||||
|
|
||||||
|
serde_json::from_str(text_payload).expect("tool text payload to be valid JSON")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_store_query_purge_roundtrip() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
ensure_schema().await;
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let agent_id = format!("e2e-agent-{}", Uuid::new_v4());
|
||||||
|
let memory_text = format!(
|
||||||
|
"E2E memory {}: user prefers dark theme and vim bindings",
|
||||||
|
Uuid::new_v4()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Ensure clean slate for this test agent.
|
||||||
|
let _ = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"purge",
|
||||||
|
json!({ "agent_id": agent_id, "confirm": true }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let store_result = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"store",
|
||||||
|
json!({
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"content": memory_text,
|
||||||
|
"metadata": { "source": "e2e-test", "suite": "store-query-purge" }
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
store_result.get("success").and_then(Value::as_bool),
|
||||||
|
Some(true),
|
||||||
|
"store should succeed"
|
||||||
|
);
|
||||||
|
|
||||||
|
let query_result = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"query",
|
||||||
|
json!({
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": "What are the user's editor preferences?",
|
||||||
|
"limit": 5,
|
||||||
|
"threshold": 0.0
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let count = query_result
|
||||||
|
.get("count")
|
||||||
|
.and_then(Value::as_u64)
|
||||||
|
.expect("query.count");
|
||||||
|
assert!(count >= 1, "query should return at least one stored memory");
|
||||||
|
|
||||||
|
let results = query_result
|
||||||
|
.get("results")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.expect("query.results");
|
||||||
|
let found_stored_content = results.iter().any(|item| {
|
||||||
|
item.get("content")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.map(|content| content == memory_text)
|
||||||
|
.unwrap_or(false)
|
||||||
|
});
|
||||||
|
assert!(
|
||||||
|
found_stored_content,
|
||||||
|
"query results should include the content stored by this test"
|
||||||
|
);
|
||||||
|
|
||||||
|
let purge_result = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"purge",
|
||||||
|
json!({ "agent_id": agent_id, "confirm": true }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let deleted = purge_result
|
||||||
|
.get("deleted")
|
||||||
|
.and_then(Value::as_u64)
|
||||||
|
.expect("purge.deleted");
|
||||||
|
assert!(deleted >= 1, "purge should delete at least one memory");
|
||||||
|
|
||||||
|
let query_after_purge = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"query",
|
||||||
|
json!({
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": "dark theme vim bindings",
|
||||||
|
"limit": 5,
|
||||||
|
"threshold": 0.0
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
query_after_purge.get("count").and_then(Value::as_u64),
|
||||||
|
Some(0),
|
||||||
|
"query after purge should return no memories for this agent"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_transport_tools_list_and_unknown_method() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let list_response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "tools-list-1",
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let tools = list_response
|
||||||
|
.get("result")
|
||||||
|
.and_then(|r| r.get("tools"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.expect("tools/list result.tools");
|
||||||
|
|
||||||
|
let tool_names: Vec<&str> = tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|t| t.get("name").and_then(Value::as_str))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert!(tool_names.contains(&"store"), "tools/list should include store");
|
||||||
|
assert!(tool_names.contains(&"query"), "tools/list should include query");
|
||||||
|
assert!(tool_names.contains(&"purge"), "tools/list should include purge");
|
||||||
|
|
||||||
|
let unknown_response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "unknown-1",
|
||||||
|
"method": "not/a/real/method",
|
||||||
|
"params": {}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
unknown_response
|
||||||
|
.get("error")
|
||||||
|
.and_then(|e| e.get("code"))
|
||||||
|
.and_then(Value::as_i64),
|
||||||
|
Some(-32601),
|
||||||
|
"unknown method should return Method Not Found"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_purge_requires_confirm_flag() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "purge-confirm-1",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "purge",
|
||||||
|
"arguments": {
|
||||||
|
"agent_id": format!("e2e-agent-{}", Uuid::new_v4()),
|
||||||
|
"confirm": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let error_message = response
|
||||||
|
.get("error")
|
||||||
|
.and_then(|e| e.get("message"))
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.expect("purge without confirm should return JSON-RPC error");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
error_message.contains("confirm: true") || error_message.contains("confirm"),
|
||||||
|
"purge error should explain confirmation requirement"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_query_isolated_by_agent_id() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
ensure_schema().await;
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let agent_a = format!("e2e-agent-a-{}", Uuid::new_v4());
|
||||||
|
let agent_b = format!("e2e-agent-b-{}", Uuid::new_v4());
|
||||||
|
let a_text = format!("A {} prefers dark mode", Uuid::new_v4());
|
||||||
|
let b_text = format!("B {} prefers light mode", Uuid::new_v4());
|
||||||
|
|
||||||
|
let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await;
|
||||||
|
let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await;
|
||||||
|
|
||||||
|
let _ = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"store",
|
||||||
|
json!({ "agent_id": agent_a, "content": a_text, "metadata": {"suite": "agent-isolation"} }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let _ = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"store",
|
||||||
|
json!({ "agent_id": agent_b, "content": b_text, "metadata": {"suite": "agent-isolation"} }),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let query_a = call_tool(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
"query",
|
||||||
|
json!({
|
||||||
|
"agent_id": agent_a,
|
||||||
|
"query": "mode preference",
|
||||||
|
"limit": 10,
|
||||||
|
"threshold": 0.0
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let results = query_a
|
||||||
|
.get("results")
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.expect("query results");
|
||||||
|
|
||||||
|
let has_a = results.iter().any(|item| {
|
||||||
|
item.get("content")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.map(|s| s == a_text)
|
||||||
|
.unwrap_or(false)
|
||||||
|
});
|
||||||
|
let has_b = results.iter().any(|item| {
|
||||||
|
item.get("content")
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.map(|s| s == b_text)
|
||||||
|
.unwrap_or(false)
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(has_a, "agent A query should include agent A memory");
|
||||||
|
assert!(!has_b, "agent A query must not include agent B memory");
|
||||||
|
|
||||||
|
let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await;
|
||||||
|
let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_initialize_contract() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
let response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "init-1",
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let result = response.get("result").expect("initialize result");
|
||||||
|
assert_eq!(
|
||||||
|
result.get("protocolVersion").and_then(Value::as_str),
|
||||||
|
Some("2024-11-05")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result
|
||||||
|
.get("serverInfo")
|
||||||
|
.and_then(|v| v.get("name"))
|
||||||
|
.and_then(Value::as_str),
|
||||||
|
Some("openbrain-mcp")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_initialized_notification_is_accepted() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let mut request = client
|
||||||
|
.post(format!("{base}/mcp/message"))
|
||||||
|
.json(&json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/initialized",
|
||||||
|
"params": {}
|
||||||
|
}));
|
||||||
|
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
request = request.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = request.send().await.expect("initialized notification request");
|
||||||
|
assert_eq!(
|
||||||
|
response.status(),
|
||||||
|
reqwest::StatusCode::ACCEPTED,
|
||||||
|
"notifications/initialized should be accepted without a JSON-RPC response body"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_sse_session_routes_posted_response() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let mut sse_request = client
|
||||||
|
.get(format!("{base}/mcp/sse"))
|
||||||
|
.header("Accept", "text/event-stream");
|
||||||
|
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
sse_request = sse_request.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sse_response = sse_request.send().await.expect("GET /mcp/sse");
|
||||||
|
assert_eq!(sse_response.status(), reqwest::StatusCode::OK);
|
||||||
|
assert!(
|
||||||
|
sse_response
|
||||||
|
.headers()
|
||||||
|
.get(reqwest::header::CONTENT_TYPE)
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.map(|value| value.starts_with("text/event-stream"))
|
||||||
|
.unwrap_or(false),
|
||||||
|
"SSE endpoint should return text/event-stream"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let (event_type, endpoint) = tokio::time::timeout(
|
||||||
|
Duration::from_secs(10),
|
||||||
|
read_sse_event(&mut sse_response, &mut buffer),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("timed out waiting for SSE endpoint event")
|
||||||
|
.expect("SSE endpoint event");
|
||||||
|
|
||||||
|
assert_eq!(event_type.as_deref(), Some("endpoint"));
|
||||||
|
assert!(
|
||||||
|
endpoint.contains("/mcp/message?sessionId="),
|
||||||
|
"endpoint event should advertise a session-specific message URL"
|
||||||
|
);
|
||||||
|
|
||||||
|
let post_url = if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
|
||||||
|
endpoint
|
||||||
|
} else {
|
||||||
|
format!("{base}{endpoint}")
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut post_request = client
|
||||||
|
.post(post_url)
|
||||||
|
.json(&json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "sse-tools-list-1",
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
}));
|
||||||
|
|
||||||
|
if let Some(key) = api_key() {
|
||||||
|
post_request = post_request.header("X-API-Key", key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let post_response = post_request.send().await.expect("POST session message");
|
||||||
|
assert_eq!(
|
||||||
|
post_response.status(),
|
||||||
|
reqwest::StatusCode::ACCEPTED,
|
||||||
|
"session-bound POST should be accepted and routed over SSE"
|
||||||
|
);
|
||||||
|
|
||||||
|
let (event_type, payload) = tokio::time::timeout(
|
||||||
|
Duration::from_secs(10),
|
||||||
|
read_sse_event(&mut sse_response, &mut buffer),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("timed out waiting for SSE message event")
|
||||||
|
.expect("SSE message event");
|
||||||
|
|
||||||
|
assert_eq!(event_type.as_deref(), Some("message"));
|
||||||
|
|
||||||
|
let message: Value = serde_json::from_str(&payload).expect("SSE payload should be valid JSON");
|
||||||
|
assert_eq!(
|
||||||
|
message.get("id").and_then(Value::as_str),
|
||||||
|
Some("sse-tools-list-1")
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
message
|
||||||
|
.get("result")
|
||||||
|
.and_then(|value| value.get("tools"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.map(|tools| !tools.is_empty())
|
||||||
|
.unwrap_or(false),
|
||||||
|
"SSE-routed tools/list response should include tool definitions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_health_endpoints() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
// Root health endpoint - no auth required
|
||||||
|
let root_health: Value = client
|
||||||
|
.get(format!("{base}/health"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("GET /health")
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("/health JSON");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
root_health.get("status").and_then(Value::as_str),
|
||||||
|
Some("ok"),
|
||||||
|
"/health should report server liveness"
|
||||||
|
);
|
||||||
|
|
||||||
|
// MCP health endpoint - requires auth if enabled
|
||||||
|
let mcp_health: Value = get_mcp_endpoint(&client, &base, "/mcp/health")
|
||||||
|
.await
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("/mcp/health JSON");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
mcp_health.get("status").and_then(Value::as_str),
|
||||||
|
Some("healthy"),
|
||||||
|
"/mcp/health should report MCP transport health"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
mcp_health.get("server").and_then(Value::as_str),
|
||||||
|
Some("openbrain-mcp")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_store_requires_content() {
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
wait_until_ready(&client, &base).await;
|
||||||
|
|
||||||
|
let response = call_jsonrpc(
|
||||||
|
&client,
|
||||||
|
&base,
|
||||||
|
json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "store-missing-content-1",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": "store",
|
||||||
|
"arguments": {
|
||||||
|
"agent_id": format!("e2e-agent-{}", Uuid::new_v4()),
|
||||||
|
"metadata": {"suite": "validation"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let message = response
|
||||||
|
.get("error")
|
||||||
|
.and_then(|e| e.get("message"))
|
||||||
|
.and_then(Value::as_str)
|
||||||
|
.expect("store missing content should return an error message");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
message.contains("Missing required parameter: content"),
|
||||||
|
"store validation should mention missing content"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_auth_rejection_without_key() {
|
||||||
|
// This test only runs when auth is expected to be enabled
|
||||||
|
let auth_enabled = std::env::var("OPENBRAIN__AUTH__ENABLED")
|
||||||
|
.map(|v| v == "true")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if !auth_enabled {
|
||||||
|
println!("Skipping auth rejection test - OPENBRAIN__AUTH__ENABLED is not true");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let base = base_url();
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
// Make request WITHOUT API key
|
||||||
|
let response = client
|
||||||
|
.post(format!("{base}/mcp/message"))
|
||||||
|
.json(&json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "auth-test-1",
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("HTTP request");
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
response.status().as_u16(),
|
||||||
|
401,
|
||||||
|
"Request without API key should return 401 Unauthorized"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_free_port() -> u16 {
|
||||||
|
std::net::TcpListener::bind("127.0.0.1:0")
|
||||||
|
.expect("bind ephemeral port")
|
||||||
|
.local_addr()
|
||||||
|
.expect("local addr")
|
||||||
|
.port()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_status(url: &str, expected_status: reqwest::StatusCode) {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(2))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
for _ in 0..80 {
|
||||||
|
if let Ok(resp) = client.get(url).send().await {
|
||||||
|
if resp.status() == expected_status {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_millis(250)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
panic!("Timed out waiting for {url} to return status {expected_status}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_auth_enabled_accepts_test_key() {
|
||||||
|
ensure_schema().await;
|
||||||
|
|
||||||
|
let port = pick_free_port();
|
||||||
|
let base = format!("http://127.0.0.1:{port}");
|
||||||
|
let test_key = "e2e-test-key-123";
|
||||||
|
|
||||||
|
let mut server = Command::new(env!("CARGO_BIN_EXE_openbrain-mcp"))
|
||||||
|
.current_dir(env!("CARGO_MANIFEST_DIR"))
|
||||||
|
.env("OPENBRAIN__SERVER__PORT", port.to_string())
|
||||||
|
.env("OPENBRAIN__AUTH__ENABLED", "true")
|
||||||
|
.env("OPENBRAIN__AUTH__API_KEYS", test_key)
|
||||||
|
.stdout(Stdio::null())
|
||||||
|
.stderr(Stdio::null())
|
||||||
|
.spawn()
|
||||||
|
.expect("spawn openbrain-mcp for auth-enabled e2e test");
|
||||||
|
|
||||||
|
wait_for_status(&format!("{base}/ready"), reqwest::StatusCode::OK).await;
|
||||||
|
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(20))
|
||||||
|
.build()
|
||||||
|
.expect("reqwest client");
|
||||||
|
|
||||||
|
let request = json!({
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "auth-enabled-1",
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
});
|
||||||
|
|
||||||
|
let unauthorized = client
|
||||||
|
.post(format!("{base}/mcp/message"))
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("unauthorized request");
|
||||||
|
assert_eq!(
|
||||||
|
unauthorized.status(),
|
||||||
|
reqwest::StatusCode::UNAUTHORIZED,
|
||||||
|
"request without key should be rejected when auth is enabled"
|
||||||
|
);
|
||||||
|
|
||||||
|
let authorized: Value = client
|
||||||
|
.post(format!("{base}/mcp/message"))
|
||||||
|
.header("X-API-Key", test_key)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("authorized request")
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("authorized JSON response");
|
||||||
|
|
||||||
|
assert!(authorized.get("error").is_none(), "valid key should not return JSON-RPC error");
|
||||||
|
assert!(
|
||||||
|
authorized
|
||||||
|
.get("result")
|
||||||
|
.and_then(|r| r.get("tools"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.map(|tools| !tools.is_empty())
|
||||||
|
.unwrap_or(false),
|
||||||
|
"authorized tools/list should return tool definitions"
|
||||||
|
);
|
||||||
|
|
||||||
|
let bearer_authorized: Value = client
|
||||||
|
.post(format!("{base}/mcp/message"))
|
||||||
|
.header("Authorization", format!("Bearer {test_key}"))
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("bearer-authorized request")
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.expect("bearer-authorized JSON response");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
bearer_authorized.get("error").is_none(),
|
||||||
|
"valid bearer token should not return JSON-RPC error"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
bearer_authorized
|
||||||
|
.get("result")
|
||||||
|
.and_then(|r| r.get("tools"))
|
||||||
|
.and_then(Value::as_array)
|
||||||
|
.map(|tools| !tools.is_empty())
|
||||||
|
.unwrap_or(false),
|
||||||
|
"authorized bearer tools/list should return tool definitions"
|
||||||
|
);
|
||||||
|
|
||||||
|
let _ = server.kill();
|
||||||
|
let _ = server.wait();
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user