Initial public release

This commit is contained in:
Agent Zero
2026-03-07 13:41:36 -05:00
commit 774982dc5a
22 changed files with 3517 additions and 0 deletions

125
src/auth.rs Normal file
View File

@@ -0,0 +1,125 @@
//! Authentication module for OpenBrain MCP
//!
//! Provides API key-based authentication for securing the MCP endpoints.
use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode, header::AUTHORIZATION},
middleware::Next,
response::Response,
};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use tracing::warn;
use crate::AppState;
/// Hash an API key for secure comparison
pub fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
hex::encode(hasher.finalize())
}
/// Middleware for API key authentication
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
// Skip auth if disabled
if !state.config.auth.enabled {
return Ok(next.run(request).await);
}
let api_key = extract_api_key(request.headers());
match api_key {
Some(key) => {
// Check if key is valid
let key_hash = hash_api_key(&key);
let valid = state
.config
.auth
.api_keys
.iter()
.any(|k| hash_api_key(k) == key_hash);
if valid {
Ok(next.run(request).await)
} else {
warn!("Invalid API key or bearer token attempted");
Err(StatusCode::UNAUTHORIZED)
}
}
None => {
warn!("Missing API key or bearer token in request");
Err(StatusCode::UNAUTHORIZED)
}
}
}
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
.or_else(|| {
headers
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|value| {
let (scheme, token) = value.split_once(' ')?;
scheme
.eq_ignore_ascii_case("bearer")
.then_some(token.trim())
})
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
})
}
pub fn get_optional_agent_id(headers: &HeaderMap) -> Option<String> {
headers
.get("X-Agent-ID")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
/// Extract agent ID from request headers or default
pub fn get_agent_id(request: &Request) -> String {
get_optional_agent_id(request.headers())
.unwrap_or_else(|| "default".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{HeaderValue, header::AUTHORIZATION};
#[test]
fn extracts_api_key_from_bearer_header() {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_static("Bearer test-token"),
);
assert_eq!(extract_api_key(&headers).as_deref(), Some("test-token"));
}
#[test]
fn extracts_agent_id_from_header() {
let mut headers = HeaderMap::new();
headers.insert("X-Agent-ID", HeaderValue::from_static("agent-zero"));
assert_eq!(
get_optional_agent_id(&headers).as_deref(),
Some("agent-zero")
);
}
}

146
src/config.rs Normal file
View File

@@ -0,0 +1,146 @@
//! Configuration management for OpenBrain MCP
//!
//! Loads configuration from environment variables with sensible defaults.
use anyhow::Result;
use serde::{Deserialize, Deserializer};
/// Main configuration structure
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub embedding: EmbeddingConfig,
pub auth: AuthConfig,
}
/// Server configuration
#[derive(Debug, Clone, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
}
/// Database configuration
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseConfig {
pub host: String,
#[serde(default = "default_db_port")]
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
#[serde(default = "default_pool_size")]
pub pool_size: usize,
}
/// Embedding engine configuration
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_model_path")]
pub model_path: String,
#[serde(default = "default_embedding_dim")]
pub dimension: usize,
}
/// Authentication configuration
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
#[serde(default = "default_auth_enabled")]
pub enabled: bool,
#[serde(default, deserialize_with = "deserialize_api_keys")]
pub api_keys: Vec<String>,
}
/// Deserialize API keys from either a comma-separated string or a Vec<String>
fn deserialize_api_keys<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
// Try to deserialize as a string first, then as a Vec
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrVec {
String(String),
Vec(Vec<String>),
}
match Option::<StringOrVec>::deserialize(deserializer)? {
Some(StringOrVec::String(s)) => {
Ok(s.split(',')
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.collect())
}
Some(StringOrVec::Vec(v)) => Ok(v),
None => Ok(Vec::new()),
}
}
// Default value functions
fn default_host() -> String { "0.0.0.0".to_string() }
fn default_port() -> u16 { 3100 }
fn default_db_port() -> u16 { 5432 }
fn default_pool_size() -> usize { 10 }
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
fn default_embedding_dim() -> usize { 384 }
fn default_auth_enabled() -> bool { false }
impl Config {
/// Load configuration from environment variables
pub fn load() -> Result<Self> {
// Load .env file if present
dotenvy::dotenv().ok();
let config = config::Config::builder()
// Server settings
.set_default("server.host", default_host())?
.set_default("server.port", default_port() as i64)?
// Database settings
.set_default("database.port", default_db_port() as i64)?
.set_default("database.pool_size", default_pool_size() as i64)?
// Embedding settings
.set_default("embedding.model_path", default_model_path())?
.set_default("embedding.dimension", default_embedding_dim() as i64)?
// Auth settings
.set_default("auth.enabled", default_auth_enabled())?
// Load from environment with OPENBRAIN_ prefix
.add_source(
config::Environment::with_prefix("OPENBRAIN")
.separator("__")
.try_parsing(true),
)
.build()?;
Ok(config.try_deserialize()?)
}
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig {
host: default_host(),
port: default_port(),
},
database: DatabaseConfig {
host: "localhost".to_string(),
port: default_db_port(),
name: "openbrain".to_string(),
user: "openbrain_svc".to_string(),
password: String::new(),
pool_size: default_pool_size(),
},
embedding: EmbeddingConfig {
model_path: default_model_path(),
dimension: default_embedding_dim(),
},
auth: AuthConfig {
enabled: default_auth_enabled(),
api_keys: Vec::new(),
},
}
}
}

