Add hybrid text plus vector memory search

This commit is contained in:
Agent Zero
2026-03-22 22:38:14 +00:00
parent 347805cc29
commit 0ba37f8573
6 changed files with 179 additions and 14 deletions

View File

@@ -19,6 +19,13 @@ OPENBRAIN__DATABASE__POOL_SIZE=10
OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2
OPENBRAIN__EMBEDDING__DIMENSION=384
# Hybrid query scoring
OPENBRAIN__QUERY__VECTOR_WEIGHT=0.6
OPENBRAIN__QUERY__TEXT_WEIGHT=0.4
# Backward-compatible plain env aliases
# VECTOR_WEIGHT=0.6
# TEXT_WEIGHT=0.4
# Authentication (optional)
OPENBRAIN__AUTH__ENABLED=false
# Comma-separated list of API keys

View File

@@ -0,0 +1,36 @@
ALTER TABLE memories
ADD COLUMN IF NOT EXISTS tsv tsvector;
CREATE OR REPLACE FUNCTION memories_tsv_trigger()
RETURNS trigger
LANGUAGE plpgsql
AS $$
BEGIN
NEW.tsv :=
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.content, '')), 'A') ||
setweight(
to_tsvector('pg_catalog.english', COALESCE(array_to_string(NEW.keywords, ' '), '')),
'B'
);
RETURN NEW;
END;
$$;
UPDATE memories
SET tsv =
setweight(to_tsvector('pg_catalog.english', COALESCE(content, '')), 'A') ||
setweight(
to_tsvector('pg_catalog.english', COALESCE(array_to_string(keywords, ' '), '')),
'B'
)
WHERE tsv IS NULL;
DROP TRIGGER IF EXISTS memories_tsv_update ON memories;
CREATE TRIGGER memories_tsv_update
BEFORE INSERT OR UPDATE OF content, keywords ON memories
FOR EACH ROW
EXECUTE FUNCTION memories_tsv_trigger();
CREATE INDEX IF NOT EXISTS idx_memories_tsv
ON memories USING GIN (tsv);

View File

