Add hybrid text plus vector memory search

This commit is contained in:
Agent Zero
2026-03-22 22:38:14 +00:00
parent 347805cc29
commit 0ba37f8573
6 changed files with 179 additions and 14 deletions

View File

@@ -36,6 +36,9 @@ pub struct MemoryRecord {
pub struct MemoryMatch {
pub record: MemoryRecord,
pub similarity: f32,
pub vector_score: f32,
pub text_score: f32,
pub hybrid_score: f32,
}
impl Database {
@@ -95,27 +98,80 @@ impl Database {
pub async fn query_memories(
&self,
agent_id: &str,
query_text: &str,
embedding: &[f32],
limit: i64,
threshold: f32,
vector_weight: f32,
text_weight: f32,
) -> Result<Vec<MemoryMatch>> {
let client = self.pool.get().await?;
let vector = Vector::from(embedding.to_vec());
let threshold_f64 = threshold as f64;
let vector_weight_f64 = vector_weight as f64;
let text_weight_f64 = text_weight as f64;
let rows = client
.query(
r#"
WITH search_query AS (
SELECT
NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text,
plainto_tsquery('pg_catalog.english', $2) AS ts_query
),
scored AS (
SELECT
id,
agent_id,
content,
keywords,
metadata,
created_at,
(1 - (embedding <=> $1))::real AS vector_score,
CASE
WHEN search_query.query_text IS NULL THEN 0::real
WHEN memories.tsv @@ search_query.ts_query
THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real
ELSE 0::real
END AS text_score
FROM memories
CROSS JOIN search_query
WHERE memories.agent_id = $3
),
ranked AS (
SELECT
*,
MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match
FROM scored
)
SELECT
id, agent_id, content, keywords, metadata, created_at,
(1 - (embedding <=> $1))::real AS similarity
FROM memories
WHERE agent_id = $2
AND (1 - (embedding <=> $1)) >= $3
ORDER BY embedding <=> $1
LIMIT $4
id,
agent_id,
content,
keywords,
metadata,
created_at,
vector_score,
text_score,
CASE
WHEN has_text_match = 1
THEN (($5 * vector_score) + ($6 * text_score))::real
ELSE vector_score
END AS hybrid_score
FROM ranked
WHERE vector_score >= $4 OR text_score > 0
ORDER BY hybrid_score DESC, vector_score DESC
LIMIT $7
"#,
&[&vector, &agent_id, &threshold_f64, &limit],
&[
&vector,
&query_text,
&agent_id,
&threshold_f64,
&vector_weight_f64,
&text_weight_f64,
&limit,
],
)
.await
.context("Failed to query memories")?;
@@ -133,7 +189,10 @@ impl Database {
metadata: row.get("metadata"),
created_at: row.get("created_at"),
},
similarity: row.get("similarity"),
similarity: row.get("hybrid_score"),
vector_score: row.get("vector_score"),
text_score: row.get("text_score"),
hybrid_score: row.get("hybrid_score"),
})
.collect();