176
src/db.rs Normal file
View File

@@ -0,0 +1,176 @@
//! Database module for PostgreSQL with pgvector support
//!
//! Provides connection pooling and query helpers for vector operations.
use anyhow::{Context, Result};
use deadpool_postgres::{Config, Pool, Runtime};
use pgvector::Vector;
use tokio_postgres::NoTls;
use tracing::info;
use uuid::Uuid;
use crate::config::DatabaseConfig;
/// Database wrapper with connection pool
#[derive(Clone)]
pub struct Database {
pool: Pool,
}
/// A memory record stored in the database
#[derive(Debug, Clone)]
pub struct MemoryRecord {
pub id: Uuid,
pub agent_id: String,
pub content: String,
pub embedding: Vec<f32>,
pub keywords: Vec<String>,
pub metadata: serde_json::Value,
pub created_at: chrono::DateTime<chrono::Utc>,
}
/// Query result with similarity score
#[derive(Debug, Clone)]
pub struct MemoryMatch {
pub record: MemoryRecord,
pub similarity: f32,
}
impl Database {
/// Create a new database connection pool
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
let mut cfg = Config::new();
cfg.host = Some(config.host.clone());
cfg.port = Some(config.port);
cfg.dbname = Some(config.name.clone());
cfg.user = Some(config.user.clone());
cfg.password = Some(config.password.clone());
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.context("Failed to create database pool")?;
// Test connection
let client = pool.get().await.context("Failed to get database connection")?;
client
.simple_query("SELECT 1")
.await
.context("Failed to execute test query")?;
info!("Database connection pool created with {} connections", config.pool_size);
Ok(Self { pool })
}
/// Store a memory record
pub async fn store_memory(
&self,
agent_id: &str,
content: &str,
embedding: &[f32],
keywords: &[String],
metadata: serde_json::Value,
) -> Result<Uuid> {
let client = self.pool.get().await?;
let id = Uuid::new_v4();
let vector = Vector::from(embedding.to_vec());
client
.execute(
r#"
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
)
.await
.context("Failed to store memory")?;
Ok(id)
}
/// Query memories by vector similarity
pub async fn query_memories(
&self,
agent_id: &str,
embedding: &[f32],
limit: i64,
threshold: f32,
) -> Result<Vec<MemoryMatch>> {
let client = self.pool.get().await?;
let vector = Vector::from(embedding.to_vec());
let threshold_f64 = threshold as f64;
let rows = client
.query(
r#"
SELECT
id, agent_id, content, keywords, metadata, created_at,
(1 - (embedding <=> $1))::real AS similarity
FROM memories
WHERE agent_id = $2
AND (1 - (embedding <=> $1)) >= $3
ORDER BY embedding <=> $1
LIMIT $4
"#,
&[&vector, &agent_id, &threshold_f64, &limit],
)
.await
.context("Failed to query memories")?;
let matches = rows
.iter()
.map(|row| MemoryMatch {
record: MemoryRecord {
id: row.get("id"),
agent_id: row.get("agent_id"),
content: row.get("content"),
// Query responses do not include raw embedding payloads.
embedding: Vec::new(),
keywords: row.get("keywords"),
metadata: row.get("metadata"),
created_at: row.get("created_at"),
},
similarity: row.get("similarity"),
})
.collect();
Ok(matches)
}
/// Delete memories by agent_id and optional filters
pub async fn purge_memories(
&self,
agent_id: &str,
before: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<u64> {
let client = self.pool.get().await?;
let count = if let Some(before_ts) = before {
client
.execute(
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
&[&agent_id, &before_ts],
)
.await?
} else {
client
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
.await?
};
Ok(count)
}
/// Get memory count for an agent
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
let client = self.pool.get().await?;
let row = client
.query_one(
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1",
&[&agent_id],
)
.await?;
Ok(row.get("count"))
}
}

245
src/embedding.rs Normal file
View File

