From 774982dc5a9a9c46e473fd8210e85eef54a7e549 Mon Sep 17 00:00:00 2001 From: Agent Zero Date: Sat, 7 Mar 2026 13:41:36 -0500 Subject: [PATCH] Initial public release --- .env.example | 28 + .gitea/deploy.sh | 213 +++++++ .gitea/download-model.sh | 92 +++ .gitea/openbrain.service | 31 + .gitea/workflows/ci-cd.yaml | 188 ++++++ .gitignore | 27 + Cargo.toml | 65 ++ README.md | 161 +++++ migrations/V1__baseline_memories.sql | 38 ++ src/auth.rs | 125 ++++ src/config.rs | 146 +++++ src/db.rs | 176 ++++++ src/embedding.rs | 245 ++++++++ src/lib.rs | 150 +++++ src/main.rs | 46 ++ src/migrations.rs | 50 ++ src/tools/mod.rs | 106 ++++ src/tools/purge.rs | 79 +++ src/tools/query.rs | 81 +++ src/tools/store.rs | 66 ++ src/transport.rs | 531 ++++++++++++++++ tests/e2e_mcp.rs | 873 +++++++++++++++++++++++++++ 22 files changed, 3517 insertions(+) create mode 100644 .env.example create mode 100755 .gitea/deploy.sh create mode 100755 .gitea/download-model.sh create mode 100644 .gitea/openbrain.service create mode 100644 .gitea/workflows/ci-cd.yaml create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 migrations/V1__baseline_memories.sql create mode 100644 src/auth.rs create mode 100644 src/config.rs create mode 100644 src/db.rs create mode 100644 src/embedding.rs create mode 100644 src/lib.rs create mode 100644 src/main.rs create mode 100644 src/migrations.rs create mode 100644 src/tools/mod.rs create mode 100644 src/tools/purge.rs create mode 100644 src/tools/query.rs create mode 100644 src/tools/store.rs create mode 100644 src/transport.rs create mode 100644 tests/e2e_mcp.rs diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..90ab331 --- /dev/null +++ b/.env.example @@ -0,0 +1,28 @@ +# OpenBrain MCP Server Configuration +# Copy this to .env and fill in your values + +# Server Configuration +OPENBRAIN__SERVER__HOST=0.0.0.0 +OPENBRAIN__SERVER__PORT=3100 + +# Database Configuration (PostgreSQL with pgvector) +# This role should own the OpenBrain database objects that migrations manage. +OPENBRAIN__DATABASE__HOST=localhost +OPENBRAIN__DATABASE__PORT=5432 +OPENBRAIN__DATABASE__NAME=openbrain +OPENBRAIN__DATABASE__USER=openbrain_svc +OPENBRAIN__DATABASE__PASSWORD=your_secure_password_here +OPENBRAIN__DATABASE__POOL_SIZE=10 + +# Embedding Configuration +# Path to ONNX model directory (all-MiniLM-L6-v2) +OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2 +OPENBRAIN__EMBEDDING__DIMENSION=384 + +# Authentication (optional) +OPENBRAIN__AUTH__ENABLED=false +# Comma-separated list of API keys +# OPENBRAIN__AUTH__API_KEYS=key1,key2,key3 + +# Logging +RUST_LOG=info,openbrain_mcp=debug diff --git a/.gitea/deploy.sh b/.gitea/deploy.sh new file mode 100755 index 0000000..227c040 --- /dev/null +++ b/.gitea/deploy.sh @@ -0,0 +1,213 @@ +#!/bin/bash +# +# OpenBrain MCP Deployment Script +# Deploys the OpenBrain MCP server to the VPS +# +# Usage: ./deploy.sh [options] +# Options: +# --build-local Build on local machine (requires cross-compilation) +# --build-remote Build on VPS (default) +# --skip-model Skip model download +# --restart-only Only restart the service +# + +set -euo pipefail + +# Configuration +VPS_HOST="${VPS_HOST:-}" +VPS_USER="${VPS_USER:-root}" +DEPLOY_DIR="/opt/openbrain-mcp" +SERVICE_NAME="openbrain-mcp" +SSH_KEY="${SSH_KEY:-/tmp/id_ed25519}" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +# Parse arguments +BUILD_REMOTE=true +SKIP_MODEL=false +RESTART_ONLY=false + +for arg in "$@"; do + case $arg in + --build-local) BUILD_REMOTE=false ;; + --build-remote) BUILD_REMOTE=true ;; + --skip-model) SKIP_MODEL=true ;; + --restart-only) RESTART_ONLY=true ;; + *) log_error "Unknown argument: $arg"; exit 1 ;; + esac +done + +# Get script directory (where .gitea folder is) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +log_info "Project root: $PROJECT_ROOT" +if [ -z "$VPS_HOST" ]; then + log_error "VPS_HOST is required. Export VPS_HOST before running deploy.sh" + exit 1 +fi +log_info "Deploying to: $VPS_USER@$VPS_HOST:$DEPLOY_DIR" + +# SSH command helper +ssh_cmd() { + ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no "$VPS_USER@$VPS_HOST" "$@" +} + +scp_cmd() { + scp -i "$SSH_KEY" -o StrictHostKeyChecking=no "$@" +} + +# Restart only mode +if [ "$RESTART_ONLY" = true ]; then + log_info "Restarting service only..." + ssh_cmd "systemctl restart $SERVICE_NAME" + ssh_cmd "systemctl status $SERVICE_NAME --no-pager" + exit 0 +fi + +# Step 1: Create deployment directory on VPS +log_info "Creating deployment directory on VPS..." +ssh_cmd "mkdir -p $DEPLOY_DIR/{src,models,logs,lib,.gitea}" + +# Step 2: Sync source code to VPS +log_info "Syncing source code to VPS..." +rsync -avz --delete \ + -e "ssh -i $SSH_KEY -o StrictHostKeyChecking=no" \ + --exclude 'target/' \ + --exclude '.git/' \ + --exclude '.a0proj/' \ + --exclude 'models/' \ + --exclude '*.md' \ + "$PROJECT_ROOT/" \ + "$VPS_USER@$VPS_HOST:$DEPLOY_DIR/" + +# Step 3: Copy .env if it doesn't exist on VPS +if ! ssh_cmd "test -f $DEPLOY_DIR/.env"; then + log_warn ".env not found on VPS. Copying .env.example..." + ssh_cmd "cp $DEPLOY_DIR/.env.example $DEPLOY_DIR/.env" + log_warn "Please edit $DEPLOY_DIR/.env on VPS with actual credentials!" +fi + +# Step 4: Download model if needed +if [ "$SKIP_MODEL" = false ]; then + log_info "Checking/downloading embedding model..." + ssh_cmd "bash $DEPLOY_DIR/.gitea/download-model.sh" +fi + +# Step 5: Build on VPS +if [ "$BUILD_REMOTE" = true ]; then + log_info "Building on VPS (this may take a while on first run)..." + ssh_cmd "cd $DEPLOY_DIR && \ + source ~/.cargo/env 2>/dev/null || true && \ + cargo build --release 2>&1" +else + log_error "Local cross-compilation not yet implemented" + exit 1 +fi + +# Step 5b: Install the built binary where systemd expects it +log_info "Installing built binary..." +ssh_cmd "cp $DEPLOY_DIR/target/release/openbrain-mcp $DEPLOY_DIR/openbrain-mcp && chmod +x $DEPLOY_DIR/openbrain-mcp" + +# Step 5c: Bootstrap runtime prerequisites +log_info "Bootstrapping runtime prerequisites..." +ssh -i "$SSH_KEY" -o StrictHostKeyChecking=no "$VPS_USER@$VPS_HOST" \ + "DEPLOY_DIR=$DEPLOY_DIR SERVICE_USER=openbrain SERVICE_GROUP=openbrain ORT_VERSION=1.24.3 bash -s" <<'EOS' +set -euo pipefail + +DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}" +SERVICE_USER="${SERVICE_USER:-openbrain}" +SERVICE_GROUP="${SERVICE_GROUP:-openbrain}" +ORT_VERSION="${ORT_VERSION:-1.24.3}" + +if command -v apt-get >/dev/null 2>&1; then + export DEBIAN_FRONTEND=noninteractive + apt-get update + apt-get install -y --no-install-recommends ca-certificates curl tar libssl3 +fi + +if ! getent group "$SERVICE_GROUP" >/dev/null 2>&1; then + groupadd --system "$SERVICE_GROUP" +fi + +if ! id -u "$SERVICE_USER" >/dev/null 2>&1; then + useradd --system --gid "$SERVICE_GROUP" --home "$DEPLOY_DIR" --shell /usr/sbin/nologin "$SERVICE_USER" +fi + +install -d -m 0755 "$DEPLOY_DIR" "$DEPLOY_DIR/models" "$DEPLOY_DIR/logs" "$DEPLOY_DIR/lib" + +ARCH="$(uname -m)" +case "$ARCH" in + x86_64) ORT_ARCH="x64" ;; + aarch64|arm64) ORT_ARCH="aarch64" ;; + *) echo "Unsupported arch: $ARCH"; exit 1 ;; +esac + +if [[ ! -f "$DEPLOY_DIR/lib/libonnxruntime.so" ]]; then + TMP_DIR="$(mktemp -d)" + ORT_TGZ="onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}.tgz" + ORT_URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/${ORT_TGZ}" + curl -fL "$ORT_URL" -o "$TMP_DIR/$ORT_TGZ" + tar -xzf "$TMP_DIR/$ORT_TGZ" -C "$TMP_DIR" + ORT_ROOT="$TMP_DIR/onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}" + cp "$ORT_ROOT/lib/libonnxruntime.so" "$DEPLOY_DIR/lib/libonnxruntime.so" + cp "$ORT_ROOT/lib/libonnxruntime.so.${ORT_VERSION}" "$DEPLOY_DIR/lib/libonnxruntime.so.${ORT_VERSION}" || true + rm -rf "$TMP_DIR" +fi + +ENV_FILE="$DEPLOY_DIR/.env" +if [[ ! -f "$ENV_FILE" ]]; then + if [[ -f "$DEPLOY_DIR/.env.example" ]]; then + cp "$DEPLOY_DIR/.env.example" "$ENV_FILE" + else + touch "$ENV_FILE" + fi +fi + +upsert_env() { + local key="$1" + local value="$2" + if grep -qE "^${key}=" "$ENV_FILE"; then + sed -i "s|^${key}=.*|${key}=${value}|" "$ENV_FILE" + else + printf '%s=%s\n' "$key" "$value" >> "$ENV_FILE" + fi +} + +upsert_env "OPENBRAIN__EMBEDDING__MODEL_PATH" "$DEPLOY_DIR/models/all-MiniLM-L6-v2" +upsert_env "ORT_DYLIB_PATH" "$DEPLOY_DIR/lib/libonnxruntime.so" +upsert_env "OPENBRAIN__SERVER__HOST" "0.0.0.0" + +chmod +x "$DEPLOY_DIR/openbrain-mcp" "$DEPLOY_DIR/.gitea/download-model.sh" +chown -R "$SERVICE_USER:$SERVICE_GROUP" "$DEPLOY_DIR" +EOS + +# Step 5d: Run database migrations with the newly deployed binary +log_info "Running database migrations..." +ssh_cmd "cd $DEPLOY_DIR && ./openbrain-mcp migrate" + +# Step 6: Install systemd service +log_info "Installing systemd service..." +scp_cmd "$SCRIPT_DIR/openbrain.service" "$VPS_USER@$VPS_HOST:/etc/systemd/system/$SERVICE_NAME.service" +ssh_cmd "systemctl daemon-reload" +ssh_cmd "systemctl enable $SERVICE_NAME" + +# Step 7: Restart service +log_info "Restarting service..." +ssh_cmd "systemctl restart $SERVICE_NAME" +sleep 2 + +# Step 8: Check status +log_info "Checking service status..." +ssh_cmd "systemctl status $SERVICE_NAME --no-pager" || true + +log_info "Deployment complete!" +log_info "Service URL: http://$VPS_HOST:3100/mcp/health" diff --git a/.gitea/download-model.sh b/.gitea/download-model.sh new file mode 100755 index 0000000..d535686 --- /dev/null +++ b/.gitea/download-model.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# +# Download ONNX embedding model for OpenBrain MCP +# Downloads all-MiniLM-L6-v2 from Hugging Face +# + +set -euo pipefail + +DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}" +MODEL_DIR="$DEPLOY_DIR/models/all-MiniLM-L6-v2" +MODEL_NAME="sentence-transformers/all-MiniLM-L6-v2" + +# Colors +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } + +# Check if model already exists +if [ -f "$MODEL_DIR/model.onnx" ] && [ -f "$MODEL_DIR/tokenizer.json" ]; then + log_info "Model already exists at $MODEL_DIR" + exit 0 +fi + +log_info "Downloading embedding model to $MODEL_DIR..." +mkdir -p "$MODEL_DIR" + +# Method 1: Try using huggingface-cli if available +if command -v huggingface-cli &> /dev/null; then + log_info "Using huggingface-cli to download model..." + huggingface-cli download "$MODEL_NAME" \ + --local-dir "$MODEL_DIR" \ + --include "*.onnx" "*.json" "*.txt" \ + --exclude "*.bin" "*.safetensors" "*.h5" +else + # Method 2: Direct download from Hugging Face + log_info "Downloading directly from Hugging Face..." + + BASE_URL="https://huggingface.co/$MODEL_NAME/resolve/main" + + # Download ONNX model (we need the optimized one) + # First try the onnx directory + ONNX_URL="https://huggingface.co/$MODEL_NAME/resolve/main/onnx/model.onnx" + + log_info "Downloading model.onnx..." + if ! curl -fSL "$ONNX_URL" -o "$MODEL_DIR/model.onnx" 2>/dev/null; then + # Fallback: convert from pytorch (requires python) + log_warn "ONNX model not found, will need to convert from PyTorch..." + log_warn "Installing optimum for ONNX export..." + pip install --quiet optimum[exporters] onnx onnxruntime + + python3 << PYEOF +from optimum.onnxruntime import ORTModelForFeatureExtraction +from transformers import AutoTokenizer + +model = ORTModelForFeatureExtraction.from_pretrained("$MODEL_NAME", export=True) +tokenizer = AutoTokenizer.from_pretrained("$MODEL_NAME") + +model.save_pretrained("$MODEL_DIR") +tokenizer.save_pretrained("$MODEL_DIR") +print("Model exported to ONNX successfully!") +PYEOF + fi + + # Download tokenizer files + log_info "Downloading tokenizer.json..." + curl -fSL "$BASE_URL/tokenizer.json" -o "$MODEL_DIR/tokenizer.json" 2>/dev/null || true + + log_info "Downloading tokenizer_config.json..." + curl -fSL "$BASE_URL/tokenizer_config.json" -o "$MODEL_DIR/tokenizer_config.json" 2>/dev/null || true + + log_info "Downloading config.json..." + curl -fSL "$BASE_URL/config.json" -o "$MODEL_DIR/config.json" 2>/dev/null || true + + log_info "Downloading vocab.txt..." + curl -fSL "$BASE_URL/vocab.txt" -o "$MODEL_DIR/vocab.txt" 2>/dev/null || true + + log_info "Downloading special_tokens_map.json..." + curl -fSL "$BASE_URL/special_tokens_map.json" -o "$MODEL_DIR/special_tokens_map.json" 2>/dev/null || true +fi + +# Verify download +if [ -f "$MODEL_DIR/model.onnx" ]; then + MODEL_SIZE=$(du -h "$MODEL_DIR/model.onnx" | cut -f1) + log_info "Model downloaded successfully! Size: $MODEL_SIZE" + ls -la "$MODEL_DIR/" +else + log_warn "Warning: model.onnx not found after download" + exit 1 +fi diff --git a/.gitea/openbrain.service b/.gitea/openbrain.service new file mode 100644 index 0000000..36fd4e3 --- /dev/null +++ b/.gitea/openbrain.service @@ -0,0 +1,31 @@ +[Unit] +Description=OpenBrain MCP Server - Vector Memory for AI Agents +After=network-online.target postgresql.service +Wants=network-online.target postgresql.service + +[Service] +Type=simple +User=openbrain +Group=openbrain +WorkingDirectory=/opt/openbrain-mcp +EnvironmentFile=/opt/openbrain-mcp/.env +ExecStart=/opt/openbrain-mcp/openbrain-mcp +Restart=on-failure +RestartSec=5 +StandardOutput=journal +StandardError=journal +SyslogIdentifier=openbrain-mcp + +# Security hardening +NoNewPrivileges=true +PrivateTmp=true +ProtectSystem=strict +ProtectHome=true +ReadWritePaths=/opt/openbrain-mcp /opt/openbrain-mcp/logs /opt/openbrain-mcp/models /opt/openbrain-mcp/lib + +# Resource limits +LimitNOFILE=65535 +MemoryMax=1G + +[Install] +WantedBy=multi-user.target diff --git a/.gitea/workflows/ci-cd.yaml b/.gitea/workflows/ci-cd.yaml new file mode 100644 index 0000000..f657ee8 --- /dev/null +++ b/.gitea/workflows/ci-cd.yaml @@ -0,0 +1,188 @@ +name: OpenBrain MCP Build and Deploy + +on: + push: + branches: + - '**' + +jobs: + build-and-deploy: + runs-on: ubuntu-latest + env: + CARGO_TERM_COLOR: always + VPS_HOST: ${{ vars.VPS_HOST }} + VPS_USER: ${{ vars.VPS_USER }} + DEPLOY_DIR: /opt/openbrain-mcp + SERVICE_NAME: openbrain-mcp + steps: + - name: Install prerequisites + run: | + set -euxo pipefail + if command -v apt-get >/dev/null 2>&1; then + if command -v sudo >/dev/null 2>&1; then SUDO=sudo; else SUDO=; fi + $SUDO apt-get update + $SUDO apt-get install -y --no-install-recommends git ca-certificates curl + elif command -v apk >/dev/null 2>&1; then + apk add --no-cache git ca-certificates curl + fi + + - name: Checkout repository + run: | + set -euxo pipefail + git clone --depth 1 "https://${{ github.token }}@gitea.ingwaz.work/${{ github.repository }}.git" . + git fetch origin "${{ github.ref }}" + git checkout FETCH_HEAD + + - name: Install build dependencies + run: | + set -euxo pipefail + if command -v sudo >/dev/null 2>&1; then SUDO=sudo; else SUDO=; fi + $SUDO apt-get install -y --no-install-recommends build-essential pkg-config libssl-dev openssh-client + + - name: Install Rust toolchain + run: | + set -euxo pipefail + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable + . "$HOME/.cargo/env" + rustc --version + cargo --version + + - name: CI checks + run: | + set -euxo pipefail + . "$HOME/.cargo/env" + cargo check + cargo test --no-run + + - name: Build release + run: | + set -euxo pipefail + . "$HOME/.cargo/env" + cargo build --release + test -x target/release/openbrain-mcp + + - name: Setup SSH auth + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' + run: | + set -euxo pipefail + : "${VPS_HOST:?Set repository variable VPS_HOST}" + : "${VPS_USER:=root}" + install -d -m 700 "$HOME/.ssh" + printf '%s\n' "${{ secrets.VPS_SSH_KEY }}" > "$HOME/.ssh/deploy_key" + chmod 600 "$HOME/.ssh/deploy_key" + ssh-keyscan -H "$VPS_HOST" >> "$HOME/.ssh/known_hosts" + + - name: Deploy artifacts + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' + run: | + set -euxo pipefail + : "${VPS_HOST:?Set repository variable VPS_HOST}" + : "${VPS_USER:=root}" + SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes" + SCP="scp -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes" + + # Stop service before deploying to avoid "Text file busy" error + $SSH "$VPS_USER@$VPS_HOST" "systemctl stop $SERVICE_NAME 2>/dev/null || true" + + $SSH "$VPS_USER@$VPS_HOST" "mkdir -p $DEPLOY_DIR/.gitea $DEPLOY_DIR/models $DEPLOY_DIR/logs $DEPLOY_DIR/lib" + + $SCP target/release/openbrain-mcp "$VPS_USER@$VPS_HOST:$DEPLOY_DIR/openbrain-mcp" + $SCP .gitea/openbrain.service "$VPS_USER@$VPS_HOST:/etc/systemd/system/$SERVICE_NAME.service" + $SCP .gitea/download-model.sh "$VPS_USER@$VPS_HOST:$DEPLOY_DIR/.gitea/download-model.sh" + + - name: Bootstrap VPS and restart service + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' + run: | + set -euxo pipefail + : "${VPS_HOST:?Set repository variable VPS_HOST}" + : "${VPS_USER:=root}" + SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes" + + $SSH "$VPS_USER@$VPS_HOST" "DEPLOY_DIR=$DEPLOY_DIR SERVICE_USER=openbrain SERVICE_GROUP=openbrain ORT_VERSION=1.24.3 bash -s" <<'EOS' + set -euo pipefail + DEPLOY_DIR="${DEPLOY_DIR:-/opt/openbrain-mcp}" + SERVICE_USER="${SERVICE_USER:-openbrain}" + SERVICE_GROUP="${SERVICE_GROUP:-openbrain}" + ORT_VERSION="${ORT_VERSION:-1.24.3}" + + if command -v apt-get >/dev/null 2>&1; then + export DEBIAN_FRONTEND=noninteractive + apt-get update + apt-get install -y --no-install-recommends ca-certificates curl tar libssl3 + fi + + if ! getent group "$SERVICE_GROUP" >/dev/null 2>&1; then + groupadd --system "$SERVICE_GROUP" + fi + if ! id -u "$SERVICE_USER" >/dev/null 2>&1; then + useradd --system --gid "$SERVICE_GROUP" --home "$DEPLOY_DIR" --shell /usr/sbin/nologin "$SERVICE_USER" + fi + + install -d -m 0755 "$DEPLOY_DIR" "$DEPLOY_DIR/models" "$DEPLOY_DIR/logs" "$DEPLOY_DIR/lib" + + ARCH="$(uname -m)" + case "$ARCH" in + x86_64) ORT_ARCH="x64" ;; + aarch64|arm64) ORT_ARCH="aarch64" ;; + *) echo "Unsupported arch: $ARCH"; exit 1 ;; + esac + + if [[ ! -f "$DEPLOY_DIR/lib/libonnxruntime.so" ]]; then + TMP_DIR="$(mktemp -d)" + ORT_TGZ="onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}.tgz" + ORT_URL="https://github.com/microsoft/onnxruntime/releases/download/v${ORT_VERSION}/${ORT_TGZ}" + curl -fL "$ORT_URL" -o "$TMP_DIR/$ORT_TGZ" + tar -xzf "$TMP_DIR/$ORT_TGZ" -C "$TMP_DIR" + ORT_ROOT="$TMP_DIR/onnxruntime-linux-${ORT_ARCH}-${ORT_VERSION}" + cp "$ORT_ROOT/lib/libonnxruntime.so" "$DEPLOY_DIR/lib/libonnxruntime.so" + cp "$ORT_ROOT/lib/libonnxruntime.so.${ORT_VERSION}" "$DEPLOY_DIR/lib/libonnxruntime.so.${ORT_VERSION}" || true + rm -rf "$TMP_DIR" + fi + + ENV_FILE="$DEPLOY_DIR/.env" + if [[ ! -f "$ENV_FILE" ]]; then + if [[ -f "$DEPLOY_DIR/.env.example" ]]; then cp "$DEPLOY_DIR/.env.example" "$ENV_FILE"; else touch "$ENV_FILE"; fi + fi + + upsert_env() { + local key="$1" + local value="$2" + if grep -qE "^${key}=" "$ENV_FILE"; then + sed -i "s|^${key}=.*|${key}=${value}|" "$ENV_FILE" + else + printf '%s=%s\n' "$key" "$value" >> "$ENV_FILE" + fi + } + + upsert_env "OPENBRAIN__EMBEDDING__MODEL_PATH" "$DEPLOY_DIR/models/all-MiniLM-L6-v2" + upsert_env "ORT_DYLIB_PATH" "$DEPLOY_DIR/lib/libonnxruntime.so" + upsert_env "OPENBRAIN__SERVER__HOST" "0.0.0.0" + + chmod +x "$DEPLOY_DIR/openbrain-mcp" "$DEPLOY_DIR/.gitea/download-model.sh" + chown -R "$SERVICE_USER:$SERVICE_GROUP" "$DEPLOY_DIR" + EOS + + $SSH "$VPS_USER@$VPS_HOST" "DEPLOY_DIR=$DEPLOY_DIR bash $DEPLOY_DIR/.gitea/download-model.sh" + $SSH "$VPS_USER@$VPS_HOST" "cd $DEPLOY_DIR && ./openbrain-mcp migrate" + + $SSH "$VPS_USER@$VPS_HOST" "systemctl daemon-reload" + $SSH "$VPS_USER@$VPS_HOST" "systemctl enable $SERVICE_NAME" + $SSH "$VPS_USER@$VPS_HOST" "systemctl restart $SERVICE_NAME" + + - name: Verify deployment + if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/master' + run: | + set -euxo pipefail + : "${VPS_HOST:?Set repository variable VPS_HOST}" + : "${VPS_USER:=root}" + SSH="ssh -i $HOME/.ssh/deploy_key -o IdentitiesOnly=yes" + + sleep 5 + $SSH "$VPS_USER@$VPS_HOST" "systemctl status $SERVICE_NAME --no-pager || journalctl -u $SERVICE_NAME --no-pager -n 80" + curl -fsS "http://$VPS_HOST:3100/health" + curl -fsS "http://$VPS_HOST:3100/ready" + + - name: Cleanup SSH key + if: always() + run: | + rm -f "$HOME/.ssh/deploy_key" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e824fcb --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# Rust build artifacts +/target/ +**/*.rs.bk +Cargo.lock + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Environment +.env +.env.local + +# Logs +logs/*.log + +# Models (downloaded at runtime) +models/ + +# OS +.DS_Store +Thumbs.db + +.a0proj +CREDENTIALS.md diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..2a23700 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "openbrain-mcp" +version = "0.1.0" +edition = "2021" +authors = ["Ingwaz "] +description = "High-performance vector memory MCP server for AI agents" +license = "MIT" +repository = "https://gitea.ingwaz.work/Ingwaz/openbrain-mcp" + +[dependencies] +# MCP Framework +rmcp = { version = "0.1", features = ["server", "transport-sse"] } + +# ONNX Runtime for local embeddings +ort = { version = "2.0.0-rc.12", features = ["load-dynamic"] } +ndarray = "0.16" +tokenizers = "0.21" + +# Database +tokio-postgres = { version = "0.7", features = ["with-uuid-1", "with-chrono-0_4", "with-serde_json-1"] } +deadpool-postgres = "0.14" +pgvector = { version = "0.4", features = ["postgres"] } +refinery = { version = "0.9", features = ["tokio-postgres"] } + +# HTTP Server +axum = { version = "0.8", features = ["macros"] } +axum-extra = { version = "0.10", features = ["typed-header"] } +tower = "0.5" +tower-http = { version = "0.6", features = ["cors", "trace"] } + +# Async Runtime +tokio = { version = "1.44", features = ["full"] } +futures = "0.3" +async-stream = "0.3" + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Security +sha2 = "0.10" +hex = "0.4" + +# Utilities +uuid = { version = "1.16", features = ["v4", "serde"] } +chrono = { version = "0.4", features = ["serde"] } +thiserror = "2.0" +anyhow = "1.0" + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } + +# Configuration +config = "0.15" +dotenvy = "0.15" + +[profile.release] +lto = true +codegen-units = 1 +opt-level = 3 +strip = true + +[dev-dependencies] +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } diff --git a/README.md b/README.md new file mode 100644 index 0000000..5903a42 --- /dev/null +++ b/README.md @@ -0,0 +1,161 @@ +# OpenBrain MCP Server + +**High-performance vector memory for AI agents** + +OpenBrain is a Model Context Protocol (MCP) server that provides AI agents with a persistent, semantic memory system. It uses local ONNX-based embeddings and PostgreSQL with pgvector for efficient similarity search. + +## Features + +- 🧠 **Semantic Memory**: Store and retrieve memories using vector similarity search +- 🏠 **Local Embeddings**: No external API calls - uses ONNX runtime with all-MiniLM-L6-v2 +- 🐘 **PostgreSQL + pgvector**: Production-grade vector storage with HNSW indexing +- 🔌 **MCP Protocol**: Standard Model Context Protocol over SSE transport +- 🔐 **Multi-Agent Support**: Isolated memory namespaces per agent +- ⚡ **High Performance**: Rust implementation with async I/O + +## MCP Tools + +| Tool | Description | +|------|-------------| +| `store` | Store a memory with automatic embedding generation and keyword extraction | +| `query` | Search memories by semantic similarity | +| `purge` | Delete memories by agent ID or time range | + +## Quick Start + +### Prerequisites + +- Rust 1.75+ +- PostgreSQL 14+ with pgvector extension +- ONNX model files (all-MiniLM-L6-v2) + +### Database Setup + +```sql +CREATE ROLE openbrain_svc LOGIN PASSWORD 'change-me'; +CREATE DATABASE openbrain OWNER openbrain_svc; +\c openbrain +CREATE EXTENSION IF NOT EXISTS vector; +``` + +Use the same PostgreSQL role for the app and for migrations. Do not create the +`memories` table manually as `postgres` or another owner and then run +OpenBrain as `openbrain_svc`, because later `ALTER TABLE` migrations will fail +with `must be owner of table memories`. + +### Configuration + +```bash +cp .env.example .env +# Edit .env with your database credentials +``` + +### Build & Run + +```bash +cargo build --release +./target/release/openbrain-mcp migrate +./target/release/openbrain-mcp +``` + +### Database Migrations + +This project uses `refinery` with embedded SQL migrations in `migrations/`. + +Run pending migrations explicitly before starting or restarting the service: + +```bash +./target/release/openbrain-mcp migrate +``` + +If you use the deploy script or CI workflow in [`.gitea/deploy.sh`](/Users/bobbytables/ai/openbrain-mcp/.gitea/deploy.sh) and [`.gitea/workflows/ci-cd.yaml`](/Users/bobbytables/ai/openbrain-mcp/.gitea/workflows/ci-cd.yaml), they already run this for you. + +## MCP Integration + +Connect to the server using SSE transport: + +``` +SSE Endpoint: http://localhost:3100/mcp/sse +Message Endpoint: http://localhost:3100/mcp/message +Health Check: http://localhost:3100/mcp/health +``` + +### Example: Store a Memory + +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "store", + "arguments": { + "content": "The user prefers dark mode and uses vim keybindings", + "agent_id": "assistant-1", + "metadata": {"source": "preferences"} + } + } +} +``` + +### Example: Query Memories + +```json +{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "query", + "arguments": { + "query": "What are the user's editor preferences?", + "agent_id": "assistant-1", + "limit": 5, + "threshold": 0.6 + } + } +} +``` + +## Architecture + +``` +┌─────────────────────────────────────────────────────────┐ +│ AI Agent │ +└─────────────────────┬───────────────────────────────────┘ + │ MCP Protocol (SSE) +┌─────────────────────▼───────────────────────────────────┐ +│ OpenBrain MCP Server │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ store │ │ query │ │ purge │ │ +│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ +│ │ │ │ │ +│ ┌──────▼────────────────▼────────────────▼──────┐ │ +│ │ Embedding Engine (ONNX) │ │ +│ │ all-MiniLM-L6-v2 (384d) │ │ +│ └──────────────────────┬────────────────────────┘ │ +│ │ │ +│ ┌──────────────────────▼────────────────────────┐ │ +│ │ PostgreSQL + pgvector │ │ +│ │ HNSW Index for fast search │ │ +│ └────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────┘ +``` + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `OPENBRAIN__SERVER__HOST` | `0.0.0.0` | Server bind address | +| `OPENBRAIN__SERVER__PORT` | `3100` | Server port | +| `OPENBRAIN__DATABASE__HOST` | `localhost` | PostgreSQL host | +| `OPENBRAIN__DATABASE__PORT` | `5432` | PostgreSQL port | +| `OPENBRAIN__DATABASE__NAME` | `openbrain` | Database name | +| `OPENBRAIN__DATABASE__USER` | - | Database user | +| `OPENBRAIN__DATABASE__PASSWORD` | - | Database password | +| `OPENBRAIN__EMBEDDING__MODEL_PATH` | `models/all-MiniLM-L6-v2` | ONNX model path | +| `OPENBRAIN__AUTH__ENABLED` | `false` | Enable API key auth | + +## License + +MIT diff --git a/migrations/V1__baseline_memories.sql b/migrations/V1__baseline_memories.sql new file mode 100644 index 0000000..625cac7 --- /dev/null +++ b/migrations/V1__baseline_memories.sql @@ -0,0 +1,38 @@ +-- Run OpenBrain migrations as the same database role that owns these objects. +-- Existing installs with a differently owned memories table must transfer +-- ownership before this baseline migration can apply ALTER TABLE changes. + +CREATE TABLE IF NOT EXISTS memories ( + id UUID PRIMARY KEY, + agent_id VARCHAR(255) NOT NULL, + content TEXT NOT NULL, + embedding vector(384) NOT NULL, + keywords TEXT[] DEFAULT '{}', + metadata JSONB DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ DEFAULT NOW() +); + +ALTER TABLE memories + ALTER COLUMN agent_id TYPE VARCHAR(255); + +ALTER TABLE memories + ADD COLUMN IF NOT EXISTS keywords TEXT[] DEFAULT '{}'; + +ALTER TABLE memories + ADD COLUMN IF NOT EXISTS metadata JSONB DEFAULT '{}'::jsonb; + +ALTER TABLE memories + ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT NOW(); + +ALTER TABLE memories + ALTER COLUMN keywords SET DEFAULT '{}'; + +ALTER TABLE memories + ALTER COLUMN metadata SET DEFAULT '{}'::jsonb; + +ALTER TABLE memories + ALTER COLUMN created_at SET DEFAULT NOW(); + +CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(agent_id); +CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories + USING hnsw (embedding vector_cosine_ops); diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..ffba822 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,125 @@ +//! Authentication module for OpenBrain MCP +//! +//! Provides API key-based authentication for securing the MCP endpoints. + +use axum::{ + extract::{Request, State}, + http::{HeaderMap, StatusCode, header::AUTHORIZATION}, + middleware::Next, + response::Response, +}; +use sha2::{Digest, Sha256}; +use std::sync::Arc; +use tracing::warn; + +use crate::AppState; + +/// Hash an API key for secure comparison +pub fn hash_api_key(key: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(key.as_bytes()); + hex::encode(hasher.finalize()) +} + +/// Middleware for API key authentication +pub async fn auth_middleware( + State(state): State>, + request: Request, + next: Next, +) -> Result { + // Skip auth if disabled + if !state.config.auth.enabled { + return Ok(next.run(request).await); + } + + let api_key = extract_api_key(request.headers()); + + match api_key { + Some(key) => { + // Check if key is valid + let key_hash = hash_api_key(&key); + let valid = state + .config + .auth + .api_keys + .iter() + .any(|k| hash_api_key(k) == key_hash); + + if valid { + Ok(next.run(request).await) + } else { + warn!("Invalid API key or bearer token attempted"); + Err(StatusCode::UNAUTHORIZED) + } + } + None => { + warn!("Missing API key or bearer token in request"); + Err(StatusCode::UNAUTHORIZED) + } + } +} + +fn extract_api_key(headers: &HeaderMap) -> Option { + headers + .get("X-API-Key") + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) + .or_else(|| { + headers + .get(AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .and_then(|value| { + let (scheme, token) = value.split_once(' ')?; + scheme + .eq_ignore_ascii_case("bearer") + .then_some(token.trim()) + }) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) + }) +} + +pub fn get_optional_agent_id(headers: &HeaderMap) -> Option { + headers + .get("X-Agent-ID") + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) +} + +/// Extract agent ID from request headers or default +pub fn get_agent_id(request: &Request) -> String { + get_optional_agent_id(request.headers()) + .unwrap_or_else(|| "default".to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::{HeaderValue, header::AUTHORIZATION}; + + #[test] + fn extracts_api_key_from_bearer_header() { + let mut headers = HeaderMap::new(); + headers.insert( + AUTHORIZATION, + HeaderValue::from_static("Bearer test-token"), + ); + + assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token")); + } + + #[test] + fn extracts_agent_id_from_header() { + let mut headers = HeaderMap::new(); + headers.insert("X-Agent-ID", HeaderValue::from_static("agent-zero")); + + assert_eq!( + get_optional_agent_id(&headers).as_deref(), + Some("agent-zero") + ); + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..449be4f --- /dev/null +++ b/src/config.rs @@ -0,0 +1,146 @@ +//! Configuration management for OpenBrain MCP +//! +//! Loads configuration from environment variables with sensible defaults. + +use anyhow::Result; +use serde::{Deserialize, Deserializer}; + +/// Main configuration structure +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub server: ServerConfig, + pub database: DatabaseConfig, + pub embedding: EmbeddingConfig, + pub auth: AuthConfig, +} + +/// Server configuration +#[derive(Debug, Clone, Deserialize)] +pub struct ServerConfig { + #[serde(default = "default_host")] + pub host: String, + #[serde(default = "default_port")] + pub port: u16, +} + +/// Database configuration +#[derive(Debug, Clone, Deserialize)] +pub struct DatabaseConfig { + pub host: String, + #[serde(default = "default_db_port")] + pub port: u16, + pub name: String, + pub user: String, + pub password: String, + #[serde(default = "default_pool_size")] + pub pool_size: usize, +} + +/// Embedding engine configuration +#[derive(Debug, Clone, Deserialize)] +pub struct EmbeddingConfig { + #[serde(default = "default_model_path")] + pub model_path: String, + #[serde(default = "default_embedding_dim")] + pub dimension: usize, +} + +/// Authentication configuration +#[derive(Debug, Clone, Deserialize)] +pub struct AuthConfig { + #[serde(default = "default_auth_enabled")] + pub enabled: bool, + #[serde(default, deserialize_with = "deserialize_api_keys")] + pub api_keys: Vec, +} + +/// Deserialize API keys from either a comma-separated string or a Vec +fn deserialize_api_keys<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + // Try to deserialize as a string first, then as a Vec + #[derive(Deserialize)] + #[serde(untagged)] + enum StringOrVec { + String(String), + Vec(Vec), + } + + match Option::::deserialize(deserializer)? { + Some(StringOrVec::String(s)) => { + Ok(s.split(',') + .map(|k| k.trim().to_string()) + .filter(|k| !k.is_empty()) + .collect()) + } + Some(StringOrVec::Vec(v)) => Ok(v), + None => Ok(Vec::new()), + } +} + +// Default value functions +fn default_host() -> String { "0.0.0.0".to_string() } +fn default_port() -> u16 { 3100 } +fn default_db_port() -> u16 { 5432 } +fn default_pool_size() -> usize { 10 } +fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() } +fn default_embedding_dim() -> usize { 384 } +fn default_auth_enabled() -> bool { false } + +impl Config { + /// Load configuration from environment variables + pub fn load() -> Result { + // Load .env file if present + dotenvy::dotenv().ok(); + + let config = config::Config::builder() + // Server settings + .set_default("server.host", default_host())? + .set_default("server.port", default_port() as i64)? + // Database settings + .set_default("database.port", default_db_port() as i64)? + .set_default("database.pool_size", default_pool_size() as i64)? + // Embedding settings + .set_default("embedding.model_path", default_model_path())? + .set_default("embedding.dimension", default_embedding_dim() as i64)? + // Auth settings + .set_default("auth.enabled", default_auth_enabled())? + // Load from environment with OPENBRAIN_ prefix + .add_source( + config::Environment::with_prefix("OPENBRAIN") + .separator("__") + .try_parsing(true), + ) + .build()?; + + Ok(config.try_deserialize()?) + } +} + +impl Default for Config { + fn default() -> Self { + Self { + server: ServerConfig { + host: default_host(), + port: default_port(), + }, + database: DatabaseConfig { + host: "localhost".to_string(), + port: default_db_port(), + name: "openbrain".to_string(), + user: "openbrain_svc".to_string(), + password: String::new(), + pool_size: default_pool_size(), + }, + embedding: EmbeddingConfig { + model_path: default_model_path(), + dimension: default_embedding_dim(), + }, + auth: AuthConfig { + enabled: default_auth_enabled(), + api_keys: Vec::new(), + }, + } + } +} diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..e1ef1f7 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,176 @@ +//! Database module for PostgreSQL with pgvector support +//! +//! Provides connection pooling and query helpers for vector operations. + +use anyhow::{Context, Result}; +use deadpool_postgres::{Config, Pool, Runtime}; +use pgvector::Vector; +use tokio_postgres::NoTls; +use tracing::info; +use uuid::Uuid; + +use crate::config::DatabaseConfig; + +/// Database wrapper with connection pool +#[derive(Clone)] +pub struct Database { + pool: Pool, +} + +/// A memory record stored in the database +#[derive(Debug, Clone)] +pub struct MemoryRecord { + pub id: Uuid, + pub agent_id: String, + pub content: String, + pub embedding: Vec, + pub keywords: Vec, + pub metadata: serde_json::Value, + pub created_at: chrono::DateTime, +} + +/// Query result with similarity score +#[derive(Debug, Clone)] +pub struct MemoryMatch { + pub record: MemoryRecord, + pub similarity: f32, +} + +impl Database { + /// Create a new database connection pool + pub async fn new(config: &DatabaseConfig) -> Result { + let mut cfg = Config::new(); + cfg.host = Some(config.host.clone()); + cfg.port = Some(config.port); + cfg.dbname = Some(config.name.clone()); + cfg.user = Some(config.user.clone()); + cfg.password = Some(config.password.clone()); + + let pool = cfg + .create_pool(Some(Runtime::Tokio1), NoTls) + .context("Failed to create database pool")?; + + // Test connection + let client = pool.get().await.context("Failed to get database connection")?; + client + .simple_query("SELECT 1") + .await + .context("Failed to execute test query")?; + + info!("Database connection pool created with {} connections", config.pool_size); + + Ok(Self { pool }) + } + + /// Store a memory record + pub async fn store_memory( + &self, + agent_id: &str, + content: &str, + embedding: &[f32], + keywords: &[String], + metadata: serde_json::Value, + ) -> Result { + let client = self.pool.get().await?; + let id = Uuid::new_v4(); + let vector = Vector::from(embedding.to_vec()); + + client + .execute( + r#" + INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata) + VALUES ($1, $2, $3, $4, $5, $6) + "#, + &[&id, &agent_id, &content, &vector, &keywords, &metadata], + ) + .await + .context("Failed to store memory")?; + + Ok(id) + } + + /// Query memories by vector similarity + pub async fn query_memories( + &self, + agent_id: &str, + embedding: &[f32], + limit: i64, + threshold: f32, + ) -> Result> { + let client = self.pool.get().await?; + let vector = Vector::from(embedding.to_vec()); + let threshold_f64 = threshold as f64; + + let rows = client + .query( + r#" + SELECT + id, agent_id, content, keywords, metadata, created_at, + (1 - (embedding <=> $1))::real AS similarity + FROM memories + WHERE agent_id = $2 + AND (1 - (embedding <=> $1)) >= $3 + ORDER BY embedding <=> $1 + LIMIT $4 + "#, + &[&vector, &agent_id, &threshold_f64, &limit], + ) + .await + .context("Failed to query memories")?; + + let matches = rows + .iter() + .map(|row| MemoryMatch { + record: MemoryRecord { + id: row.get("id"), + agent_id: row.get("agent_id"), + content: row.get("content"), + // Query responses do not include raw embedding payloads. + embedding: Vec::new(), + keywords: row.get("keywords"), + metadata: row.get("metadata"), + created_at: row.get("created_at"), + }, + similarity: row.get("similarity"), + }) + .collect(); + + Ok(matches) + } + + /// Delete memories by agent_id and optional filters + pub async fn purge_memories( + &self, + agent_id: &str, + before: Option>, + ) -> Result { + let client = self.pool.get().await?; + + let count = if let Some(before_ts) = before { + client + .execute( + "DELETE FROM memories WHERE agent_id = $1 AND created_at < $2", + &[&agent_id, &before_ts], + ) + .await? + } else { + client + .execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id]) + .await? + }; + + Ok(count) + } + + /// Get memory count for an agent + pub async fn count_memories(&self, agent_id: &str) -> Result { + let client = self.pool.get().await?; + let row = client + .query_one( + "SELECT COUNT(*) as count FROM memories WHERE agent_id = $1", + &[&agent_id], + ) + .await?; + Ok(row.get("count")) + } +} diff --git a/src/embedding.rs b/src/embedding.rs new file mode 100644 index 0000000..ddd0d9f --- /dev/null +++ b/src/embedding.rs @@ -0,0 +1,245 @@ +//! Embedding engine using local ONNX models + +use anyhow::Result; +use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::value::Value; +use std::path::{Path, PathBuf}; +use std::sync::Once; +use tokenizers::Tokenizer; +use tracing::info; + +use crate::config::EmbeddingConfig; + +static ORT_INIT: Once = Once::new(); + +/// Initialize ONNX Runtime synchronously (called inside spawn_blocking) +fn init_ort_sync(dylib_path: &str) -> Result<()> { + info!("Initializing ONNX Runtime from: {}", dylib_path); + + let mut init_error: Option = None; + + ORT_INIT.call_once(|| { + info!("ORT_INIT.call_once - starting initialization"); + match ort::init_from(dylib_path) { + Ok(builder) => { + info!("ort::init_from succeeded, calling commit()"); + let committed = builder.commit(); + info!("commit() returned: {}", committed); + if !committed { + init_error = Some("ONNX Runtime commit returned false".to_string()); + } + } + Err(e) => { + let err_msg = format!("ONNX Runtime init_from failed: {:?}", e); + info!("{}", err_msg); + init_error = Some(err_msg); + } + } + info!("ORT_INIT.call_once - finished"); + }); + + // Note: init_error won't be set if ORT_INIT was already called + // This is fine - we only initialize once + if let Some(err) = init_error { + return Err(anyhow::anyhow!("{}", err)); + } + + info!("ONNX Runtime initialization complete"); + Ok(()) +} + +/// Resolve ONNX Runtime dylib path from env var or common local install locations. +fn resolve_ort_dylib_path() -> Result { + if let Ok(path) = std::env::var("ORT_DYLIB_PATH") { + if Path::new(&path).exists() { + return Ok(path); + } + + return Err(anyhow::anyhow!( + "ORT_DYLIB_PATH is set but file does not exist: {}", + path + )); + } + + let candidates = [ + "/opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib", + "/usr/local/opt/onnxruntime/lib/libonnxruntime.dylib", + ]; + + for candidate in candidates { + if Path::new(candidate).exists() { + return Ok(candidate.to_string()); + } + } + + Err(anyhow::anyhow!( + "ORT_DYLIB_PATH environment variable not set and ONNX Runtime dylib not found. \ +Set ORT_DYLIB_PATH to your libonnxruntime.dylib path (for example: /opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib)." + )) +} + +pub struct EmbeddingEngine { + session: std::sync::Mutex, + tokenizer: Tokenizer, + dimension: usize, +} + +impl EmbeddingEngine { + /// Create a new embedding engine + pub async fn new(config: &EmbeddingConfig) -> Result { + let dylib_path = resolve_ort_dylib_path()?; + + let model_path = PathBuf::from(&config.model_path); + let dimension = config.dimension; + + info!("Loading ONNX model from {:?}", model_path.join("model.onnx")); + + // Use spawn_blocking to avoid blocking the async runtime + let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> { + // Initialize ONNX Runtime first + init_ort_sync(&dylib_path)?; + + info!("Creating ONNX session..."); + + // Load ONNX model with ort 2.0 API + let session = Session::builder() + .map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))? + .with_intra_threads(4) + .map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))? + .commit_from_file(model_path.join("model.onnx")) + .map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?; + + info!("ONNX model loaded, loading tokenizer..."); + + // Load tokenizer + let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json")) + .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; + + info!("Tokenizer loaded successfully"); + Ok((session, tokenizer)) + }).await + .map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??; + + info!( + "Embedding engine initialized: model={}, dimension={}", + config.model_path, dimension + ); + + Ok(Self { + session: std::sync::Mutex::new(session), + tokenizer, + dimension, + }) + } + + /// Generate embedding for a single text + pub fn embed(&self, text: &str) -> Result> { + let encoding = self.tokenizer + .encode(text, true) + .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; + + let input_ids: Vec = encoding.get_ids().iter().map(|&x| x as i64).collect(); + let attention_mask: Vec = encoding.get_attention_mask().iter().map(|&x| x as i64).collect(); + let token_type_ids: Vec = encoding.get_type_ids().iter().map(|&x| x as i64).collect(); + + let seq_len = input_ids.len(); + + // Create input tensors with ort 2.0 API + let input_ids_tensor = Value::from_array(([1, seq_len], input_ids))?; + let attention_mask_tensor = Value::from_array(([1, seq_len], attention_mask))?; + let token_type_ids_tensor = Value::from_array(([1, seq_len], token_type_ids))?; + + // Run inference + let inputs = ort::inputs![ + "input_ids" => input_ids_tensor, + "attention_mask" => attention_mask_tensor, + "token_type_ids" => token_type_ids_tensor, + ]; + + let mut session_guard = self.session.lock() + .map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?; + let outputs = session_guard.run(inputs)?; + + // Extract output + let output = outputs.get("last_hidden_state") + .ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?; + + // Get the tensor data + let (shape, data) = output.try_extract_tensor::()?; + + // Mean pooling over sequence dimension + let hidden_size = *shape.last().unwrap_or(&384) as usize; + let seq_len = data.len() / hidden_size; + + let mut embedding = vec![0.0f32; hidden_size]; + for i in 0..seq_len { + for j in 0..hidden_size { + embedding[j] += data[i * hidden_size + j]; + } + } + for val in &mut embedding { + *val /= seq_len as f32; + } + + // L2 normalize + let norm: f32 = embedding.iter().map(|x| x * x).sum::().sqrt(); + if norm > 0.0 { + for val in &mut embedding { + *val /= norm; + } + } + + Ok(embedding) + } + + /// Generate embeddings for multiple texts + pub fn embed_batch(&self, texts: &[&str]) -> Result>> { + texts.iter().map(|text| self.embed(text)).collect() + } + + /// Get the embedding dimension + pub fn dimension(&self) -> usize { + self.dimension + } +} + +/// Extract keywords from text using simple frequency analysis +pub fn extract_keywords(text: &str, limit: usize) -> Vec { + use std::collections::HashMap; + + let stop_words: std::collections::HashSet<&str> = [ + "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", + "of", "with", "by", "from", "is", "are", "was", "were", "be", "been", + "being", "have", "has", "had", "do", "does", "did", "will", "would", + "could", "should", "may", "might", "must", "shall", "can", "this", + "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", + "what", "which", "who", "whom", "whose", "where", "when", "why", "how", + "all", "each", "every", "both", "few", "more", "most", "other", "some", + "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", + "very", "just", "also", "now", "here", "there", "then", "once", "if", + ].iter().cloned().collect(); + + let mut word_counts: HashMap = HashMap::new(); + + for word in text.split_whitespace() { + let clean: String = word + .chars() + .filter(|c| c.is_alphanumeric()) + .collect::() + .to_lowercase(); + + if clean.len() > 2 && !stop_words.contains(clean.as_str()) { + *word_counts.entry(clean).or_insert(0) += 1; + } + } + + let mut sorted: Vec<_> = word_counts.into_iter().collect(); + sorted.sort_by(|a, b| b.1.cmp(&a.1)); + + sorted.into_iter() + .take(limit) + .map(|(word, _)| word) + .collect() +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..26e8e90 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,150 @@ +//! OpenBrain MCP - High-performance vector memory for AI agents + +pub mod auth; +pub mod config; +pub mod db; +pub mod embedding; +pub mod migrations; +pub mod tools; +pub mod transport; + +use anyhow::Result; +use axum::{Router, Json, http::StatusCode, middleware}; +use serde_json::json; +use std::sync::Arc; +use tokio::net::TcpListener; +use tower_http::cors::{Any, CorsLayer}; +use tower_http::trace::TraceLayer; +use tracing::{info, error}; + +use crate::auth::auth_middleware; +use crate::config::Config; +use crate::db::Database; +use crate::embedding::EmbeddingEngine; +use crate::transport::McpState; + +/// Service readiness state +#[derive(Clone, Debug, PartialEq)] +pub enum ReadinessState { + Initializing, + Ready, + Failed(String), +} + +/// Shared application state +pub struct AppState { + pub db: Database, + pub embedding: tokio::sync::RwLock>>, + pub config: Config, + pub readiness: tokio::sync::RwLock, +} + +impl AppState { + /// Get embedding engine, returns None if not ready + pub async fn get_embedding(&self) -> Option> { + self.embedding.read().await.clone() + } +} + +/// Health check endpoint - always returns OK if server is running +async fn health_handler() -> Json { + Json(json!({"status": "ok"})) +} + +/// Readiness endpoint - returns 503 if embedding not ready +async fn readiness_handler( + state: axum::extract::State>, +) -> (StatusCode, Json) { + let readiness = state.readiness.read().await.clone(); + match readiness { + ReadinessState::Ready => ( + StatusCode::OK, + Json(json!({"status": "ready", "embedding": true})) + ), + ReadinessState::Initializing => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({"status": "initializing", "embedding": false})) + ), + ReadinessState::Failed(err) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({"status": "failed", "error": err})) + ), + } +} + +/// Run the MCP server +pub async fn run_server(config: Config, db: Database) -> Result<()> { + // Create state with None embedding (will init in background) + let state = Arc::new(AppState { + db, + embedding: tokio::sync::RwLock::new(None), + config: config.clone(), + readiness: tokio::sync::RwLock::new(ReadinessState::Initializing), + }); + + // Spawn background task to initialize embedding with retry + let state_clone = state.clone(); + let embedding_config = config.embedding.clone(); + tokio::spawn(async move { + let max_retries = 3; + let mut attempt = 0; + + loop { + attempt += 1; + info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries); + + match EmbeddingEngine::new(&embedding_config).await { + Ok(engine) => { + let engine = Arc::new(engine); + *state_clone.embedding.write().await = Some(engine); + *state_clone.readiness.write().await = ReadinessState::Ready; + info!("Embedding engine initialized successfully"); + break; + } + Err(e) => { + error!("Failed to init embedding (attempt {}): {:?}", attempt, e); + if attempt >= max_retries { + let err_msg = format!("Failed after {} attempts: {:?}", max_retries, e); + *state_clone.readiness.write().await = ReadinessState::Failed(err_msg); + break; + } + // Exponential backoff: 2s, 4s, 8s... + tokio::time::sleep(tokio::time::Duration::from_secs(2u64.pow(attempt))).await; + } + } + } + }); + + // Create MCP state for SSE transport + let mcp_state = McpState::new(state.clone()); + + // Build router with health/readiness endpoints (no auth required) + let health_router = Router::new() + .route("/health", axum::routing::get(health_handler)) + .route("/ready", axum::routing::get(readiness_handler)) + .with_state(state.clone()); + + // Build MCP router with auth middleware + let mcp_router = transport::mcp_router(mcp_state) + .layer(middleware::from_fn_with_state(state.clone(), auth_middleware)); + + let app = Router::new() + .merge(health_router) + .nest("/mcp", mcp_router) + .layer(TraceLayer::new_for_http()) + .layer( + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any), + ); + + // Start server immediately + let bind_addr = format!("{}:{}", config.server.host, config.server.port); + let listener = TcpListener::bind(&bind_addr).await?; + info!("Server listening on {}", bind_addr); + + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..28e18f6 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,46 @@ +//! OpenBrain MCP Server - High-performance vector memory for AI agents +//! +//! This is the main entry point for the OpenBrain MCP server. + +use anyhow::Result; +use std::env; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; + +use openbrain_mcp::{config::Config, db::Database, migrations, run_server}; + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize tracing + tracing_subscriber::registry() + .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into())) + .with(tracing_subscriber::fmt::layer()) + .init(); + + info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION")); + + // Load configuration + let config = Config::load()?; + info!("Configuration loaded from environment"); + + match env::args().nth(1).as_deref() { + Some("migrate") => { + migrations::run(&config.database).await?; + info!("Database migrations completed successfully"); + return Ok(()); + } + Some(arg) => { + anyhow::bail!("Unknown command: {arg}. Supported commands: migrate"); + } + None => {} + } + + // Initialize database connection pool + let db = Database::new(&config.database).await?; + info!("Database connection pool initialized"); + + // Run the MCP server + run_server(config, db).await?; + + Ok(()) +} diff --git a/src/migrations.rs b/src/migrations.rs new file mode 100644 index 0000000..46de87e --- /dev/null +++ b/src/migrations.rs @@ -0,0 +1,50 @@ +//! Database migrations using refinery. + +use anyhow::{Context, Result}; +use refinery::embed_migrations; +use tokio_postgres::NoTls; +use tracing::info; + +use crate::config::DatabaseConfig; + +embed_migrations!("migrations"); + +/// Apply all pending database migrations. +pub async fn run(config: &DatabaseConfig) -> Result<()> { + let mut pg_config = tokio_postgres::Config::new(); + pg_config.host(&config.host); + pg_config.port(config.port); + pg_config.dbname(&config.name); + pg_config.user(&config.user); + pg_config.password(&config.password); + + let (mut client, connection) = pg_config + .connect(NoTls) + .await + .context("Failed to connect to database for migrations")?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + tracing::error!("Database migration connection error: {}", e); + } + }); + + let report = migrations::runner() + .run_async(&mut client) + .await + .context("Failed to apply database migrations")?; + + if report.applied_migrations().is_empty() { + info!("No database migrations to apply"); + } else { + for migration in report.applied_migrations() { + info!( + version = migration.version(), + name = migration.name(), + "Applied database migration" + ); + } + } + + Ok(()) +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs new file mode 100644 index 0000000..2fd9c18 --- /dev/null +++ b/src/tools/mod.rs @@ -0,0 +1,106 @@ +//! MCP Tools for OpenBrain +//! +//! Provides the core tools for memory storage and retrieval: +//! - `store`: Store a memory with automatic embedding generation +//! - `query`: Query memories by semantic similarity +//! - `purge`: Delete memories by agent_id or time range + +pub mod query; +pub mod store; +pub mod purge; + +use anyhow::Result; +use serde_json::{json, Value}; +use std::sync::Arc; + +use crate::AppState; + +/// Get all tool definitions for MCP tools/list +pub fn get_tool_definitions() -> Vec { + vec![ + json!({ + "name": "store", + "description": "Store a memory with automatic embedding generation and keyword extraction. The memory will be associated with the agent_id for isolated retrieval.", + "inputSchema": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The text content to store as a memory" + }, + "agent_id": { + "type": "string", + "description": "Unique identifier for the agent storing the memory (default: 'default')" + }, + "metadata": { + "type": "object", + "description": "Optional metadata to attach to the memory" + } + }, + "required": ["content"] + } + }), + json!({ + "name": "query", + "description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.", + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query text" + }, + "agent_id": { + "type": "string", + "description": "Agent ID to search within (default: 'default')" + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return (default: 10)" + }, + "threshold": { + "type": "number", + "description": "Minimum similarity threshold 0.0-1.0 (default: 0.5)" + } + }, + "required": ["query"] + } + }), + json!({ + "name": "purge", + "description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.", + "inputSchema": { + "type": "object", + "properties": { + "agent_id": { + "type": "string", + "description": "Agent ID whose memories to delete (required)" + }, + "before": { + "type": "string", + "description": "Optional ISO8601 timestamp - delete memories created before this time" + }, + "confirm": { + "type": "boolean", + "description": "Must be true to confirm deletion" + } + }, + "required": ["agent_id", "confirm"] + } + }) + ] +} + +/// Execute a tool by name with given arguments +pub async fn execute_tool( + state: &Arc, + tool_name: &str, + arguments: Value, +) -> Result { + match tool_name { + "store" => store::execute(state, arguments).await, + "query" => query::execute(state, arguments).await, + "purge" => purge::execute(state, arguments).await, + _ => anyhow::bail!("Unknown tool: {}", tool_name), + } +} diff --git a/src/tools/purge.rs b/src/tools/purge.rs new file mode 100644 index 0000000..d71017b --- /dev/null +++ b/src/tools/purge.rs @@ -0,0 +1,79 @@ +//! Purge Tool - Delete memories by agent_id or time range + +use anyhow::{bail, Context, Result}; +use chrono::DateTime; +use serde_json::Value; +use std::sync::Arc; +use tracing::{info, warn}; + +use crate::AppState; + +/// Execute the purge tool +pub async fn execute(state: &Arc, arguments: Value) -> Result { + // Extract parameters + let agent_id = arguments + .get("agent_id") + .and_then(|v| v.as_str()) + .context("Missing required parameter: agent_id")?; + + let confirm = arguments + .get("confirm") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if !confirm { + bail!("Purge operation requires 'confirm: true' to proceed"); + } + + let before = arguments + .get("before") + .and_then(|v| v.as_str()) + .map(|s| DateTime::parse_from_rfc3339(s)) + .transpose() + .context("Invalid 'before' timestamp format - use ISO8601/RFC3339")? + .map(|dt| dt.with_timezone(&chrono::Utc)); + + // Get current count before purge + let count_before = state + .db + .count_memories(agent_id) + .await + .context("Failed to count memories")?; + + if count_before == 0 { + info!("No memories found for agent '{}'", agent_id); + return Ok(serde_json::json!({ + "success": true, + "agent_id": agent_id, + "deleted": 0, + "message": "No memories found to purge" + }) + .to_string()); + } + + warn!( + "Purging memories for agent '{}' (before={:?})", + agent_id, before + ); + + // Execute purge + let deleted = state + .db + .purge_memories(agent_id, before) + .await + .context("Failed to purge memories")?; + + info!( + "Purged {} memories for agent '{}'", + deleted, agent_id + ); + + Ok(serde_json::json!({ + "success": true, + "agent_id": agent_id, + "deleted": deleted, + "had_before_filter": before.is_some(), + "message": format!("Successfully purged {} memories", deleted) + }) + .to_string()) +} diff --git a/src/tools/query.rs b/src/tools/query.rs new file mode 100644 index 0000000..c9a1a5a --- /dev/null +++ b/src/tools/query.rs @@ -0,0 +1,81 @@ +//! Query Tool - Search memories by semantic similarity + +use anyhow::{Context, Result, anyhow}; +use serde_json::Value; +use std::sync::Arc; +use tracing::info; + +use crate::AppState; + +/// Execute the query tool +pub async fn execute(state: &Arc, arguments: Value) -> Result { + // Get embedding engine, return error if not ready + let embedding_engine = state + .get_embedding() + .await + .ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?; + + // Extract parameters + let query_text = arguments + .get("query") + .and_then(|v| v.as_str()) + .context("Missing required parameter: query")?; + + let agent_id = arguments + .get("agent_id") + .and_then(|v| v.as_str()) + .unwrap_or("default"); + + let limit = arguments + .get("limit") + .and_then(|v| v.as_i64()) + .unwrap_or(10); + + let threshold = arguments + .get("threshold") + .and_then(|v| v.as_f64()) + .unwrap_or(0.5) as f32; + + info!( + "Querying memories for agent '{}': '{}' (limit={}, threshold={})", + agent_id, query_text, limit, threshold + ); + + // Generate embedding for query using Arc + let query_embedding = embedding_engine + .embed(query_text) + .context("Failed to generate query embedding")?; + + // Search database + let matches = state + .db + .query_memories(agent_id, &query_embedding, limit, threshold) + .await + .context("Failed to query memories")?; + + info!("Found {} matching memories", matches.len()); + + // Format results + let results: Vec = matches + .iter() + .map(|m| { + serde_json::json!({ + "id": m.record.id.to_string(), + "content": m.record.content, + "similarity": m.similarity, + "keywords": m.record.keywords, + "metadata": m.record.metadata, + "created_at": m.record.created_at.to_rfc3339() + }) + }) + .collect(); + + Ok(serde_json::json!({ + "success": true, + "agent_id": agent_id, + "query": query_text, + "count": results.len(), + "results": results + }) + .to_string()) +} diff --git a/src/tools/store.rs b/src/tools/store.rs new file mode 100644 index 0000000..2b8f5c5 --- /dev/null +++ b/src/tools/store.rs @@ -0,0 +1,66 @@ +//! Store Tool - Store memories with automatic embeddings + +use anyhow::{Context, Result, anyhow}; +use serde_json::Value; +use std::sync::Arc; +use tracing::info; + +use crate::embedding::extract_keywords; +use crate::AppState; + +/// Execute the store tool +pub async fn execute(state: &Arc, arguments: Value) -> Result { + // Get embedding engine, return error if not ready + let embedding_engine = state + .get_embedding() + .await + .ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?; + + // Extract parameters + let content = arguments + .get("content") + .and_then(|v| v.as_str()) + .context("Missing required parameter: content")?; + + let agent_id = arguments + .get("agent_id") + .and_then(|v| v.as_str()) + .unwrap_or("default"); + + let metadata = arguments + .get("metadata") + .cloned() + .unwrap_or(serde_json::json!({})); + + info!( + "Storing memory for agent '{}': {} chars", + agent_id, + content.len() + ); + + // Generate embedding using Arc + let embedding = embedding_engine + .embed(content) + .context("Failed to generate embedding")?; + + // Extract keywords + let keywords = extract_keywords(content, 10); + + // Store in database + let id = state + .db + .store_memory(agent_id, content, &embedding, &keywords, metadata) + .await + .context("Failed to store memory")?; + + info!("Memory stored with ID: {}", id); + + Ok(serde_json::json!({ + "success": true, + "id": id.to_string(), + "agent_id": agent_id, + "keywords": keywords, + "embedding_dimension": embedding.len() + }) + .to_string()) +} diff --git a/src/transport.rs b/src/transport.rs new file mode 100644 index 0000000..a88b7f0 --- /dev/null +++ b/src/transport.rs @@ -0,0 +1,531 @@ +//! SSE Transport for MCP Protocol +//! +//! Implements Server-Sent Events transport for the Model Context Protocol. + +use axum::{ + extract::{Query, State}, + http::{HeaderMap, StatusCode}, + response::{ + IntoResponse, Response, + sse::{Event, KeepAlive, Sse}, + }, + routing::{get, post}, + Json, Router, +}; +use futures::stream::Stream; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + convert::Infallible, + sync::Arc, + time::Duration, +}; +use tokio::sync::{RwLock, broadcast, mpsc}; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use crate::{AppState, auth, tools}; + +type SessionStore = RwLock>>; + +/// MCP Server State +pub struct McpState { + pub app: Arc, + pub event_tx: broadcast::Sender, + sessions: SessionStore, +} + +impl McpState { + pub fn new(app: Arc) -> Arc { + let (event_tx, _) = broadcast::channel(100); + Arc::new(Self { + app, + event_tx, + sessions: RwLock::new(HashMap::new()), + }) + } + + async fn insert_session( + &self, + session_id: String, + tx: mpsc::Sender, + ) { + self.sessions.write().await.insert(session_id, tx); + } + + async fn remove_session(&self, session_id: &str) { + self.sessions.write().await.remove(session_id); + } + + async fn has_session(&self, session_id: &str) -> bool { + self.sessions.read().await.contains_key(session_id) + } + + async fn send_to_session( + &self, + session_id: &str, + response: &JsonRpcResponse, + ) -> Result<(), SessionSendError> { + let tx = { + let sessions = self.sessions.read().await; + sessions + .get(session_id) + .cloned() + .ok_or(SessionSendError::NotFound)? + }; + + let payload = + serde_json::to_value(response).expect("serializing JSON-RPC response should succeed"); + + if tx.send(payload).await.is_err() { + self.remove_session(session_id).await; + return Err(SessionSendError::Closed); + } + + Ok(()) + } +} + +enum SessionSendError { + NotFound, + Closed, +} + +struct SessionGuard { + state: Arc, + session_id: String, +} + +impl SessionGuard { + fn new(state: Arc, session_id: String) -> Self { + Self { state, session_id } + } +} + +impl Drop for SessionGuard { + fn drop(&mut self) { + let state = self.state.clone(); + let session_id = self.session_id.clone(); + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + state.remove_session(&session_id).await; + }); + } + } +} + +/// MCP Event for SSE streaming +#[derive(Clone, Debug, Serialize)] +pub struct McpEvent { + pub id: String, + pub event_type: String, + pub data: serde_json::Value, +} + +/// MCP JSON-RPC Request +#[derive(Debug, Deserialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, + #[serde(default)] + pub id: Option, + pub method: String, + #[serde(default)] + pub params: serde_json::Value, +} + +/// MCP JSON-RPC Response +#[derive(Debug, Serialize)] +pub struct JsonRpcResponse { + pub jsonrpc: String, + pub id: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Serialize)] +pub struct JsonRpcError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Debug, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +struct PostMessageQuery { + #[serde(default)] + session_id: Option, +} + +/// Create the MCP router +pub fn mcp_router(state: Arc) -> Router { + Router::new() + .route("/sse", get(sse_handler)) + .route("/message", post(message_handler)) + .route("/health", get(health_handler)) + .with_state(state) +} + +/// SSE endpoint for streaming events +async fn sse_handler( + State(state): State>, +) -> Sse>> { + let mut broadcast_rx = state.event_tx.subscribe(); + let (session_tx, mut session_rx) = mpsc::channel(32); + let session_id = Uuid::new_v4().to_string(); + let endpoint = session_message_endpoint(&session_id); + + state.insert_session(session_id.clone(), session_tx).await; + + let stream = async_stream::stream! { + let _session_guard = SessionGuard::new(state.clone(), session_id.clone()); + + // Send endpoint event (required by MCP SSE protocol) + // This tells the client where to POST JSON-RPC messages for this session. + yield Ok(Event::default() + .event("endpoint") + .data(endpoint)); + + // Send initial tools list so Agent Zero knows what's available + let tools_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("initial-tools"), + result: Some(serde_json::json!({ + "tools": tools::get_tool_definitions() + })), + error: None, + }; + + yield Ok(Event::default() + .event("message") + .json_data(&tools_response) + .unwrap()); + + loop { + tokio::select! { + maybe_message = session_rx.recv() => { + match maybe_message { + Some(message) => { + yield Ok(Event::default() + .event("message") + .json_data(&message) + .unwrap()); + } + None => break, + } + } + event = broadcast_rx.recv() => { + match event { + Ok(event) => { + yield Ok(Event::default() + .event(&event.event_type) + .id(&event.id) + .json_data(&event.data) + .unwrap()); + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!("SSE client lagged, missed {} events", n); + } + Err(broadcast::error::RecvError::Closed) => { + break; + } + } + } + } + } + }; + + Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))) +} + +/// Message endpoint for JSON-RPC requests +async fn message_handler( + State(state): State>, + Query(query): Query, + headers: HeaderMap, + Json(request): Json, +) -> Response { + info!("Received MCP request: {}", request.method); + + if let Some(session_id) = query.session_id.as_deref() { + if !state.has_session(session_id).await { + return StatusCode::NOT_FOUND.into_response(); + } + } + + let request = apply_request_context(request, &headers); + let response = dispatch_request(&state, &request).await; + + match query.session_id.as_deref() { + Some(session_id) => route_session_response(&state, session_id, response).await, + None => match response { + Some(response) => Json(response).into_response(), + None => StatusCode::ACCEPTED.into_response(), + }, + } +} + +/// Health check endpoint +async fn health_handler() -> Json { + Json(serde_json::json!({ + "status": "healthy", + "server": "openbrain-mcp", + "version": env!("CARGO_PKG_VERSION") + })) +} + +fn session_message_endpoint(session_id: &str) -> String { + format!("/mcp/message?sessionId={session_id}") +} + +async fn route_session_response( + state: &Arc, + session_id: &str, + response: Option, +) -> Response { + match response { + Some(response) => match state.send_to_session(session_id, &response).await { + Ok(()) => StatusCode::ACCEPTED.into_response(), + Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(), + Err(SessionSendError::Closed) => StatusCode::GONE.into_response(), + }, + None => StatusCode::ACCEPTED.into_response(), + } +} + +fn apply_request_context( + mut request: JsonRpcRequest, + headers: &HeaderMap, +) -> JsonRpcRequest { + if let Some(agent_id) = auth::get_optional_agent_id(headers) { + inject_agent_id(&mut request, &agent_id); + } + + request +} + +fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) { + if request.method != "tools/call" { + return; + } + + if !request.params.is_object() { + request.params = serde_json::json!({}); + } + + let params = request + .params + .as_object_mut() + .expect("params should be an object"); + let arguments = params + .entry("arguments") + .or_insert_with(|| serde_json::json!({})); + + if !arguments.is_object() { + *arguments = serde_json::json!({}); + } + + arguments + .as_object_mut() + .expect("arguments should be an object") + .entry("agent_id".to_string()) + .or_insert_with(|| serde_json::json!(agent_id)); +} + +async fn dispatch_request( + state: &Arc, + request: &JsonRpcRequest, +) -> Option { + match request.method.as_str() { + "initialize" => request + .id + .clone() + .map(|id| success_response(id, initialize_result())), + "ping" => request + .id + .clone() + .map(|id| success_response(id, serde_json::json!({}))), + "tools/list" => request.id.clone().map(|id| { + success_response( + id, + serde_json::json!({ + "tools": tools::get_tool_definitions() + }), + ) + }), + "tools/call" => handle_tools_call(state, request).await, + "notifications/initialized" => { + info!("Received MCP initialized notification"); + None + } + _ => request.id.clone().map(|id| { + error_response( + id, + -32601, + format!("Method not found: {}", request.method), + None, + ) + }), + } +} + +fn initialize_result() -> serde_json::Value { + serde_json::json!({ + "protocolVersion": "2024-11-05", + "serverInfo": { + "name": "openbrain-mcp", + "version": env!("CARGO_PKG_VERSION") + }, + "capabilities": { + "tools": { + "listChanged": false + } + } + }) +} + +fn success_response( + id: serde_json::Value, + result: serde_json::Value, +) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: Some(result), + error: None, + } +} + +fn error_response( + id: serde_json::Value, + code: i32, + message: String, + data: Option, +) -> JsonRpcResponse { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(JsonRpcError { + code, + message, + data, + }), + } +} + +/// Handle tools/call request +async fn handle_tools_call( + state: &Arc, + request: &JsonRpcRequest, +) -> Option { + let params = &request.params; + let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let arguments = params + .get("arguments") + .cloned() + .unwrap_or(serde_json::json!({})); + + let result = match tools::execute_tool(&state.app, tool_name, arguments).await { + Ok(result) => Ok(serde_json::json!({ + "content": [{ + "type": "text", + "text": result + }] + })), + Err(e) => { + let full_error = format!("{:#}", e); + error!("Tool execution error: {}", full_error); + Err(JsonRpcError { + code: -32000, + message: full_error, + data: None, + }) + } + }; + + match (request.id.clone(), result) { + (Some(id), Ok(result)) => Some(success_response(id, result)), + (Some(id), Err(error)) => Some(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(error), + }), + (None, Ok(_)) => None, + (None, Err(error)) => { + error!( + "Tool execution failed for notification '{}': {}", + request.method, error.message + ); + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn injects_agent_id_when_missing_from_tool_arguments() { + let mut request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(serde_json::json!("1")), + method: "tools/call".to_string(), + params: serde_json::json!({ + "name": "query", + "arguments": { + "query": "editor preferences" + } + }), + }; + + inject_agent_id(&mut request, "agent-from-header"); + + assert_eq!( + request + .params + .get("arguments") + .and_then(|value| value.get("agent_id")) + .and_then(|value| value.as_str()), + Some("agent-from-header") + ); + } + + #[test] + fn preserves_explicit_agent_id() { + let mut request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Some(serde_json::json!("1")), + method: "tools/call".to_string(), + params: serde_json::json!({ + "name": "query", + "arguments": { + "agent_id": "explicit-agent" + } + }), + }; + + inject_agent_id(&mut request, "agent-from-header"); + + assert_eq!( + request + .params + .get("arguments") + .and_then(|value| value.get("agent_id")) + .and_then(|value| value.as_str()), + Some("explicit-agent") + ); + } + + #[test] + fn session_endpoint_uses_camel_case_query_param() { + assert_eq!( + session_message_endpoint("abc123"), + "/mcp/message?sessionId=abc123" + ); + } +} diff --git a/tests/e2e_mcp.rs b/tests/e2e_mcp.rs new file mode 100644 index 0000000..11ac037 --- /dev/null +++ b/tests/e2e_mcp.rs @@ -0,0 +1,873 @@ +use serde_json::{json, Value}; +use std::process::{Command, Stdio}; +use std::time::Duration; +use tokio_postgres::NoTls; +use uuid::Uuid; + +fn base_url() -> String { + std::env::var("OPENBRAIN_E2E_BASE_URL").unwrap_or_else(|_| "http://127.0.0.1:3100".to_string()) +} + +fn api_key() -> Option { + std::env::var("OPENBRAIN_E2E_API_KEY").ok() + .or_else(|| std::env::var("OPENBRAIN__AUTH__API_KEYS").ok()) + .map(|keys| keys.split(',').next().unwrap_or("").trim().to_string()) + .filter(|k| !k.is_empty()) +} + +fn db_url() -> String { + let host = std::env::var("OPENBRAIN__DATABASE__HOST").unwrap_or_else(|_| "localhost".to_string()); + let port = std::env::var("OPENBRAIN__DATABASE__PORT").unwrap_or_else(|_| "5432".to_string()); + let name = std::env::var("OPENBRAIN__DATABASE__NAME").unwrap_or_else(|_| "openbrain".to_string()); + let user = std::env::var("OPENBRAIN__DATABASE__USER").unwrap_or_else(|_| "openbrain_svc".to_string()); + let password = std::env::var("OPENBRAIN__DATABASE__PASSWORD") + .unwrap_or_else(|_| "your_secure_password_here".to_string()); + + format!("host={host} port={port} dbname={name} user={user} password={password}") +} + +async fn ensure_schema() { + let (client, connection) = tokio_postgres::connect(&db_url(), NoTls) + .await + .expect("connect to postgres for e2e schema setup"); + + tokio::spawn(async move { + if let Err(e) = connection.await { + eprintln!("postgres connection error: {e}"); + } + }); + + let vector_exists = client + .query_one("SELECT to_regtype('vector')::text", &[]) + .await + .expect("query vector type availability") + .get::<_, Option>(0) + .is_some(); + + if !vector_exists { + if let Err(e) = client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[]).await { + panic!( + "pgvector extension is not available for this PostgreSQL instance: {e}. \ +Install pgvector for your active PostgreSQL major version, then run: CREATE EXTENSION vector;" + ); + } + } + + client + .batch_execute( + r#" + CREATE TABLE IF NOT EXISTS memories ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + agent_id VARCHAR(255) NOT NULL, + content TEXT NOT NULL, + embedding vector(384) NOT NULL, + keywords TEXT[] DEFAULT '{}', + metadata JSONB DEFAULT '{}', + created_at TIMESTAMPTZ DEFAULT NOW() + ); + CREATE INDEX IF NOT EXISTS idx_memories_agent ON memories(agent_id); + CREATE INDEX IF NOT EXISTS idx_memories_embedding ON memories + USING hnsw (embedding vector_cosine_ops); + "#, + ) + .await + .expect("create memories table/indexes for e2e"); +} + +async fn wait_until_ready(client: &reqwest::Client, base: &str) { + for _ in 0..60 { + let resp = client.get(format!("{base}/ready")).send().await; + if let Ok(resp) = resp { + if resp.status().is_success() { + let body: Value = resp.json().await.expect("/ready JSON response"); + if body.get("status").and_then(Value::as_str) == Some("ready") { + return; + } + } + } + tokio::time::sleep(Duration::from_millis(500)).await; + } + + panic!("Server did not become ready at {base}/ready within timeout"); +} + +async fn call_jsonrpc(client: &reqwest::Client, base: &str, request: Value) -> Value { + let mut req_builder = client + .post(format!("{base}/mcp/message")) + .json(&request); + + // Add API key header if available + if let Some(key) = api_key() { + req_builder = req_builder.header("X-API-Key", key); + } + + req_builder + .send() + .await + .expect("JSON-RPC HTTP request") + .json() + .await + .expect("JSON-RPC response body") +} + +/// Make an authenticated GET request to an MCP endpoint +async fn get_mcp_endpoint(client: &reqwest::Client, base: &str, path: &str) -> reqwest::Response { + let mut req_builder = client.get(format!("{base}{path}")); + + if let Some(key) = api_key() { + req_builder = req_builder.header("X-API-Key", key); + } + + req_builder.send().await.expect(&format!("GET {path}")) +} + +async fn read_sse_event( + response: &mut reqwest::Response, + buffer: &mut String, +) -> Option<(Option, String)> { + loop { + *buffer = buffer.replace("\r\n", "\n"); + if let Some(idx) = buffer.find("\n\n") { + let raw_event = buffer[..idx].to_string(); + *buffer = buffer[idx + 2..].to_string(); + + let mut event_type = None; + let mut data_lines = Vec::new(); + for line in raw_event.lines() { + if let Some(value) = line.strip_prefix("event:") { + event_type = Some(value.trim().to_string()); + } else if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.trim_start().to_string()); + } + } + + return Some((event_type, data_lines.join("\n"))); + } + + let chunk = response + .chunk() + .await + .expect("read SSE chunk")?; + buffer.push_str(std::str::from_utf8(&chunk).expect("SSE chunk should be valid UTF-8")); + } +} + +async fn call_tool( + client: &reqwest::Client, + base: &str, + tool_name: &str, + arguments: Value, +) -> Value { + let request = json!({ + "jsonrpc": "2.0", + "id": Uuid::new_v4().to_string(), + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + }); + + let response = call_jsonrpc(client, base, request).await; + + if let Some(error) = response.get("error") { + panic!("tools/call for '{tool_name}' failed: {error}"); + } + + let text_payload = response + .get("result") + .and_then(|r| r.get("content")) + .and_then(Value::as_array) + .and_then(|arr| arr.first()) + .and_then(|item| item.get("text")) + .and_then(Value::as_str) + .expect("result.content[0].text payload"); + + serde_json::from_str(text_payload).expect("tool text payload to be valid JSON") +} + +#[tokio::test] +async fn e2e_store_query_purge_roundtrip() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + ensure_schema().await; + wait_until_ready(&client, &base).await; + + let agent_id = format!("e2e-agent-{}", Uuid::new_v4()); + let memory_text = format!( + "E2E memory {}: user prefers dark theme and vim bindings", + Uuid::new_v4() + ); + + // Ensure clean slate for this test agent. + let _ = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_id, "confirm": true }), + ) + .await; + + let store_result = call_tool( + &client, + &base, + "store", + json!({ + "agent_id": agent_id, + "content": memory_text, + "metadata": { "source": "e2e-test", "suite": "store-query-purge" } + }), + ) + .await; + + assert_eq!( + store_result.get("success").and_then(Value::as_bool), + Some(true), + "store should succeed" + ); + + let query_result = call_tool( + &client, + &base, + "query", + json!({ + "agent_id": agent_id, + "query": "What are the user's editor preferences?", + "limit": 5, + "threshold": 0.0 + }), + ) + .await; + + let count = query_result + .get("count") + .and_then(Value::as_u64) + .expect("query.count"); + assert!(count >= 1, "query should return at least one stored memory"); + + let results = query_result + .get("results") + .and_then(Value::as_array) + .expect("query.results"); + let found_stored_content = results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|content| content == memory_text) + .unwrap_or(false) + }); + assert!( + found_stored_content, + "query results should include the content stored by this test" + ); + + let purge_result = call_tool( + &client, + &base, + "purge", + json!({ "agent_id": agent_id, "confirm": true }), + ) + .await; + + let deleted = purge_result + .get("deleted") + .and_then(Value::as_u64) + .expect("purge.deleted"); + assert!(deleted >= 1, "purge should delete at least one memory"); + + let query_after_purge = call_tool( + &client, + &base, + "query", + json!({ + "agent_id": agent_id, + "query": "dark theme vim bindings", + "limit": 5, + "threshold": 0.0 + }), + ) + .await; + + assert_eq!( + query_after_purge.get("count").and_then(Value::as_u64), + Some(0), + "query after purge should return no memories for this agent" + ); +} + +#[tokio::test] +async fn e2e_transport_tools_list_and_unknown_method() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let list_response = call_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "tools-list-1", + "method": "tools/list", + "params": {} + }), + ) + .await; + + let tools = list_response + .get("result") + .and_then(|r| r.get("tools")) + .and_then(Value::as_array) + .expect("tools/list result.tools"); + + let tool_names: Vec<&str> = tools + .iter() + .filter_map(|t| t.get("name").and_then(Value::as_str)) + .collect(); + + assert!(tool_names.contains(&"store"), "tools/list should include store"); + assert!(tool_names.contains(&"query"), "tools/list should include query"); + assert!(tool_names.contains(&"purge"), "tools/list should include purge"); + + let unknown_response = call_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "unknown-1", + "method": "not/a/real/method", + "params": {} + }), + ) + .await; + + assert_eq!( + unknown_response + .get("error") + .and_then(|e| e.get("code")) + .and_then(Value::as_i64), + Some(-32601), + "unknown method should return Method Not Found" + ); +} + +#[tokio::test] +async fn e2e_purge_requires_confirm_flag() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let response = call_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "purge-confirm-1", + "method": "tools/call", + "params": { + "name": "purge", + "arguments": { + "agent_id": format!("e2e-agent-{}", Uuid::new_v4()), + "confirm": false + } + } + }), + ) + .await; + + let error_message = response + .get("error") + .and_then(|e| e.get("message")) + .and_then(Value::as_str) + .expect("purge without confirm should return JSON-RPC error"); + + assert!( + error_message.contains("confirm: true") || error_message.contains("confirm"), + "purge error should explain confirmation requirement" + ); +} + +#[tokio::test] +async fn e2e_query_isolated_by_agent_id() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + ensure_schema().await; + wait_until_ready(&client, &base).await; + + let agent_a = format!("e2e-agent-a-{}", Uuid::new_v4()); + let agent_b = format!("e2e-agent-b-{}", Uuid::new_v4()); + let a_text = format!("A {} prefers dark mode", Uuid::new_v4()); + let b_text = format!("B {} prefers light mode", Uuid::new_v4()); + + let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await; + let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await; + + let _ = call_tool( + &client, + &base, + "store", + json!({ "agent_id": agent_a, "content": a_text, "metadata": {"suite": "agent-isolation"} }), + ) + .await; + + let _ = call_tool( + &client, + &base, + "store", + json!({ "agent_id": agent_b, "content": b_text, "metadata": {"suite": "agent-isolation"} }), + ) + .await; + + let query_a = call_tool( + &client, + &base, + "query", + json!({ + "agent_id": agent_a, + "query": "mode preference", + "limit": 10, + "threshold": 0.0 + }), + ) + .await; + + let results = query_a + .get("results") + .and_then(Value::as_array) + .expect("query results"); + + let has_a = results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|s| s == a_text) + .unwrap_or(false) + }); + let has_b = results.iter().any(|item| { + item.get("content") + .and_then(Value::as_str) + .map(|s| s == b_text) + .unwrap_or(false) + }); + + assert!(has_a, "agent A query should include agent A memory"); + assert!(!has_b, "agent A query must not include agent B memory"); + + let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_a, "confirm": true })).await; + let _ = call_tool(&client, &base, "purge", json!({ "agent_id": agent_b, "confirm": true })).await; +} + +#[tokio::test] +async fn e2e_initialize_contract() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + let response = call_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "init-1", + "method": "initialize", + "params": {} + }), + ) + .await; + + let result = response.get("result").expect("initialize result"); + assert_eq!( + result.get("protocolVersion").and_then(Value::as_str), + Some("2024-11-05") + ); + assert_eq!( + result + .get("serverInfo") + .and_then(|v| v.get("name")) + .and_then(Value::as_str), + Some("openbrain-mcp") + ); +} + +#[tokio::test] +async fn e2e_initialized_notification_is_accepted() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let mut request = client + .post(format!("{base}/mcp/message")) + .json(&json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {} + })); + + if let Some(key) = api_key() { + request = request.header("X-API-Key", key); + } + + let response = request.send().await.expect("initialized notification request"); + assert_eq!( + response.status(), + reqwest::StatusCode::ACCEPTED, + "notifications/initialized should be accepted without a JSON-RPC response body" + ); +} + +#[tokio::test] +async fn e2e_sse_session_routes_posted_response() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let mut sse_request = client + .get(format!("{base}/mcp/sse")) + .header("Accept", "text/event-stream"); + + if let Some(key) = api_key() { + sse_request = sse_request.header("X-API-Key", key); + } + + let mut sse_response = sse_request.send().await.expect("GET /mcp/sse"); + assert_eq!(sse_response.status(), reqwest::StatusCode::OK); + assert!( + sse_response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map(|value| value.starts_with("text/event-stream")) + .unwrap_or(false), + "SSE endpoint should return text/event-stream" + ); + + let mut buffer = String::new(); + let (event_type, endpoint) = tokio::time::timeout( + Duration::from_secs(10), + read_sse_event(&mut sse_response, &mut buffer), + ) + .await + .expect("timed out waiting for SSE endpoint event") + .expect("SSE endpoint event"); + + assert_eq!(event_type.as_deref(), Some("endpoint")); + assert!( + endpoint.contains("/mcp/message?sessionId="), + "endpoint event should advertise a session-specific message URL" + ); + + let post_url = if endpoint.starts_with("http://") || endpoint.starts_with("https://") { + endpoint + } else { + format!("{base}{endpoint}") + }; + + let mut post_request = client + .post(post_url) + .json(&json!({ + "jsonrpc": "2.0", + "id": "sse-tools-list-1", + "method": "tools/list", + "params": {} + })); + + if let Some(key) = api_key() { + post_request = post_request.header("X-API-Key", key); + } + + let post_response = post_request.send().await.expect("POST session message"); + assert_eq!( + post_response.status(), + reqwest::StatusCode::ACCEPTED, + "session-bound POST should be accepted and routed over SSE" + ); + + let (event_type, payload) = tokio::time::timeout( + Duration::from_secs(10), + read_sse_event(&mut sse_response, &mut buffer), + ) + .await + .expect("timed out waiting for SSE message event") + .expect("SSE message event"); + + assert_eq!(event_type.as_deref(), Some("message")); + + let message: Value = serde_json::from_str(&payload).expect("SSE payload should be valid JSON"); + assert_eq!( + message.get("id").and_then(Value::as_str), + Some("sse-tools-list-1") + ); + assert!( + message + .get("result") + .and_then(|value| value.get("tools")) + .and_then(Value::as_array) + .map(|tools| !tools.is_empty()) + .unwrap_or(false), + "SSE-routed tools/list response should include tool definitions" + ); +} + +#[tokio::test] +async fn e2e_health_endpoints() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + // Root health endpoint - no auth required + let root_health: Value = client + .get(format!("{base}/health")) + .send() + .await + .expect("GET /health") + .json() + .await + .expect("/health JSON"); + + assert_eq!( + root_health.get("status").and_then(Value::as_str), + Some("ok"), + "/health should report server liveness" + ); + + // MCP health endpoint - requires auth if enabled + let mcp_health: Value = get_mcp_endpoint(&client, &base, "/mcp/health") + .await + .json() + .await + .expect("/mcp/health JSON"); + + assert_eq!( + mcp_health.get("status").and_then(Value::as_str), + Some("healthy"), + "/mcp/health should report MCP transport health" + ); + assert_eq!( + mcp_health.get("server").and_then(Value::as_str), + Some("openbrain-mcp") + ); +} + +#[tokio::test] +async fn e2e_store_requires_content() { + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + wait_until_ready(&client, &base).await; + + let response = call_jsonrpc( + &client, + &base, + json!({ + "jsonrpc": "2.0", + "id": "store-missing-content-1", + "method": "tools/call", + "params": { + "name": "store", + "arguments": { + "agent_id": format!("e2e-agent-{}", Uuid::new_v4()), + "metadata": {"suite": "validation"} + } + } + }), + ) + .await; + + let message = response + .get("error") + .and_then(|e| e.get("message")) + .and_then(Value::as_str) + .expect("store missing content should return an error message"); + + assert!( + message.contains("Missing required parameter: content"), + "store validation should mention missing content" + ); +} + +#[tokio::test] +async fn e2e_auth_rejection_without_key() { + // This test only runs when auth is expected to be enabled + let auth_enabled = std::env::var("OPENBRAIN__AUTH__ENABLED") + .map(|v| v == "true") + .unwrap_or(false); + + if !auth_enabled { + println!("Skipping auth rejection test - OPENBRAIN__AUTH__ENABLED is not true"); + return; + } + + let base = base_url(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + // Make request WITHOUT API key + let response = client + .post(format!("{base}/mcp/message")) + .json(&json!({ + "jsonrpc": "2.0", + "id": "auth-test-1", + "method": "tools/list", + "params": {} + })) + .send() + .await + .expect("HTTP request"); + + assert_eq!( + response.status().as_u16(), + 401, + "Request without API key should return 401 Unauthorized" + ); +} + +fn pick_free_port() -> u16 { + std::net::TcpListener::bind("127.0.0.1:0") + .expect("bind ephemeral port") + .local_addr() + .expect("local addr") + .port() +} + +async fn wait_for_status(url: &str, expected_status: reqwest::StatusCode) { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(2)) + .build() + .expect("reqwest client"); + + for _ in 0..80 { + if let Ok(resp) = client.get(url).send().await { + if resp.status() == expected_status { + return; + } + } + tokio::time::sleep(Duration::from_millis(250)).await; + } + + panic!("Timed out waiting for {url} to return status {expected_status}"); +} + +#[tokio::test] +async fn e2e_auth_enabled_accepts_test_key() { + ensure_schema().await; + + let port = pick_free_port(); + let base = format!("http://127.0.0.1:{port}"); + let test_key = "e2e-test-key-123"; + + let mut server = Command::new(env!("CARGO_BIN_EXE_openbrain-mcp")) + .current_dir(env!("CARGO_MANIFEST_DIR")) + .env("OPENBRAIN__SERVER__PORT", port.to_string()) + .env("OPENBRAIN__AUTH__ENABLED", "true") + .env("OPENBRAIN__AUTH__API_KEYS", test_key) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("spawn openbrain-mcp for auth-enabled e2e test"); + + wait_for_status(&format!("{base}/ready"), reqwest::StatusCode::OK).await; + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(20)) + .build() + .expect("reqwest client"); + + let request = json!({ + "jsonrpc": "2.0", + "id": "auth-enabled-1", + "method": "tools/list", + "params": {} + }); + + let unauthorized = client + .post(format!("{base}/mcp/message")) + .json(&request) + .send() + .await + .expect("unauthorized request"); + assert_eq!( + unauthorized.status(), + reqwest::StatusCode::UNAUTHORIZED, + "request without key should be rejected when auth is enabled" + ); + + let authorized: Value = client + .post(format!("{base}/mcp/message")) + .header("X-API-Key", test_key) + .json(&request) + .send() + .await + .expect("authorized request") + .json() + .await + .expect("authorized JSON response"); + + assert!(authorized.get("error").is_none(), "valid key should not return JSON-RPC error"); + assert!( + authorized + .get("result") + .and_then(|r| r.get("tools")) + .and_then(Value::as_array) + .map(|tools| !tools.is_empty()) + .unwrap_or(false), + "authorized tools/list should return tool definitions" + ); + + let bearer_authorized: Value = client + .post(format!("{base}/mcp/message")) + .header("Authorization", format!("Bearer {test_key}")) + .json(&request) + .send() + .await + .expect("bearer-authorized request") + .json() + .await + .expect("bearer-authorized JSON response"); + + assert!( + bearer_authorized.get("error").is_none(), + "valid bearer token should not return JSON-RPC error" + ); + assert!( + bearer_authorized + .get("result") + .and_then(|r| r.get("tools")) + .and_then(Value::as_array) + .map(|tools| !tools.is_empty()) + .unwrap_or(false), + "authorized bearer tools/list should return tool definitions" + ); + + let _ = server.kill(); + let _ = server.wait(); +}