From 0ba37f85739b7d4f88d4e49faa93a2d437529d8b Mon Sep 17 00:00:00 2001 From: Agent Zero Date: Sun, 22 Mar 2026 22:38:14 +0000 Subject: [PATCH] Add hybrid text plus vector memory search --- .env.example | 7 +++ migrations/V2__hybrid_search.sql | 36 +++++++++++++++ src/config.rs | 35 ++++++++++++++- src/db.rs | 77 ++++++++++++++++++++++++++++---- src/tools/mod.rs | 2 +- src/tools/query.rs | 36 +++++++++++++-- 6 files changed, 179 insertions(+), 14 deletions(-) create mode 100644 migrations/V2__hybrid_search.sql diff --git a/.env.example b/.env.example index 90ab331..4f273f9 100644 --- a/.env.example +++ b/.env.example @@ -19,6 +19,13 @@ OPENBRAIN__DATABASE__POOL_SIZE=10 OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2 OPENBRAIN__EMBEDDING__DIMENSION=384 +# Hybrid query scoring +OPENBRAIN__QUERY__VECTOR_WEIGHT=0.6 +OPENBRAIN__QUERY__TEXT_WEIGHT=0.4 +# Backward-compatible plain env aliases +# VECTOR_WEIGHT=0.6 +# TEXT_WEIGHT=0.4 + # Authentication (optional) OPENBRAIN__AUTH__ENABLED=false # Comma-separated list of API keys diff --git a/migrations/V2__hybrid_search.sql b/migrations/V2__hybrid_search.sql new file mode 100644 index 0000000..0ef5a64 --- /dev/null +++ b/migrations/V2__hybrid_search.sql @@ -0,0 +1,36 @@ +ALTER TABLE memories + ADD COLUMN IF NOT EXISTS tsv tsvector; + +CREATE OR REPLACE FUNCTION memories_tsv_trigger() +RETURNS trigger +LANGUAGE plpgsql +AS $$ +BEGIN + NEW.tsv := + setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.content, '')), 'A') || + setweight( + to_tsvector('pg_catalog.english', COALESCE(array_to_string(NEW.keywords, ' '), '')), + 'B' + ); + RETURN NEW; +END; +$$; + +UPDATE memories +SET tsv = + setweight(to_tsvector('pg_catalog.english', COALESCE(content, '')), 'A') || + setweight( + to_tsvector('pg_catalog.english', COALESCE(array_to_string(keywords, ' '), '')), + 'B' + ) +WHERE tsv IS NULL; + +DROP TRIGGER IF EXISTS memories_tsv_update ON memories; + +CREATE TRIGGER memories_tsv_update +BEFORE INSERT OR UPDATE OF content, keywords ON memories +FOR EACH ROW +EXECUTE FUNCTION memories_tsv_trigger(); + +CREATE INDEX IF NOT EXISTS idx_memories_tsv + ON memories USING GIN (tsv); diff --git a/src/config.rs b/src/config.rs index 449be4f..f73a685 100644 --- a/src/config.rs +++ b/src/config.rs @@ -11,6 +11,7 @@ pub struct Config { pub server: ServerConfig, pub database: DatabaseConfig, pub embedding: EmbeddingConfig, + pub query: QueryConfig, pub auth: AuthConfig, } @@ -45,6 +46,15 @@ pub struct EmbeddingConfig { pub dimension: usize, } +/// Query scoring configuration +#[derive(Debug, Clone, Deserialize)] +pub struct QueryConfig { + #[serde(default = "default_vector_weight")] + pub vector_weight: f32, + #[serde(default = "default_text_weight")] + pub text_weight: f32, +} + /// Authentication configuration #[derive(Debug, Clone, Deserialize)] pub struct AuthConfig { @@ -86,6 +96,8 @@ fn default_db_port() -> u16 { 5432 } fn default_pool_size() -> usize { 10 } fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() } fn default_embedding_dim() -> usize { 384 } +fn default_vector_weight() -> f32 { 0.6 } +fn default_text_weight() -> f32 { 0.4 } fn default_auth_enabled() -> bool { false } impl Config { @@ -104,6 +116,9 @@ impl Config { // Embedding settings .set_default("embedding.model_path", default_model_path())? .set_default("embedding.dimension", default_embedding_dim() as i64)? + // Query settings + .set_default("query.vector_weight", default_vector_weight() as f64)? + .set_default("query.text_weight", default_text_weight() as f64)? // Auth settings .set_default("auth.enabled", default_auth_enabled())? // Load from environment with OPENBRAIN_ prefix @@ -114,7 +129,21 @@ impl Config { ) .build()?; - Ok(config.try_deserialize()?) + let mut config: Self = config.try_deserialize()?; + + // Keep compatibility with plain env names proposed in issue #17. + if let Ok(vector_weight) = std::env::var("VECTOR_WEIGHT") { + if let Ok(parsed) = vector_weight.parse::() { + config.query.vector_weight = parsed; + } + } + if let Ok(text_weight) = std::env::var("TEXT_WEIGHT") { + if let Ok(parsed) = text_weight.parse::() { + config.query.text_weight = parsed; + } + } + + Ok(config) } } @@ -137,6 +166,10 @@ impl Default for Config { model_path: default_model_path(), dimension: default_embedding_dim(), }, + query: QueryConfig { + vector_weight: default_vector_weight(), + text_weight: default_text_weight(), + }, auth: AuthConfig { enabled: default_auth_enabled(), api_keys: Vec::new(), diff --git a/src/db.rs b/src/db.rs index 8e810bf..d4424bb 100644 --- a/src/db.rs +++ b/src/db.rs @@ -36,6 +36,9 @@ pub struct MemoryRecord { pub struct MemoryMatch { pub record: MemoryRecord, pub similarity: f32, + pub vector_score: f32, + pub text_score: f32, + pub hybrid_score: f32, } impl Database { @@ -95,27 +98,80 @@ impl Database { pub async fn query_memories( &self, agent_id: &str, + query_text: &str, embedding: &[f32], limit: i64, threshold: f32, + vector_weight: f32, + text_weight: f32, ) -> Result> { let client = self.pool.get().await?; let vector = Vector::from(embedding.to_vec()); let threshold_f64 = threshold as f64; + let vector_weight_f64 = vector_weight as f64; + let text_weight_f64 = text_weight as f64; let rows = client .query( r#" + WITH search_query AS ( + SELECT + NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text, + plainto_tsquery('pg_catalog.english', $2) AS ts_query + ), + scored AS ( + SELECT + id, + agent_id, + content, + keywords, + metadata, + created_at, + (1 - (embedding <=> $1))::real AS vector_score, + CASE + WHEN search_query.query_text IS NULL THEN 0::real + WHEN memories.tsv @@ search_query.ts_query + THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real + ELSE 0::real + END AS text_score + FROM memories + CROSS JOIN search_query + WHERE memories.agent_id = $3 + ), + ranked AS ( + SELECT + *, + MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match + FROM scored + ) SELECT - id, agent_id, content, keywords, metadata, created_at, - (1 - (embedding <=> $1))::real AS similarity - FROM memories - WHERE agent_id = $2 - AND (1 - (embedding <=> $1)) >= $3 - ORDER BY embedding <=> $1 - LIMIT $4 + id, + agent_id, + content, + keywords, + metadata, + created_at, + vector_score, + text_score, + CASE + WHEN has_text_match = 1 + THEN (($5 * vector_score) + ($6 * text_score))::real + ELSE vector_score + END AS hybrid_score + FROM ranked + WHERE vector_score >= $4 OR text_score > 0 + ORDER BY hybrid_score DESC, vector_score DESC + LIMIT $7 "#, - &[&vector, &agent_id, &threshold_f64, &limit], + &[ + &vector, + &query_text, + &agent_id, + &threshold_f64, + &vector_weight_f64, + &text_weight_f64, + &limit, + ], ) .await .context("Failed to query memories")?; @@ -133,7 +189,10 @@ impl Database { metadata: row.get("metadata"), created_at: row.get("created_at"), }, - similarity: row.get("similarity"), + similarity: row.get("hybrid_score"), + vector_score: row.get("vector_score"), + text_score: row.get("text_score"), + hybrid_score: row.get("hybrid_score"), }) .collect(); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 22c644c..c88733f 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -69,7 +69,7 @@ pub fn get_tool_definitions() -> Vec { }), json!({ "name": "query", - "description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.", + "description": "Query stored memories using hybrid semantic plus keyword search. Returns the most relevant memories based on vector similarity and exact-text ranking.", "inputSchema": { "type": "object", "properties": { diff --git a/src/tools/query.rs b/src/tools/query.rs index c9a1a5a..5eb8a66 100644 --- a/src/tools/query.rs +++ b/src/tools/query.rs @@ -36,9 +36,14 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result .and_then(|v| v.as_f64()) .unwrap_or(0.5) as f32; + let (vector_weight, text_weight) = normalized_weights( + state.config.query.vector_weight, + state.config.query.text_weight, + ); + info!( - "Querying memories for agent '{}': '{}' (limit={}, threshold={})", - agent_id, query_text, limit, threshold + "Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})", + agent_id, query_text, limit, threshold, vector_weight, text_weight ); // Generate embedding for query using Arc @@ -49,7 +54,15 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result // Search database let matches = state .db - .query_memories(agent_id, &query_embedding, limit, threshold) + .query_memories( + agent_id, + query_text, + &query_embedding, + limit, + threshold, + vector_weight, + text_weight, + ) .await .context("Failed to query memories")?; @@ -63,6 +76,9 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result "id": m.record.id.to_string(), "content": m.record.content, "similarity": m.similarity, + "vector_score": m.vector_score, + "text_score": m.text_score, + "hybrid_score": m.hybrid_score, "keywords": m.record.keywords, "metadata": m.record.metadata, "created_at": m.record.created_at.to_rfc3339() @@ -74,8 +90,22 @@ pub async fn execute(state: &Arc, arguments: Value) -> Result "success": true, "agent_id": agent_id, "query": query_text, + "vector_weight": vector_weight, + "text_weight": text_weight, "count": results.len(), "results": results }) .to_string()) } + +fn normalized_weights(vector_weight: f32, text_weight: f32) -> (f32, f32) { + let vector_weight = vector_weight.max(0.0); + let text_weight = text_weight.max(0.0); + let total = vector_weight + text_weight; + + if total <= f32::EPSILON { + (0.6, 0.4) + } else { + (vector_weight / total, text_weight / total) + } +}