@@ -0,0 +1,245 @@
//! Embedding engine using local ONNX models
use anyhow::Result;
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use std::path::{Path, PathBuf};
use std::sync::Once;
use tokenizers::Tokenizer;
use tracing::info;
use crate::config::EmbeddingConfig;
static ORT_INIT: Once = Once::new();
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
fn init_ort_sync(dylib_path: &str) -> Result<()> {
info!("Initializing ONNX Runtime from: {}", dylib_path);
let mut init_error: Option<String> = None;
ORT_INIT.call_once(|| {
info!("ORT_INIT.call_once - starting initialization");
match ort::init_from(dylib_path) {
Ok(builder) => {
info!("ort::init_from succeeded, calling commit()");
let committed = builder.commit();
info!("commit() returned: {}", committed);
if !committed {
init_error = Some("ONNX Runtime commit returned false".to_string());
}
}
Err(e) => {
let err_msg = format!("ONNX Runtime init_from failed: {:?}", e);
info!("{}", err_msg);
init_error = Some(err_msg);
}
}
info!("ORT_INIT.call_once - finished");
});
// Note: init_error won't be set if ORT_INIT was already called
// This is fine - we only initialize once
if let Some(err) = init_error {
return Err(anyhow::anyhow!("{}", err));
}
info!("ONNX Runtime initialization complete");
Ok(())
}
/// Resolve ONNX Runtime dylib path from env var or common local install locations.
fn resolve_ort_dylib_path() -> Result<String> {
if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
if Path::new(&path).exists() {
return Ok(path);
}
return Err(anyhow::anyhow!(
"ORT_DYLIB_PATH is set but file does not exist: {}",
path
));
}
let candidates = [
"/opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib",
"/usr/local/opt/onnxruntime/lib/libonnxruntime.dylib",
];
for candidate in candidates {
if Path::new(candidate).exists() {
return Ok(candidate.to_string());
}
}
Err(anyhow::anyhow!(
"ORT_DYLIB_PATH environment variable not set and ONNX Runtime dylib not found. \
Set ORT_DYLIB_PATH to your libonnxruntime.dylib path (for example: /opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib)."
))
}
pub struct EmbeddingEngine {
session: std::sync::Mutex<Session>,
tokenizer: Tokenizer,
dimension: usize,
}
impl EmbeddingEngine {
/// Create a new embedding engine
pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
let dylib_path = resolve_ort_dylib_path()?;
let model_path = PathBuf::from(&config.model_path);
let dimension = config.dimension;
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
// Use spawn_blocking to avoid blocking the async runtime
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
// Initialize ONNX Runtime first
init_ort_sync(&dylib_path)?;
info!("Creating ONNX session...");
// Load ONNX model with ort 2.0 API
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
.commit_from_file(model_path.join("model.onnx"))
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
info!("ONNX model loaded, loading tokenizer...");
// Load tokenizer
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
info!("Tokenizer loaded successfully");
Ok((session, tokenizer))
}).await
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
info!(
"Embedding engine initialized: model={}, dimension={}",
config.model_path, dimension
);
Ok(Self {
session: std::sync::Mutex::new(session),
tokenizer,
dimension,
})
}
/// Generate embedding for a single text
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
let seq_len = input_ids.len();
// Create input tensors with ort 2.0 API
let input_ids_tensor = Value::from_array(([1, seq_len], input_ids))?;
let attention_mask_tensor = Value::from_array(([1, seq_len], attention_mask))?;
let token_type_ids_tensor = Value::from_array(([1, seq_len], token_type_ids))?;
// Run inference
let inputs = ort::inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor,
];
let mut session_guard = self.session.lock()
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
let outputs = session_guard.run(inputs)?;
// Extract output
let output = outputs.get("last_hidden_state")
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
// Get the tensor data
let (shape, data) = output.try_extract_tensor::<f32>()?;
// Mean pooling over sequence dimension
let hidden_size = *shape.last().unwrap_or(&384) as usize;
let seq_len = data.len() / hidden_size;
let mut embedding = vec![0.0f32; hidden_size];
for i in 0..seq_len {
for j in 0..hidden_size {
embedding[j] += data[i * hidden_size + j];
}
}
for val in &mut embedding {
*val /= seq_len as f32;
}
// L2 normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
/// Generate embeddings for multiple texts
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
/// Get the embedding dimension
pub fn dimension(&self) -> usize {
self.dimension
}
}
/// Extract keywords from text using simple frequency analysis
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
use std::collections::HashMap;
let stop_words: std::collections::HashSet<&str> = [
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "must", "shall", "can", "this",
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
"all", "each", "every", "both", "few", "more", "most", "other", "some",
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
"very", "just", "also", "now", "here", "there", "then", "once", "if",
].iter().cloned().collect();
let mut word_counts: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() {
let clean: String = word
.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_lowercase();
if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
*word_counts.entry(clean).or_insert(0) += 1;
}
}
let mut sorted: Vec<_> = word_counts.into_iter().collect();
sorted.sort_by(|a, b| b.1.cmp(&a.1));
sorted.into_iter()
.take(limit)
.map(|(word, _)| word)
.collect()
}

