Files
openbrain-mcp/src/config.rs

352 lines
11 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Configuration management for OpenBrain MCP
//!
//! Loads configuration from environment variables with sensible defaults.
use anyhow::Result;
use serde::{Deserialize, Deserializer};
/// Main configuration structure
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub embedding: EmbeddingConfig,
pub query: QueryConfig,
pub dedup: DedupConfig,
pub ttl: TtlConfig,
pub auth: AuthConfig,
pub truth: TruthConfig,
}
/// Server configuration
#[derive(Debug, Clone, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
}
/// Database configuration
#[derive(Debug, Clone, Deserialize)]
pub struct DatabaseConfig {
pub host: String,
#[serde(default = "default_db_port")]
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
#[serde(default = "default_pool_size")]
pub pool_size: usize,
}
/// Embedding engine configuration
#[derive(Debug, Clone, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_model_path")]
pub model_path: String,
#[serde(default = "default_embedding_dim")]
pub dimension: usize,
}
/// Query scoring configuration
#[derive(Debug, Clone, Deserialize)]
pub struct QueryConfig {
#[serde(default = "default_vector_weight")]
pub vector_weight: f32,
#[serde(default = "default_text_weight")]
pub text_weight: f32,
}
/// Deduplication configuration
#[derive(Debug, Clone, Deserialize)]
pub struct DedupConfig {
#[serde(default = "default_dedup_threshold")]
pub threshold: f32,
}
/// TTL / expiry configuration
#[derive(Debug, Clone, Deserialize)]
pub struct TtlConfig {
#[serde(default = "default_cleanup_interval_seconds")]
pub cleanup_interval_seconds: u64,
}
/// Truth scoring engine configuration
#[derive(Debug, Clone, Deserialize)]
pub struct TruthConfig {
/// Enable truth scoring background worker
#[serde(default = "default_truth_enabled")]
pub enabled: bool,
/// Seconds between scoring cycles
#[serde(default = "default_scoring_interval_seconds")]
pub scoring_interval_seconds: u64,
/// Number of memories to score per cycle
#[serde(default = "default_truth_batch_size")]
pub batch_size: usize,
/// Seconds before a scored memory is re-evaluated
#[serde(default = "default_rescore_after_seconds")]
pub rescore_after_seconds: u64,
/// Base confidence for PLN deduction chains
#[serde(default = "default_pln_base_confidence")]
pub pln_base_confidence: f32,
/// ECAN STI decay rate per cycle (0.01.0)
#[serde(default = "default_ecan_decay_rate")]
pub ecan_decay_rate: f32,
/// ECAN attention spread factor
#[serde(default = "default_ecan_spread_factor")]
pub ecan_spread_factor: f32,
/// Similarity threshold above which conflicting memories are contradictions
#[serde(default = "default_contradiction_threshold")]
pub contradiction_threshold: f32,
/// Truth value threshold for "verified" categorization
#[serde(default = "default_verification_threshold")]
pub verification_threshold: f32,
/// Max related memories to cross-reference per scoring
#[serde(default = "default_cross_ref_limit")]
pub cross_ref_limit: i64,
}
/// Authentication configuration
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
#[serde(default = "default_auth_enabled")]
pub enabled: bool,
#[serde(default, deserialize_with = "deserialize_api_keys")]
pub api_keys: Vec<String>,
}
/// Deserialize API keys from either a comma-separated string or a Vec<String>
fn deserialize_api_keys<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
// Try to deserialize as a string first, then as a Vec
#[derive(Deserialize)]
#[serde(untagged)]
enum StringOrVec {
String(String),
Vec(Vec<String>),
}
match Option::<StringOrVec>::deserialize(deserializer)? {
Some(StringOrVec::String(s)) => Ok(s
.split(',')
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.collect()),
Some(StringOrVec::Vec(v)) => Ok(v),
None => Ok(Vec::new()),
}
}
// Default value functions
fn default_host() -> String {
"0.0.0.0".to_string()
}
fn default_port() -> u16 {
3100
}
fn default_db_port() -> u16 {
5432
}
fn default_pool_size() -> usize {
10
}
fn default_model_path() -> String {
"models/all-MiniLM-L6-v2".to_string()
}
fn default_embedding_dim() -> usize {
384
}
fn default_vector_weight() -> f32 {
0.6
}
fn default_text_weight() -> f32 {
0.4
}
fn default_dedup_threshold() -> f32 {
0.90
}
fn default_cleanup_interval_seconds() -> u64 {
300
}
fn default_auth_enabled() -> bool {
false
}
// Truth engine defaults
fn default_truth_enabled() -> bool {
false
}
fn default_scoring_interval_seconds() -> u64 {
300
}
fn default_truth_batch_size() -> usize {
50
}
fn default_rescore_after_seconds() -> u64 {
86400
}
fn default_pln_base_confidence() -> f32 {
0.85
}
fn default_ecan_decay_rate() -> f32 {
0.95
}
fn default_ecan_spread_factor() -> f32 {
0.05
}
fn default_contradiction_threshold() -> f32 {
0.85
}
fn default_verification_threshold() -> f32 {
0.8
}
fn default_cross_ref_limit() -> i64 {
10
}
impl Config {
/// Load configuration from environment variables
pub fn load() -> Result<Self> {
// Load .env file if present
dotenvy::dotenv().ok();
let config = config::Config::builder()
// Server settings
.set_default("server.host", default_host())?
.set_default("server.port", default_port() as i64)?
// Database settings
.set_default("database.port", default_db_port() as i64)?
.set_default("database.pool_size", default_pool_size() as i64)?
// Embedding settings
.set_default("embedding.model_path", default_model_path())?
.set_default("embedding.dimension", default_embedding_dim() as i64)?
// Query settings
.set_default("query.vector_weight", default_vector_weight() as f64)?
.set_default("query.text_weight", default_text_weight() as f64)?
// Dedup settings
.set_default("dedup.threshold", default_dedup_threshold() as f64)?
// TTL settings
.set_default(
"ttl.cleanup_interval_seconds",
default_cleanup_interval_seconds() as i64,
)?
// Auth settings
.set_default("auth.enabled", default_auth_enabled())?
// Truth engine settings
.set_default("truth.enabled", default_truth_enabled())?
.set_default(
"truth.scoring_interval_seconds",
default_scoring_interval_seconds() as i64,
)?
.set_default("truth.batch_size", default_truth_batch_size() as i64)?
.set_default(
"truth.rescore_after_seconds",
default_rescore_after_seconds() as i64,
)?
.set_default(
"truth.pln_base_confidence",
default_pln_base_confidence() as f64,
)?
.set_default(
"truth.ecan_decay_rate",
default_ecan_decay_rate() as f64,
)?
.set_default(
"truth.ecan_spread_factor",
default_ecan_spread_factor() as f64,
)?
.set_default(
"truth.contradiction_threshold",
default_contradiction_threshold() as f64,
)?
.set_default(
"truth.verification_threshold",
default_verification_threshold() as f64,
)?
.set_default(
"truth.cross_ref_limit",
default_cross_ref_limit(),
)?
// Load from environment with OPENBRAIN_ prefix
.add_source(
config::Environment::with_prefix("OPENBRAIN")
.separator("__")
.try_parsing(true),
)
.build()?;
let mut config: Self = config.try_deserialize()?;
// Keep compatibility with plain env names proposed in issue #17.
if let Ok(vector_weight) = std::env::var("VECTOR_WEIGHT") {
if let Ok(parsed) = vector_weight.parse::<f32>() {
config.query.vector_weight = parsed;
}
}
if let Ok(text_weight) = std::env::var("TEXT_WEIGHT") {
if let Ok(parsed) = text_weight.parse::<f32>() {
config.query.text_weight = parsed;
}
}
if let Ok(dedup_threshold) = std::env::var("DEDUP_THRESHOLD") {
if let Ok(parsed) = dedup_threshold.parse::<f32>() {
config.dedup.threshold = parsed;
}
}
Ok(config)
}
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig {
host: default_host(),
port: default_port(),
},
database: DatabaseConfig {
host: "localhost".to_string(),
port: default_db_port(),
name: "openbrain".to_string(),
user: "openbrain_svc".to_string(),
password: String::new(),
pool_size: default_pool_size(),
},
embedding: EmbeddingConfig {
model_path: default_model_path(),
dimension: default_embedding_dim(),
},
query: QueryConfig {
vector_weight: default_vector_weight(),
text_weight: default_text_weight(),
},
dedup: DedupConfig {
threshold: default_dedup_threshold(),
},
ttl: TtlConfig {
cleanup_interval_seconds: default_cleanup_interval_seconds(),
},
auth: AuthConfig {
enabled: default_auth_enabled(),
api_keys: Vec::new(),
},
truth: TruthConfig {
enabled: default_truth_enabled(),
scoring_interval_seconds: default_scoring_interval_seconds(),
batch_size: default_truth_batch_size(),
rescore_after_seconds: default_rescore_after_seconds(),
pln_base_confidence: default_pln_base_confidence(),
ecan_decay_rate: default_ecan_decay_rate(),
ecan_spread_factor: default_ecan_spread_factor(),
contradiction_threshold: default_contradiction_threshold(),
verification_threshold: default_verification_threshold(),
cross_ref_limit: default_cross_ref_limit(),
},
}
}
}