mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 06:39:06 +00:00
Merge pull request 'Add hybrid text plus vector memory search' (#21) from codex/issue-17-hybrid-search into main
Reviewed-on: Ingwaz/openbrain-mcp#21
This commit is contained in:
@@ -19,6 +19,13 @@ OPENBRAIN__DATABASE__POOL_SIZE=10
|
||||
OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2
|
||||
OPENBRAIN__EMBEDDING__DIMENSION=384
|
||||
|
||||
# Hybrid query scoring
|
||||
OPENBRAIN__QUERY__VECTOR_WEIGHT=0.6
|
||||
OPENBRAIN__QUERY__TEXT_WEIGHT=0.4
|
||||
# Backward-compatible plain env aliases
|
||||
# VECTOR_WEIGHT=0.6
|
||||
# TEXT_WEIGHT=0.4
|
||||
|
||||
# Authentication (optional)
|
||||
OPENBRAIN__AUTH__ENABLED=false
|
||||
# Comma-separated list of API keys
|
||||
|
||||
36
migrations/V2__hybrid_search.sql
Normal file
36
migrations/V2__hybrid_search.sql
Normal file
@@ -0,0 +1,36 @@
|
||||
ALTER TABLE memories
|
||||
ADD COLUMN IF NOT EXISTS tsv tsvector;
|
||||
|
||||
CREATE OR REPLACE FUNCTION memories_tsv_trigger()
|
||||
RETURNS trigger
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
NEW.tsv :=
|
||||
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.content, '')), 'A') ||
|
||||
setweight(
|
||||
to_tsvector('pg_catalog.english', COALESCE(array_to_string(NEW.keywords, ' '), '')),
|
||||
'B'
|
||||
);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$;
|
||||
|
||||
UPDATE memories
|
||||
SET tsv =
|
||||
setweight(to_tsvector('pg_catalog.english', COALESCE(content, '')), 'A') ||
|
||||
setweight(
|
||||
to_tsvector('pg_catalog.english', COALESCE(array_to_string(keywords, ' '), '')),
|
||||
'B'
|
||||
)
|
||||
WHERE tsv IS NULL;
|
||||
|
||||
DROP TRIGGER IF EXISTS memories_tsv_update ON memories;
|
||||
|
||||
CREATE TRIGGER memories_tsv_update
|
||||
BEFORE INSERT OR UPDATE OF content, keywords ON memories
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION memories_tsv_trigger();
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_tsv
|
||||
ON memories USING GIN (tsv);
|
||||
@@ -11,6 +11,7 @@ pub struct Config {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub embedding: EmbeddingConfig,
|
||||
pub query: QueryConfig,
|
||||
pub auth: AuthConfig,
|
||||
}
|
||||
|
||||
@@ -45,6 +46,15 @@ pub struct EmbeddingConfig {
|
||||
pub dimension: usize,
|
||||
}
|
||||
|
||||
/// Query scoring configuration
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct QueryConfig {
|
||||
#[serde(default = "default_vector_weight")]
|
||||
pub vector_weight: f32,
|
||||
#[serde(default = "default_text_weight")]
|
||||
pub text_weight: f32,
|
||||
}
|
||||
|
||||
/// Authentication configuration
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AuthConfig {
|
||||
@@ -86,6 +96,8 @@ fn default_db_port() -> u16 { 5432 }
|
||||
fn default_pool_size() -> usize { 10 }
|
||||
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
||||
fn default_embedding_dim() -> usize { 384 }
|
||||
fn default_vector_weight() -> f32 { 0.6 }
|
||||
fn default_text_weight() -> f32 { 0.4 }
|
||||
fn default_auth_enabled() -> bool { false }
|
||||
|
||||
impl Config {
|
||||
@@ -104,6 +116,9 @@ impl Config {
|
||||
// Embedding settings
|
||||
.set_default("embedding.model_path", default_model_path())?
|
||||
.set_default("embedding.dimension", default_embedding_dim() as i64)?
|
||||
// Query settings
|
||||
.set_default("query.vector_weight", default_vector_weight() as f64)?
|
||||
.set_default("query.text_weight", default_text_weight() as f64)?
|
||||
// Auth settings
|
||||
.set_default("auth.enabled", default_auth_enabled())?
|
||||
// Load from environment with OPENBRAIN_ prefix
|
||||
@@ -114,7 +129,21 @@ impl Config {
|
||||
)
|
||||
.build()?;
|
||||
|
||||
Ok(config.try_deserialize()?)
|
||||
let mut config: Self = config.try_deserialize()?;
|
||||
|
||||
// Keep compatibility with plain env names proposed in issue #17.
|
||||
if let Ok(vector_weight) = std::env::var("VECTOR_WEIGHT") {
|
||||
if let Ok(parsed) = vector_weight.parse::<f32>() {
|
||||
config.query.vector_weight = parsed;
|
||||
}
|
||||
}
|
||||
if let Ok(text_weight) = std::env::var("TEXT_WEIGHT") {
|
||||
if let Ok(parsed) = text_weight.parse::<f32>() {
|
||||
config.query.text_weight = parsed;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,6 +166,10 @@ impl Default for Config {
|
||||
model_path: default_model_path(),
|
||||
dimension: default_embedding_dim(),
|
||||
},
|
||||
query: QueryConfig {
|
||||
vector_weight: default_vector_weight(),
|
||||
text_weight: default_text_weight(),
|
||||
},
|
||||
auth: AuthConfig {
|
||||
enabled: default_auth_enabled(),
|
||||
api_keys: Vec::new(),
|
||||
|
||||
77
src/db.rs
77
src/db.rs
@@ -36,6 +36,9 @@ pub struct MemoryRecord {
|
||||
pub struct MemoryMatch {
|
||||
pub record: MemoryRecord,
|
||||
pub similarity: f32,
|
||||
pub vector_score: f32,
|
||||
pub text_score: f32,
|
||||
pub hybrid_score: f32,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
@@ -95,27 +98,80 @@ impl Database {
|
||||
pub async fn query_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
query_text: &str,
|
||||
embedding: &[f32],
|
||||
limit: i64,
|
||||
threshold: f32,
|
||||
vector_weight: f32,
|
||||
text_weight: f32,
|
||||
) -> Result<Vec<MemoryMatch>> {
|
||||
let client = self.pool.get().await?;
|
||||
let vector = Vector::from(embedding.to_vec());
|
||||
let threshold_f64 = threshold as f64;
|
||||
let vector_weight_f64 = vector_weight as f64;
|
||||
let text_weight_f64 = text_weight as f64;
|
||||
|
||||
let rows = client
|
||||
.query(
|
||||
r#"
|
||||
WITH search_query AS (
|
||||
SELECT
|
||||
NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text,
|
||||
plainto_tsquery('pg_catalog.english', $2) AS ts_query
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
id,
|
||||
agent_id,
|
||||
content,
|
||||
keywords,
|
||||
metadata,
|
||||
created_at,
|
||||
(1 - (embedding <=> $1))::real AS vector_score,
|
||||
CASE
|
||||
WHEN search_query.query_text IS NULL THEN 0::real
|
||||
WHEN memories.tsv @@ search_query.ts_query
|
||||
THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real
|
||||
ELSE 0::real
|
||||
END AS text_score
|
||||
FROM memories
|
||||
CROSS JOIN search_query
|
||||
WHERE memories.agent_id = $3
|
||||
),
|
||||
ranked AS (
|
||||
SELECT
|
||||
*,
|
||||
MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match
|
||||
FROM scored
|
||||
)
|
||||
SELECT
|
||||
id, agent_id, content, keywords, metadata, created_at,
|
||||
(1 - (embedding <=> $1))::real AS similarity
|
||||
FROM memories
|
||||
WHERE agent_id = $2
|
||||
AND (1 - (embedding <=> $1)) >= $3
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $4
|
||||
id,
|
||||
agent_id,
|
||||
content,
|
||||
keywords,
|
||||
metadata,
|
||||
created_at,
|
||||
vector_score,
|
||||
text_score,
|
||||
CASE
|
||||
WHEN has_text_match = 1
|
||||
THEN (($5 * vector_score) + ($6 * text_score))::real
|
||||
ELSE vector_score
|
||||
END AS hybrid_score
|
||||
FROM ranked
|
||||
WHERE vector_score >= $4 OR text_score > 0
|
||||
ORDER BY hybrid_score DESC, vector_score DESC
|
||||
LIMIT $7
|
||||
"#,
|
||||
&[&vector, &agent_id, &threshold_f64, &limit],
|
||||
&[
|
||||
&vector,
|
||||
&query_text,
|
||||
&agent_id,
|
||||
&threshold_f64,
|
||||
&vector_weight_f64,
|
||||
&text_weight_f64,
|
||||
&limit,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.context("Failed to query memories")?;
|
||||
@@ -133,7 +189,10 @@ impl Database {
|
||||
metadata: row.get("metadata"),
|
||||
created_at: row.get("created_at"),
|
||||
},
|
||||
similarity: row.get("similarity"),
|
||||
similarity: row.get("hybrid_score"),
|
||||
vector_score: row.get("vector_score"),
|
||||
text_score: row.get("text_score"),
|
||||
hybrid_score: row.get("hybrid_score"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
}),
|
||||
json!({
|
||||
"name": "query",
|
||||
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
|
||||
"description": "Query stored memories using hybrid semantic plus keyword search. Returns the most relevant memories based on vector similarity and exact-text ranking.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -36,9 +36,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.5) as f32;
|
||||
|
||||
let (vector_weight, text_weight) = normalized_weights(
|
||||
state.config.query.vector_weight,
|
||||
state.config.query.text_weight,
|
||||
);
|
||||
|
||||
info!(
|
||||
"Querying memories for agent '{}': '{}' (limit={}, threshold={})",
|
||||
agent_id, query_text, limit, threshold
|
||||
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
|
||||
agent_id, query_text, limit, threshold, vector_weight, text_weight
|
||||
);
|
||||
|
||||
// Generate embedding for query using Arc<EmbeddingEngine>
|
||||
@@ -49,7 +54,15 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
// Search database
|
||||
let matches = state
|
||||
.db
|
||||
.query_memories(agent_id, &query_embedding, limit, threshold)
|
||||
.query_memories(
|
||||
agent_id,
|
||||
query_text,
|
||||
&query_embedding,
|
||||
limit,
|
||||
threshold,
|
||||
vector_weight,
|
||||
text_weight,
|
||||
)
|
||||
.await
|
||||
.context("Failed to query memories")?;
|
||||
|
||||
@@ -63,6 +76,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
"id": m.record.id.to_string(),
|
||||
"content": m.record.content,
|
||||
"similarity": m.similarity,
|
||||
"vector_score": m.vector_score,
|
||||
"text_score": m.text_score,
|
||||
"hybrid_score": m.hybrid_score,
|
||||
"keywords": m.record.keywords,
|
||||
"metadata": m.record.metadata,
|
||||
"created_at": m.record.created_at.to_rfc3339()
|
||||
@@ -74,8 +90,22 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
"success": true,
|
||||
"agent_id": agent_id,
|
||||
"query": query_text,
|
||||
"vector_weight": vector_weight,
|
||||
"text_weight": text_weight,
|
||||
"count": results.len(),
|
||||
"results": results
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn normalized_weights(vector_weight: f32, text_weight: f32) -> (f32, f32) {
|
||||
let vector_weight = vector_weight.max(0.0);
|
||||
let text_weight = text_weight.max(0.0);
|
||||
let total = vector_weight + text_weight;
|
||||
|
||||
if total <= f32::EPSILON {
|
||||
(0.6, 0.4)
|
||||
} else {
|
||||
(vector_weight / total, text_weight / total)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user