150
src/lib.rs Normal file
View File

@@ -0,0 +1,150 @@
//! OpenBrain MCP - High-performance vector memory for AI agents
pub mod auth;
pub mod config;
pub mod db;
pub mod embedding;
pub mod migrations;
pub mod tools;
pub mod transport;
use anyhow::Result;
use axum::{Router, Json, http::StatusCode, middleware};
use serde_json::json;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::{info, error};
use crate::auth::auth_middleware;
use crate::config::Config;
use crate::db::Database;
use crate::embedding::EmbeddingEngine;
use crate::transport::McpState;
/// Service readiness state
#[derive(Clone, Debug, PartialEq)]
pub enum ReadinessState {
Initializing,
Ready,
Failed(String),
}
/// Shared application state
pub struct AppState {
pub db: Database,
pub embedding: tokio::sync::RwLock<Option<Arc<EmbeddingEngine>>>,
pub config: Config,
pub readiness: tokio::sync::RwLock<ReadinessState>,
}
impl AppState {
/// Get embedding engine, returns None if not ready
pub async fn get_embedding(&self) -> Option<Arc<EmbeddingEngine>> {
self.embedding.read().await.clone()
}
}
/// Health check endpoint - always returns OK if server is running
async fn health_handler() -> Json<serde_json::Value> {
Json(json!({"status": "ok"}))
}
/// Readiness endpoint - returns 503 if embedding not ready
async fn readiness_handler(
state: axum::extract::State<Arc<AppState>>,
) -> (StatusCode, Json<serde_json::Value>) {
let readiness = state.readiness.read().await.clone();
match readiness {
ReadinessState::Ready => (
StatusCode::OK,
Json(json!({"status": "ready", "embedding": true}))
),
ReadinessState::Initializing => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"status": "initializing", "embedding": false}))
),
ReadinessState::Failed(err) => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({"status": "failed", "error": err}))
),
}
}
/// Run the MCP server
pub async fn run_server(config: Config, db: Database) -> Result<()> {
// Create state with None embedding (will init in background)
let state = Arc::new(AppState {
db,
embedding: tokio::sync::RwLock::new(None),
config: config.clone(),
readiness: tokio::sync::RwLock::new(ReadinessState::Initializing),
});
// Spawn background task to initialize embedding with retry
let state_clone = state.clone();
let embedding_config = config.embedding.clone();
tokio::spawn(async move {
let max_retries = 3;
let mut attempt = 0;
loop {
attempt += 1;
info!("Initializing embedding engine (attempt {}/{})", attempt, max_retries);
match EmbeddingEngine::new(&embedding_config).await {
Ok(engine) => {
let engine = Arc::new(engine);
*state_clone.embedding.write().await = Some(engine);
*state_clone.readiness.write().await = ReadinessState::Ready;
info!("Embedding engine initialized successfully");
break;
}
Err(e) => {
error!("Failed to init embedding (attempt {}): {:?}", attempt, e);
if attempt >= max_retries {
let err_msg = format!("Failed after {} attempts: {:?}", max_retries, e);
*state_clone.readiness.write().await = ReadinessState::Failed(err_msg);
break;
}
// Exponential backoff: 2s, 4s, 8s...
tokio::time::sleep(tokio::time::Duration::from_secs(2u64.pow(attempt))).await;
}
}
}
});
// Create MCP state for SSE transport
let mcp_state = McpState::new(state.clone());
// Build router with health/readiness endpoints (no auth required)
let health_router = Router::new()
.route("/health", axum::routing::get(health_handler))
.route("/ready", axum::routing::get(readiness_handler))
.with_state(state.clone());
// Build MCP router with auth middleware
let mcp_router = transport::mcp_router(mcp_state)
.layer(middleware::from_fn_with_state(state.clone(), auth_middleware));
let app = Router::new()
.merge(health_router)
.nest("/mcp", mcp_router)
.layer(TraceLayer::new_for_http())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
);
// Start server immediately
let bind_addr = format!("{}:{}", config.server.host, config.server.port);
let listener = TcpListener::bind(&bind_addr).await?;
info!("Server listening on {}", bind_addr);
axum::serve(listener, app).await?;
Ok(())
}

46
src/main.rs Normal file
View File

@@ -0,0 +1,46 @@
//! OpenBrain MCP Server - High-performance vector memory for AI agents
//!
//! This is the main entry point for the OpenBrain MCP server.
use anyhow::Result;
use std::env;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use openbrain_mcp::{config::Config, db::Database, migrations, run_server};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
.with(tracing_subscriber::fmt::layer())
.init();
info!("Starting OpenBrain MCP Server v{}", env!("CARGO_PKG_VERSION"));
// Load configuration
let config = Config::load()?;
info!("Configuration loaded from environment");
match env::args().nth(1).as_deref() {
Some("migrate") => {
migrations::run(&config.database).await?;
info!("Database migrations completed successfully");
return Ok(());
}
Some(arg) => {
anyhow::bail!("Unknown command: {arg}. Supported commands: migrate");
}
None => {}
}
// Initialize database connection pool
let db = Database::new(&config.database).await?;
info!("Database connection pool initialized");
// Run the MCP server
run_server(config, db).await?;
Ok(())
}

