mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Merge pull request 'Add hybrid text plus vector memory search' (#21) from codex/issue-17-hybrid-search into main
Reviewed-on: Ingwaz/openbrain-mcp#21
This commit is contained in:
@@ -19,6 +19,13 @@ OPENBRAIN__DATABASE__POOL_SIZE=10
|
|||||||
OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2
|
OPENBRAIN__EMBEDDING__MODEL_PATH=models/all-MiniLM-L6-v2
|
||||||
OPENBRAIN__EMBEDDING__DIMENSION=384
|
OPENBRAIN__EMBEDDING__DIMENSION=384
|
||||||
|
|
||||||
|
# Hybrid query scoring
|
||||||
|
OPENBRAIN__QUERY__VECTOR_WEIGHT=0.6
|
||||||
|
OPENBRAIN__QUERY__TEXT_WEIGHT=0.4
|
||||||
|
# Backward-compatible plain env aliases
|
||||||
|
# VECTOR_WEIGHT=0.6
|
||||||
|
# TEXT_WEIGHT=0.4
|
||||||
|
|
||||||
# Authentication (optional)
|
# Authentication (optional)
|
||||||
OPENBRAIN__AUTH__ENABLED=false
|
OPENBRAIN__AUTH__ENABLED=false
|
||||||
# Comma-separated list of API keys
|
# Comma-separated list of API keys
|
||||||
|
|||||||
36
migrations/V2__hybrid_search.sql
Normal file
36
migrations/V2__hybrid_search.sql
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
ALTER TABLE memories
|
||||||
|
ADD COLUMN IF NOT EXISTS tsv tsvector;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION memories_tsv_trigger()
|
||||||
|
RETURNS trigger
|
||||||
|
LANGUAGE plpgsql
|
||||||
|
AS $$
|
||||||
|
BEGIN
|
||||||
|
NEW.tsv :=
|
||||||
|
setweight(to_tsvector('pg_catalog.english', COALESCE(NEW.content, '')), 'A') ||
|
||||||
|
setweight(
|
||||||
|
to_tsvector('pg_catalog.english', COALESCE(array_to_string(NEW.keywords, ' '), '')),
|
||||||
|
'B'
|
||||||
|
);
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
UPDATE memories
|
||||||
|
SET tsv =
|
||||||
|
setweight(to_tsvector('pg_catalog.english', COALESCE(content, '')), 'A') ||
|
||||||
|
setweight(
|
||||||
|
to_tsvector('pg_catalog.english', COALESCE(array_to_string(keywords, ' '), '')),
|
||||||
|
'B'
|
||||||
|
)
|
||||||
|
WHERE tsv IS NULL;
|
||||||
|
|
||||||
|
DROP TRIGGER IF EXISTS memories_tsv_update ON memories;
|
||||||
|
|
||||||
|
CREATE TRIGGER memories_tsv_update
|
||||||
|
BEFORE INSERT OR UPDATE OF content, keywords ON memories
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE FUNCTION memories_tsv_trigger();
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_tsv
|
||||||
|
ON memories USING GIN (tsv);
|
||||||
@@ -11,6 +11,7 @@ pub struct Config {
|
|||||||
pub server: ServerConfig,
|
pub server: ServerConfig,
|
||||||
pub database: DatabaseConfig,
|
pub database: DatabaseConfig,
|
||||||
pub embedding: EmbeddingConfig,
|
pub embedding: EmbeddingConfig,
|
||||||
|
pub query: QueryConfig,
|
||||||
pub auth: AuthConfig,
|
pub auth: AuthConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,6 +46,15 @@ pub struct EmbeddingConfig {
|
|||||||
pub dimension: usize,
|
pub dimension: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Query scoring configuration
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct QueryConfig {
|
||||||
|
#[serde(default = "default_vector_weight")]
|
||||||
|
pub vector_weight: f32,
|
||||||
|
#[serde(default = "default_text_weight")]
|
||||||
|
pub text_weight: f32,
|
||||||
|
}
|
||||||
|
|
||||||
/// Authentication configuration
|
/// Authentication configuration
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct AuthConfig {
|
pub struct AuthConfig {
|
||||||
@@ -86,6 +96,8 @@ fn default_db_port() -> u16 { 5432 }
|
|||||||
fn default_pool_size() -> usize { 10 }
|
fn default_pool_size() -> usize { 10 }
|
||||||
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
||||||
fn default_embedding_dim() -> usize { 384 }
|
fn default_embedding_dim() -> usize { 384 }
|
||||||
|
fn default_vector_weight() -> f32 { 0.6 }
|
||||||
|
fn default_text_weight() -> f32 { 0.4 }
|
||||||
fn default_auth_enabled() -> bool { false }
|
fn default_auth_enabled() -> bool { false }
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@@ -104,6 +116,9 @@ impl Config {
|
|||||||
// Embedding settings
|
// Embedding settings
|
||||||
.set_default("embedding.model_path", default_model_path())?
|
.set_default("embedding.model_path", default_model_path())?
|
||||||
.set_default("embedding.dimension", default_embedding_dim() as i64)?
|
.set_default("embedding.dimension", default_embedding_dim() as i64)?
|
||||||
|
// Query settings
|
||||||
|
.set_default("query.vector_weight", default_vector_weight() as f64)?
|
||||||
|
.set_default("query.text_weight", default_text_weight() as f64)?
|
||||||
// Auth settings
|
// Auth settings
|
||||||
.set_default("auth.enabled", default_auth_enabled())?
|
.set_default("auth.enabled", default_auth_enabled())?
|
||||||
// Load from environment with OPENBRAIN_ prefix
|
// Load from environment with OPENBRAIN_ prefix
|
||||||
@@ -114,7 +129,21 @@ impl Config {
|
|||||||
)
|
)
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
Ok(config.try_deserialize()?)
|
let mut config: Self = config.try_deserialize()?;
|
||||||
|
|
||||||
|
// Keep compatibility with plain env names proposed in issue #17.
|
||||||
|
if let Ok(vector_weight) = std::env::var("VECTOR_WEIGHT") {
|
||||||
|
if let Ok(parsed) = vector_weight.parse::<f32>() {
|
||||||
|
config.query.vector_weight = parsed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Ok(text_weight) = std::env::var("TEXT_WEIGHT") {
|
||||||
|
if let Ok(parsed) = text_weight.parse::<f32>() {
|
||||||
|
config.query.text_weight = parsed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,6 +166,10 @@ impl Default for Config {
|
|||||||
model_path: default_model_path(),
|
model_path: default_model_path(),
|
||||||
dimension: default_embedding_dim(),
|
dimension: default_embedding_dim(),
|
||||||
},
|
},
|
||||||
|
query: QueryConfig {
|
||||||
|
vector_weight: default_vector_weight(),
|
||||||
|
text_weight: default_text_weight(),
|
||||||
|
},
|
||||||
auth: AuthConfig {
|
auth: AuthConfig {
|
||||||
enabled: default_auth_enabled(),
|
enabled: default_auth_enabled(),
|
||||||
api_keys: Vec::new(),
|
api_keys: Vec::new(),
|
||||||
|
|||||||
77
src/db.rs
77
src/db.rs
@@ -36,6 +36,9 @@ pub struct MemoryRecord {
|
|||||||
pub struct MemoryMatch {
|
pub struct MemoryMatch {
|
||||||
pub record: MemoryRecord,
|
pub record: MemoryRecord,
|
||||||
pub similarity: f32,
|
pub similarity: f32,
|
||||||
|
pub vector_score: f32,
|
||||||
|
pub text_score: f32,
|
||||||
|
pub hybrid_score: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Database {
|
impl Database {
|
||||||
@@ -95,27 +98,80 @@ impl Database {
|
|||||||
pub async fn query_memories(
|
pub async fn query_memories(
|
||||||
&self,
|
&self,
|
||||||
agent_id: &str,
|
agent_id: &str,
|
||||||
|
query_text: &str,
|
||||||
embedding: &[f32],
|
embedding: &[f32],
|
||||||
limit: i64,
|
limit: i64,
|
||||||
threshold: f32,
|
threshold: f32,
|
||||||
|
vector_weight: f32,
|
||||||
|
text_weight: f32,
|
||||||
) -> Result<Vec<MemoryMatch>> {
|
) -> Result<Vec<MemoryMatch>> {
|
||||||
let client = self.pool.get().await?;
|
let client = self.pool.get().await?;
|
||||||
let vector = Vector::from(embedding.to_vec());
|
let vector = Vector::from(embedding.to_vec());
|
||||||
let threshold_f64 = threshold as f64;
|
let threshold_f64 = threshold as f64;
|
||||||
|
let vector_weight_f64 = vector_weight as f64;
|
||||||
|
let text_weight_f64 = text_weight as f64;
|
||||||
|
|
||||||
let rows = client
|
let rows = client
|
||||||
.query(
|
.query(
|
||||||
r#"
|
r#"
|
||||||
|
WITH search_query AS (
|
||||||
|
SELECT
|
||||||
|
NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text,
|
||||||
|
plainto_tsquery('pg_catalog.english', $2) AS ts_query
|
||||||
|
),
|
||||||
|
scored AS (
|
||||||
|
SELECT
|
||||||
|
id,
|
||||||
|
agent_id,
|
||||||
|
content,
|
||||||
|
keywords,
|
||||||
|
metadata,
|
||||||
|
created_at,
|
||||||
|
(1 - (embedding <=> $1))::real AS vector_score,
|
||||||
|
CASE
|
||||||
|
WHEN search_query.query_text IS NULL THEN 0::real
|
||||||
|
WHEN memories.tsv @@ search_query.ts_query
|
||||||
|
THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real
|
||||||
|
ELSE 0::real
|
||||||
|
END AS text_score
|
||||||
|
FROM memories
|
||||||
|
CROSS JOIN search_query
|
||||||
|
WHERE memories.agent_id = $3
|
||||||
|
),
|
||||||
|
ranked AS (
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match
|
||||||
|
FROM scored
|
||||||
|
)
|
||||||
SELECT
|
SELECT
|
||||||
id, agent_id, content, keywords, metadata, created_at,
|
id,
|
||||||
(1 - (embedding <=> $1))::real AS similarity
|
agent_id,
|
||||||
FROM memories
|
content,
|
||||||
WHERE agent_id = $2
|
keywords,
|
||||||
AND (1 - (embedding <=> $1)) >= $3
|
metadata,
|
||||||
ORDER BY embedding <=> $1
|
created_at,
|
||||||
LIMIT $4
|
vector_score,
|
||||||
|
text_score,
|
||||||
|
CASE
|
||||||
|
WHEN has_text_match = 1
|
||||||
|
THEN (($5 * vector_score) + ($6 * text_score))::real
|
||||||
|
ELSE vector_score
|
||||||
|
END AS hybrid_score
|
||||||
|
FROM ranked
|
||||||
|
WHERE vector_score >= $4 OR text_score > 0
|
||||||
|
ORDER BY hybrid_score DESC, vector_score DESC
|
||||||
|
LIMIT $7
|
||||||
"#,
|
"#,
|
||||||
&[&vector, &agent_id, &threshold_f64, &limit],
|
&[
|
||||||
|
&vector,
|
||||||
|
&query_text,
|
||||||
|
&agent_id,
|
||||||
|
&threshold_f64,
|
||||||
|
&vector_weight_f64,
|
||||||
|
&text_weight_f64,
|
||||||
|
&limit,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.context("Failed to query memories")?;
|
.context("Failed to query memories")?;
|
||||||
@@ -133,7 +189,10 @@ impl Database {
|
|||||||
metadata: row.get("metadata"),
|
metadata: row.get("metadata"),
|
||||||
created_at: row.get("created_at"),
|
created_at: row.get("created_at"),
|
||||||
},
|
},
|
||||||
similarity: row.get("similarity"),
|
similarity: row.get("hybrid_score"),
|
||||||
|
vector_score: row.get("vector_score"),
|
||||||
|
text_score: row.get("text_score"),
|
||||||
|
hybrid_score: row.get("hybrid_score"),
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "query",
|
"name": "query",
|
||||||
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
|
"description": "Query stored memories using hybrid semantic plus keyword search. Returns the most relevant memories based on vector similarity and exact-text ranking.",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|||||||
@@ -36,9 +36,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
.and_then(|v| v.as_f64())
|
.and_then(|v| v.as_f64())
|
||||||
.unwrap_or(0.5) as f32;
|
.unwrap_or(0.5) as f32;
|
||||||
|
|
||||||
|
let (vector_weight, text_weight) = normalized_weights(
|
||||||
|
state.config.query.vector_weight,
|
||||||
|
state.config.query.text_weight,
|
||||||
|
);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Querying memories for agent '{}': '{}' (limit={}, threshold={})",
|
"Querying memories for agent '{}': '{}' (limit={}, threshold={}, vector_weight={}, text_weight={})",
|
||||||
agent_id, query_text, limit, threshold
|
agent_id, query_text, limit, threshold, vector_weight, text_weight
|
||||||
);
|
);
|
||||||
|
|
||||||
// Generate embedding for query using Arc<EmbeddingEngine>
|
// Generate embedding for query using Arc<EmbeddingEngine>
|
||||||
@@ -49,7 +54,15 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
// Search database
|
// Search database
|
||||||
let matches = state
|
let matches = state
|
||||||
.db
|
.db
|
||||||
.query_memories(agent_id, &query_embedding, limit, threshold)
|
.query_memories(
|
||||||
|
agent_id,
|
||||||
|
query_text,
|
||||||
|
&query_embedding,
|
||||||
|
limit,
|
||||||
|
threshold,
|
||||||
|
vector_weight,
|
||||||
|
text_weight,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.context("Failed to query memories")?;
|
.context("Failed to query memories")?;
|
||||||
|
|
||||||
@@ -63,6 +76,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
"id": m.record.id.to_string(),
|
"id": m.record.id.to_string(),
|
||||||
"content": m.record.content,
|
"content": m.record.content,
|
||||||
"similarity": m.similarity,
|
"similarity": m.similarity,
|
||||||
|
"vector_score": m.vector_score,
|
||||||
|
"text_score": m.text_score,
|
||||||
|
"hybrid_score": m.hybrid_score,
|
||||||
"keywords": m.record.keywords,
|
"keywords": m.record.keywords,
|
||||||
"metadata": m.record.metadata,
|
"metadata": m.record.metadata,
|
||||||
"created_at": m.record.created_at.to_rfc3339()
|
"created_at": m.record.created_at.to_rfc3339()
|
||||||
@@ -74,8 +90,22 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
|||||||
"success": true,
|
"success": true,
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
"query": query_text,
|
"query": query_text,
|
||||||
|
"vector_weight": vector_weight,
|
||||||
|
"text_weight": text_weight,
|
||||||
"count": results.len(),
|
"count": results.len(),
|
||||||
"results": results
|
"results": results
|
||||||
})
|
})
|
||||||
.to_string())
|
.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalized_weights(vector_weight: f32, text_weight: f32) -> (f32, f32) {
|
||||||
|
let vector_weight = vector_weight.max(0.0);
|
||||||
|
let text_weight = text_weight.max(0.0);
|
||||||
|
let total = vector_weight + text_weight;
|
||||||
|
|
||||||
|
if total <= f32::EPSILON {
|
||||||
|
(0.6, 0.4)
|
||||||
|
} else {
|
||||||
|
(vector_weight / total, text_weight / total)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user