mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Add hybrid text plus vector memory search
This commit is contained in:
77
src/db.rs
77
src/db.rs
@@ -36,6 +36,9 @@ pub struct MemoryRecord {
|
||||
pub struct MemoryMatch {
|
||||
pub record: MemoryRecord,
|
||||
pub similarity: f32,
|
||||
pub vector_score: f32,
|
||||
pub text_score: f32,
|
||||
pub hybrid_score: f32,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
@@ -95,27 +98,80 @@ impl Database {
|
||||
pub async fn query_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
query_text: &str,
|
||||
embedding: &[f32],
|
||||
limit: i64,
|
||||
threshold: f32,
|
||||
vector_weight: f32,
|
||||
text_weight: f32,
|
||||
) -> Result<Vec<MemoryMatch>> {
|
||||
let client = self.pool.get().await?;
|
||||
let vector = Vector::from(embedding.to_vec());
|
||||
let threshold_f64 = threshold as f64;
|
||||
let vector_weight_f64 = vector_weight as f64;
|
||||
let text_weight_f64 = text_weight as f64;
|
||||
|
||||
let rows = client
|
||||
.query(
|
||||
r#"
|
||||
WITH search_query AS (
|
||||
SELECT
|
||||
NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text,
|
||||
plainto_tsquery('pg_catalog.english', $2) AS ts_query
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
id,
|
||||
agent_id,
|
||||
content,
|
||||
keywords,
|
||||
metadata,
|
||||
created_at,
|
||||
(1 - (embedding <=> $1))::real AS vector_score,
|
||||
CASE
|
||||
WHEN search_query.query_text IS NULL THEN 0::real
|
||||
WHEN memories.tsv @@ search_query.ts_query
|
||||
THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real
|
||||
ELSE 0::real
|
||||
END AS text_score
|
||||
FROM memories
|
||||
CROSS JOIN search_query
|
||||
WHERE memories.agent_id = $3
|
||||
),
|
||||
ranked AS (
|
||||
SELECT
|
||||
*,
|
||||
MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match
|
||||
FROM scored
|
||||
)
|
||||
SELECT
|
||||
id, agent_id, content, keywords, metadata, created_at,
|
||||
(1 - (embedding <=> $1))::real AS similarity
|
||||
FROM memories
|
||||
WHERE agent_id = $2
|
||||
AND (1 - (embedding <=> $1)) >= $3
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $4
|
||||
id,
|
||||
agent_id,
|
||||
content,
|
||||
keywords,
|
||||
metadata,
|
||||
created_at,
|
||||
vector_score,
|
||||
text_score,
|
||||
CASE
|
||||
WHEN has_text_match = 1
|
||||
THEN (($5 * vector_score) + ($6 * text_score))::real
|
||||
ELSE vector_score
|
||||
END AS hybrid_score
|
||||
FROM ranked
|
||||
WHERE vector_score >= $4 OR text_score > 0
|
||||
ORDER BY hybrid_score DESC, vector_score DESC
|
||||
LIMIT $7
|
||||
"#,
|
||||
&[&vector, &agent_id, &threshold_f64, &limit],
|
||||
&[
|
||||
&vector,
|
||||
&query_text,
|
||||
&agent_id,
|
||||
&threshold_f64,
|
||||
&vector_weight_f64,
|
||||
&text_weight_f64,
|
||||
&limit,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.context("Failed to query memories")?;
|
||||
@@ -133,7 +189,10 @@ impl Database {
|
||||
metadata: row.get("metadata"),
|
||||
created_at: row.get("created_at"),
|
||||
},
|
||||
similarity: row.get("similarity"),
|
||||
similarity: row.get("hybrid_score"),
|
||||
vector_score: row.get("vector_score"),
|
||||
text_score: row.get("text_score"),
|
||||
hybrid_score: row.get("hybrid_score"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user