50
src/migrations.rs Normal file
View File

@@ -0,0 +1,50 @@
//! Database migrations using refinery.
use anyhow::{Context, Result};
use refinery::embed_migrations;
use tokio_postgres::NoTls;
use tracing::info;
use crate::config::DatabaseConfig;
embed_migrations!("migrations");
/// Apply all pending database migrations.
pub async fn run(config: &DatabaseConfig) -> Result<()> {
let mut pg_config = tokio_postgres::Config::new();
pg_config.host(&config.host);
pg_config.port(config.port);
pg_config.dbname(&config.name);
pg_config.user(&config.user);
pg_config.password(&config.password);
let (mut client, connection) = pg_config
.connect(NoTls)
.await
.context("Failed to connect to database for migrations")?;
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("Database migration connection error: {}", e);
}
});
let report = migrations::runner()
.run_async(&mut client)
.await
.context("Failed to apply database migrations")?;
if report.applied_migrations().is_empty() {
info!("No database migrations to apply");
} else {
for migration in report.applied_migrations() {
info!(
version = migration.version(),
name = migration.name(),
"Applied database migration"
);
}
}
Ok(())
}

106
src/tools/mod.rs Normal file
View File

@@ -0,0 +1,106 @@
//! MCP Tools for OpenBrain
//!
//! Provides the core tools for memory storage and retrieval:
//! - `store`: Store a memory with automatic embedding generation
//! - `query`: Query memories by semantic similarity
//! - `purge`: Delete memories by agent_id or time range
pub mod query;
pub mod store;
pub mod purge;
use anyhow::Result;
use serde_json::{json, Value};
use std::sync::Arc;
use crate::AppState;
/// Get all tool definitions for MCP tools/list
pub fn get_tool_definitions() -> Vec<Value> {
vec![
json!({
"name": "store",
"description": "Store a memory with automatic embedding generation and keyword extraction. The memory will be associated with the agent_id for isolated retrieval.",
"inputSchema": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "The text content to store as a memory"
},
"agent_id": {
"type": "string",
"description": "Unique identifier for the agent storing the memory (default: 'default')"
},
"metadata": {
"type": "object",
"description": "Optional metadata to attach to the memory"
}
},
"required": ["content"]
}
}),
json!({
"name": "query",
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query text"
},
"agent_id": {
"type": "string",
"description": "Agent ID to search within (default: 'default')"
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return (default: 10)"
},
"threshold": {
"type": "number",
"description": "Minimum similarity threshold 0.0-1.0 (default: 0.5)"
}
},
"required": ["query"]
}
}),
json!({
"name": "purge",
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
"inputSchema": {
"type": "object",
"properties": {
"agent_id": {
"type": "string",
"description": "Agent ID whose memories to delete (required)"
},
"before": {
"type": "string",
"description": "Optional ISO8601 timestamp - delete memories created before this time"
},
"confirm": {
"type": "boolean",
"description": "Must be true to confirm deletion"
}
},
"required": ["agent_id", "confirm"]
}
})
]
}
/// Execute a tool by name with given arguments
pub async fn execute_tool(
state: &Arc<AppState>,
tool_name: &str,
arguments: Value,
) -> Result<String> {
match tool_name {
"store" => store::execute(state, arguments).await,
"query" => query::execute(state, arguments).await,
"purge" => purge::execute(state, arguments).await,
_ => anyhow::bail!("Unknown tool: {}", tool_name),
}
}

79
src/tools/purge.rs Normal file
View File

@@ -0,0 +1,79 @@
//! Purge Tool - Delete memories by agent_id or time range
use anyhow::{bail, Context, Result};
use chrono::DateTime;
use serde_json::Value;
use std::sync::Arc;
use tracing::{info, warn};
use crate::AppState;
/// Execute the purge tool
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
// Extract parameters
let agent_id = arguments
.get("agent_id")
.and_then(|v| v.as_str())
.context("Missing required parameter: agent_id")?;
let confirm = arguments
.get("confirm")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if !confirm {
bail!("Purge operation requires 'confirm: true' to proceed");
}
let before = arguments
.get("before")
.and_then(|v| v.as_str())
.map(|s| DateTime::parse_from_rfc3339(s))
.transpose()
.context("Invalid 'before' timestamp format - use ISO8601/RFC3339")?
.map(|dt| dt.with_timezone(&chrono::Utc));
// Get current count before purge
let count_before = state
.db
.count_memories(agent_id)
.await
.context("Failed to count memories")?;
if count_before == 0 {
info!("No memories found for agent '{}'", agent_id);
return Ok(serde_json::json!({
"success": true,
"agent_id": agent_id,
"deleted": 0,
"message": "No memories found to purge"
})
.to_string());
}
warn!(
"Purging memories for agent '{}' (before={:?})",
agent_id, before
);
// Execute purge
let deleted = state
.db
.purge_memories(agent_id, before)
.await
.context("Failed to purge memories")?;
info!(
"Purged {} memories for agent '{}'",
deleted, agent_id
);
Ok(serde_json::json!({
"success": true,
"agent_id": agent_id,
"deleted": deleted,
"had_before_filter": before.is_some(),
"message": format!("Successfully purged {} memories", deleted)
})
.to_string())
}