@@ -11,6 +11,7 @@ pub struct Config {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub embedding: EmbeddingConfig,
pub query: QueryConfig,
pub auth: AuthConfig,
}
@@ -45,6 +46,15 @@ pub struct EmbeddingConfig {
pub dimension: usize,
}
/// Query scoring configuration
#[derive(Debug, Clone, Deserialize)]
pub struct QueryConfig {
#[serde(default = "default_vector_weight")]
pub vector_weight: f32,
#[serde(default = "default_text_weight")]
pub text_weight: f32,
}
/// Authentication configuration
#[derive(Debug, Clone, Deserialize)]
pub struct AuthConfig {
@@ -86,6 +96,8 @@ fn default_db_port() -> u16 { 5432 }
fn default_pool_size() -> usize { 10 }
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
fn default_embedding_dim() -> usize { 384 }
fn default_vector_weight() -> f32 { 0.6 }
fn default_text_weight() -> f32 { 0.4 }
fn default_auth_enabled() -> bool { false }
impl Config {
@@ -104,6 +116,9 @@ impl Config {
// Embedding settings
.set_default("embedding.model_path", default_model_path())?
.set_default("embedding.dimension", default_embedding_dim() as i64)?
// Query settings
.set_default("query.vector_weight", default_vector_weight() as f64)?
.set_default("query.text_weight", default_text_weight() as f64)?
// Auth settings
.set_default("auth.enabled", default_auth_enabled())?
// Load from environment with OPENBRAIN_ prefix
@@ -114,7 +129,21 @@ impl Config {
)
.build()?;
Ok(config.try_deserialize()?)
let mut config: Self = config.try_deserialize()?;
// Keep compatibility with plain env names proposed in issue #17.
if let Ok(vector_weight) = std::env::var("VECTOR_WEIGHT") {
if let Ok(parsed) = vector_weight.parse::<f32>() {
config.query.vector_weight = parsed;
}
}
if let Ok(text_weight) = std::env::var("TEXT_WEIGHT") {
if let Ok(parsed) = text_weight.parse::<f32>() {
config.query.text_weight = parsed;
}
}
Ok(config)
}
}
@@ -137,6 +166,10 @@ impl Default for Config {
model_path: default_model_path(),
dimension: default_embedding_dim(),
},
query: QueryConfig {
vector_weight: default_vector_weight(),
text_weight: default_text_weight(),
},
auth: AuthConfig {
enabled: default_auth_enabled(),
api_keys: Vec::new(),

View File

@@ -36,6 +36,9 @@ pub struct MemoryRecord {
pub struct MemoryMatch {
pub record: MemoryRecord,
pub similarity: f32,
pub vector_score: f32,
pub text_score: f32,
pub hybrid_score: f32,
}
impl Database {
@@ -95,27 +98,80 @@ impl Database {
pub async fn query_memories(
&self,
agent_id: &str,
query_text: &str,
embedding: &[f32],
limit: i64,
threshold: f32,
vector_weight: f32,
text_weight: f32,
) -> Result<Vec<MemoryMatch>> {
let client = self.pool.get().await?;
let vector = Vector::from(embedding.to_vec());
let threshold_f64 = threshold as f64;
let vector_weight_f64 = vector_weight as f64;
let text_weight_f64 = text_weight as f64;
let rows = client
.query(
r#"
WITH search_query AS (
SELECT
id, agent_id, content, keywords, metadata, created_at,
(1 - (embedding <=> $1))::real AS similarity
NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text,
plainto_tsquery('pg_catalog.english', $2) AS ts_query
),
scored AS (
SELECT
id,
agent_id,
content,
keywords,
metadata,
created_at,
(1 - (embedding <=> $1))::real AS vector_score,
CASE
WHEN search_query.query_text IS NULL THEN 0::real
WHEN memories.tsv @@ search_query.ts_query
THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real
ELSE 0::real
END AS text_score
FROM memories
WHERE agent_id = $2
AND (1 - (embedding <=> $1)) >= $3
ORDER BY embedding <=> $1
LIMIT $4
CROSS JOIN search_query
WHERE memories.agent_id = $3
),
ranked AS (
SELECT
*,
MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match
FROM scored
)
SELECT
id,
agent_id,
content,
keywords,
metadata,
created_at,
vector_score,
text_score,
CASE
WHEN has_text_match = 1
THEN (($5 * vector_score) + ($6 * text_score))::real
ELSE vector_score
END AS hybrid_score
FROM ranked
WHERE vector_score >= $4 OR text_score > 0
ORDER BY hybrid_score DESC, vector_score DESC
LIMIT $7
"#,
&[&vector, &agent_id, &threshold_f64, &limit],
&[
&vector,
&query_text,
&agent_id,
&threshold_f64,
&vector_weight_f64,
&text_weight_f64,
&limit,
],
)
.await
.context("Failed to query memories")?;
@@ -133,7 +189,10 @@ impl Database {
metadata: row.get("metadata"),
created_at: row.get("created_at"),
},
similarity: row.get("similarity"),
similarity: row.get("hybrid_score"),
vector_score: row.get("vector_score"),
text_score: row.get("text_score"),
hybrid_score: row.get("hybrid_score"),
})
.collect();

View File

@@ -69,7 +69,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
}),
json!({
"name": "query",
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
"description": "Query stored memories using hybrid semantic plus keyword search. Returns the most relevant memories based on vector similarity and exact-text ranking.",
"inputSchema": {
"type": "object",
"properties": {

View File

@@ -36,9 +36,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
.and_then(|v| v.as_f64())
.unwrap_or(0.5) as f32;
let (vector_weight, text_weight) = normalized_weights(
state.config.query.vector_weight,
state.config.query.text_weight,
);
info!(
"Querying memories for agent '{}': '{}' (limit={}, threshold={})",
agent_id, query_text, limit, threshold
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
agent_id, query_text, limit, threshold, vector_weight, text_weight
);
// Generate embedding for query using Arc<EmbeddingEngine>
@@ -49,7 +54,15 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
// Search database
let matches = state
.db
.query_memories(agent_id, &query_embedding, limit, threshold)
.query_memories(
agent_id,
query_text,
&query_embedding,
limit,
threshold,
vector_weight,
text_weight,
)
.await
.context("Failed to query memories")?;
@@ -63,6 +76,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
"id": m.record.id.to_string(),
"content": m.record.content,
"similarity": m.similarity,
"vector_score": m.vector_score,
"text_score": m.text_score,
"hybrid_score": m.hybrid_score,
"keywords": m.record.keywords,
"metadata": m.record.metadata,
"created_at": m.record.created_at.to_rfc3339()
@@ -74,8 +90,22 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
"success": true,
"agent_id": agent_id,
"query": query_text,
"vector_weight": vector_weight,
"text_weight": text_weight,
"count": results.len(),
"results": results
})
.to_string())
}
fn normalized_weights(vector_weight: f32, text_weight: f32) -> (f32, f32) {
let vector_weight = vector_weight.max(0.0);
let text_weight = text_weight.max(0.0);
let total = vector_weight + text_weight;
if total <= f32::EPSILON {
(0.6, 0.4)
} else {
(vector_weight / total, text_weight / total)
}
}