81
src/tools/query.rs Normal file
View File

@@ -0,0 +1,81 @@
//! Query Tool - Search memories by semantic similarity
use anyhow::{Context, Result, anyhow};
use serde_json::Value;
use std::sync::Arc;
use tracing::info;
use crate::AppState;
/// Execute the query tool
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
// Get embedding engine, return error if not ready
let embedding_engine = state
.get_embedding()
.await
.ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?;
// Extract parameters
let query_text = arguments
.get("query")
.and_then(|v| v.as_str())
.context("Missing required parameter: query")?;
let agent_id = arguments
.get("agent_id")
.and_then(|v| v.as_str())
.unwrap_or("default");
let limit = arguments
.get("limit")
.and_then(|v| v.as_i64())
.unwrap_or(10);
let threshold = arguments
.get("threshold")
.and_then(|v| v.as_f64())
.unwrap_or(0.5) as f32;
info!(
"Querying memories for agent '{}': '{}' (limit={}, threshold={})",
agent_id, query_text, limit, threshold
);
// Generate embedding for query using Arc<EmbeddingEngine>
let query_embedding = embedding_engine
.embed(query_text)
.context("Failed to generate query embedding")?;
// Search database
let matches = state
.db
.query_memories(agent_id, &query_embedding, limit, threshold)
.await
.context("Failed to query memories")?;
info!("Found {} matching memories", matches.len());
// Format results
let results: Vec<Value> = matches
.iter()
.map(|m| {
serde_json::json!({
"id": m.record.id.to_string(),
"content": m.record.content,
"similarity": m.similarity,
"keywords": m.record.keywords,
"metadata": m.record.metadata,
"created_at": m.record.created_at.to_rfc3339()
})
})
.collect();
Ok(serde_json::json!({
"success": true,
"agent_id": agent_id,
"query": query_text,
"count": results.len(),
"results": results
})
.to_string())
}

66
src/tools/store.rs Normal file
View File

@@ -0,0 +1,66 @@
//! Store Tool - Store memories with automatic embeddings
use anyhow::{Context, Result, anyhow};
use serde_json::Value;
use std::sync::Arc;
use tracing::info;
use crate::embedding::extract_keywords;
use crate::AppState;
/// Execute the store tool
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
// Get embedding engine, return error if not ready
let embedding_engine = state
.get_embedding()
.await
.ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?;
// Extract parameters
let content = arguments
.get("content")
.and_then(|v| v.as_str())
.context("Missing required parameter: content")?;
let agent_id = arguments
.get("agent_id")
.and_then(|v| v.as_str())
.unwrap_or("default");
let metadata = arguments
.get("metadata")
.cloned()
.unwrap_or(serde_json::json!({}));
info!(
"Storing memory for agent '{}': {} chars",
agent_id,
content.len()
);
// Generate embedding using Arc<EmbeddingEngine>
let embedding = embedding_engine
.embed(content)
.context("Failed to generate embedding")?;
// Extract keywords
let keywords = extract_keywords(content, 10);
// Store in database
let id = state
.db
.store_memory(agent_id, content, &embedding, &keywords, metadata)
.await
.context("Failed to store memory")?;
info!("Memory stored with ID: {}", id);
Ok(serde_json::json!({
"success": true,
"id": id.to_string(),
"agent_id": agent_id,
"keywords": keywords,
"embedding_dimension": embedding.len()
})
.to_string())
}

531
src/transport.rs Normal file
View File

@@ -0,0 +1,531 @@
//! SSE Transport for MCP Protocol
//!
//! Implements Server-Sent Events transport for the Model Context Protocol.
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
response::{
IntoResponse, Response,
sse::{Event, KeepAlive, Sse},
},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::Arc,
time::Duration,
};
use tokio::sync::{RwLock, broadcast, mpsc};
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{AppState, auth, tools};
type SessionStore = RwLock<HashMap<String, mpsc::Sender<serde_json::Value>>>;
/// MCP Server State
pub struct McpState {
pub app: Arc<AppState>,
pub event_tx: broadcast::Sender<McpEvent>,
sessions: SessionStore,
}
impl McpState {
pub fn new(app: Arc<AppState>) -> Arc<Self> {
let (event_tx, _) = broadcast::channel(100);
Arc::new(Self {
app,
event_tx,
sessions: RwLock::new(HashMap::new()),
})
}
async fn insert_session(
&self,
session_id: String,
tx: mpsc::Sender<serde_json::Value>,
) {
self.sessions.write().await.insert(session_id, tx);
}
async fn remove_session(&self, session_id: &str) {
self.sessions.write().await.remove(session_id);
}
async fn has_session(&self, session_id: &str) -> bool {
self.sessions.read().await.contains_key(session_id)
}
async fn send_to_session(
&self,
session_id: &str,
response: &JsonRpcResponse,
) -> Result<(), SessionSendError> {
let tx = {
let sessions = self.sessions.read().await;
sessions
.get(session_id)
.cloned()
.ok_or(SessionSendError::NotFound)?
};
let payload =
serde_json::to_value(response).expect("serializing JSON-RPC response should succeed");
if tx.send(payload).await.is_err() {
self.remove_session(session_id).await;
return Err(SessionSendError::Closed);
}
Ok(())
}
}
enum SessionSendError {
NotFound,
Closed,
}
struct SessionGuard {
state: Arc<McpState>,
session_id: String,
}
impl SessionGuard {
fn new(state: Arc<McpState>, session_id: String) -> Self {
Self { state, session_id }
}
}
impl Drop for SessionGuard {
fn drop(&mut self) {
let state = self.state.clone();
let session_id = self.session_id.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
handle.spawn(async move {
state.remove_session(&session_id).await;
});
}
}
}
/// MCP Event for SSE streaming
#[derive(Clone, Debug, Serialize)]
pub struct McpEvent {
pub id: String,
pub event_type: String,
pub data: serde_json::Value,
}
/// MCP JSON-RPC Request
#[derive(Debug, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
#[serde(default)]
pub id: Option<serde_json::Value>,
pub method: String,
#[serde(default)]
pub params: serde_json::Value,
}
/// MCP JSON-RPC Response
#[derive(Debug, Serialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub id: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcError>,
}
#[derive(Debug, Serialize)]
pub struct JsonRpcError {
pub code: i32,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
#[derive(Debug, Default, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PostMessageQuery {
#[serde(default)]
session_id: Option<String>,
}
/// Create the MCP router
pub fn mcp_router(state: Arc<McpState>) -> Router {
Router::new()
.route("/sse", get(sse_handler))
.route("/message", post(message_handler))
.route("/health", get(health_handler))
.with_state(state)
}
/// SSE endpoint for streaming events
async fn sse_handler(
State(state): State<Arc<McpState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let mut broadcast_rx = state.event_tx.subscribe();
let (session_tx, mut session_rx) = mpsc::channel(32);
let session_id = Uuid::new_v4().to_string();
let endpoint = session_message_endpoint(&session_id);
state.insert_session(session_id.clone(), session_tx).await;
let stream = async_stream::stream! {
let _session_guard = SessionGuard::new(state.clone(), session_id.clone());
// Send endpoint event (required by MCP SSE protocol)
// This tells the client where to POST JSON-RPC messages for this session.
yield Ok(Event::default()
.event("endpoint")
.data(endpoint));
// Send initial tools list so Agent Zero knows what's available
let tools_response = JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: serde_json::json!("initial-tools"),
result: Some(serde_json::json!({
"tools": tools::get_tool_definitions()
})),
error: None,
};
yield Ok(Event::default()
.event("message")
.json_data(&tools_response)
.unwrap());
loop {
tokio::select! {
maybe_message = session_rx.recv() => {
match maybe_message {
Some(message) => {
yield Ok(Event::default()
.event("message")
.json_data(&message)
.unwrap());
}
None => break,
}
}
event = broadcast_rx.recv() => {
match event {
Ok(event) => {
yield Ok(Event::default()
.event(&event.event_type)
.id(&event.id)
.json_data(&event.data)
.unwrap());
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("SSE client lagged, missed {} events", n);
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
};
Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))
}
/// Message endpoint for JSON-RPC requests
async fn message_handler(
State(state): State<Arc<McpState>>,
Query(query): Query<PostMessageQuery>,
headers: HeaderMap,
Json(request): Json<JsonRpcRequest>,
) -> Response {
info!("Received MCP request: {}", request.method);
if let Some(session_id) = query.session_id.as_deref() {
if !state.has_session(session_id).await {
return StatusCode::NOT_FOUND.into_response();
}
}
let request = apply_request_context(request, &headers);
let response = dispatch_request(&state, &request).await;
match query.session_id.as_deref() {
Some(session_id) => route_session_response(&state, session_id, response).await,
None => match response {
Some(response) => Json(response).into_response(),
None => StatusCode::ACCEPTED.into_response(),
},
}
}
/// Health check endpoint
async fn health_handler() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "healthy",
"server": "openbrain-mcp",
"version": env!("CARGO_PKG_VERSION")
}))
}
fn session_message_endpoint(session_id: &str) -> String {
format!("/mcp/message?sessionId={session_id}")
}
async fn route_session_response(
state: &Arc<McpState>,
session_id: &str,
response: Option<JsonRpcResponse>,
) -> Response {
match response {
Some(response) => match state.send_to_session(session_id, &response).await {
Ok(()) => StatusCode::ACCEPTED.into_response(),
Err(SessionSendError::NotFound) => StatusCode::NOT_FOUND.into_response(),
Err(SessionSendError::Closed) => StatusCode::GONE.into_response(),
},
None => StatusCode::ACCEPTED.into_response(),
}
}
fn apply_request_context(
mut request: JsonRpcRequest,
headers: &HeaderMap,
) -> JsonRpcRequest {
if let Some(agent_id) = auth::get_optional_agent_id(headers) {
inject_agent_id(&mut request, &agent_id);
}
request
}
fn inject_agent_id(request: &mut JsonRpcRequest, agent_id: &str) {
if request.method != "tools/call" {
return;
}
if !request.params.is_object() {
request.params = serde_json::json!({});
}
let params = request
.params
.as_object_mut()
.expect("params should be an object");
let arguments = params
.entry("arguments")
.or_insert_with(|| serde_json::json!({}));
if !arguments.is_object() {
*arguments = serde_json::json!({});
}
arguments
.as_object_mut()
.expect("arguments should be an object")
.entry("agent_id".to_string())
.or_insert_with(|| serde_json::json!(agent_id));
}
async fn dispatch_request(
state: &Arc<McpState>,
request: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
match request.method.as_str() {
"initialize" => request
.id
.clone()
.map(|id| success_response(id, initialize_result())),
"ping" => request
.id
.clone()
.map(|id| success_response(id, serde_json::json!({}))),
"tools/list" => request.id.clone().map(|id| {
success_response(
id,
serde_json::json!({
"tools": tools::get_tool_definitions()
}),
)
}),
"tools/call" => handle_tools_call(state, request).await,
"notifications/initialized" => {
info!("Received MCP initialized notification");
None
}
_ => request.id.clone().map(|id| {
error_response(
id,
-32601,
format!("Method not found: {}", request.method),
None,
)
}),
}
}
fn initialize_result() -> serde_json::Value {
serde_json::json!({
"protocolVersion": "2024-11-05",
"serverInfo": {
"name": "openbrain-mcp",
"version": env!("CARGO_PKG_VERSION")
},
"capabilities": {
"tools": {
"listChanged": false
}
}
})
}
fn success_response(
id: serde_json::Value,
result: serde_json::Value,
) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
fn error_response(
id: serde_json::Value,
code: i32,
message: String,
data: Option<serde_json::Value>,
) -> JsonRpcResponse {
JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code,
message,
data,
}),
}
}
/// Handle tools/call request
async fn handle_tools_call(
state: &Arc<McpState>,
request: &JsonRpcRequest,
) -> Option<JsonRpcResponse> {
let params = &request.params;
let tool_name = params.get("name").and_then(|v| v.as_str()).unwrap_or("");
let arguments = params
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
let result = match tools::execute_tool(&state.app, tool_name, arguments).await {
Ok(result) => Ok(serde_json::json!({
"content": [{
"type": "text",
"text": result
}]
})),
Err(e) => {
let full_error = format!("{:#}", e);
error!("Tool execution error: {}", full_error);
Err(JsonRpcError {
code: -32000,
message: full_error,
data: None,
})
}
};
match (request.id.clone(), result) {
(Some(id), Ok(result)) => Some(success_response(id, result)),
(Some(id), Err(error)) => Some(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(error),
}),
(None, Ok(_)) => None,
(None, Err(error)) => {
error!(
"Tool execution failed for notification '{}': {}",
request.method, error.message
);
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn injects_agent_id_when_missing_from_tool_arguments() {
let mut request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"query": "editor preferences"
}
}),
};
inject_agent_id(&mut request, "agent-from-header");
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("agent-from-header")
);
}
#[test]
fn preserves_explicit_agent_id() {
let mut request = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::json!("1")),
method: "tools/call".to_string(),
params: serde_json::json!({
"name": "query",
"arguments": {
"agent_id": "explicit-agent"
}
}),
};
inject_agent_id(&mut request, "agent-from-header");
assert_eq!(
request
.params
.get("arguments")
.and_then(|value| value.get("agent_id"))
.and_then(|value| value.as_str()),
Some("explicit-agent")
);
}
#[test]
fn session_endpoint_uses_camel_case_query_param() {
assert_eq!(
session_message_endpoint("abc123"),
"/mcp/message?sessionId=abc123"
);
}
}