mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-06-16 06:17:08 +00:00
Scope memories by API token and add shared-token e2e coverage
This commit is contained in:
103
src/db.rs
103
src/db.rs
@@ -9,9 +9,9 @@ use tokio_postgres::NoTls;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::DatabaseConfig;
|
||||
use serde::Serialize;
|
||||
use serde_json::{Map, Value};
|
||||
use crate::config::DatabaseConfig;
|
||||
|
||||
/// Database wrapper with connection pool
|
||||
#[derive(Clone)]
|
||||
@@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
|
||||
|
||||
async fn find_dedup_match<C>(
|
||||
client: &C,
|
||||
auth_scope: &str,
|
||||
agent_id: &str,
|
||||
embedding: &Vector,
|
||||
threshold: f64,
|
||||
@@ -87,13 +88,14 @@ where
|
||||
r#"
|
||||
SELECT id, metadata, expires_at
|
||||
FROM memories
|
||||
WHERE agent_id = $1
|
||||
WHERE auth_scope = $1
|
||||
AND agent_id = $2
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
AND (1 - (embedding <=> $2)) >= $3
|
||||
ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC
|
||||
AND (1 - (embedding <=> $3)) >= $4
|
||||
ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
|
||||
LIMIT 1
|
||||
"#,
|
||||
&[&agent_id, embedding, &threshold],
|
||||
&[&auth_scope, &agent_id, embedding, &threshold],
|
||||
)
|
||||
.await
|
||||
.context("Failed to check for duplicate memory")?;
|
||||
@@ -120,13 +122,19 @@ impl Database {
|
||||
.context("Failed to create database pool")?;
|
||||
|
||||
// Test connection
|
||||
let client = pool.get().await.context("Failed to get database connection")?;
|
||||
let client = pool
|
||||
.get()
|
||||
.await
|
||||
.context("Failed to get database connection")?;
|
||||
client
|
||||
.simple_query("SELECT 1")
|
||||
.await
|
||||
.context("Failed to execute test query")?;
|
||||
|
||||
info!("Database connection pool created with {} connections", config.pool_size);
|
||||
info!(
|
||||
"Database connection pool created with {} connections",
|
||||
config.pool_size
|
||||
);
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
@@ -134,6 +142,7 @@ impl Database {
|
||||
/// Store a memory record
|
||||
pub async fn store_memory(
|
||||
&self,
|
||||
auth_scope: &str,
|
||||
agent_id: &str,
|
||||
content: &str,
|
||||
embedding: &[f32],
|
||||
@@ -146,7 +155,9 @@ impl Database {
|
||||
let vector = Vector::from(embedding.to_vec());
|
||||
let dedup_threshold = dedup_threshold as f64;
|
||||
|
||||
if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? {
|
||||
if let Some(existing) =
|
||||
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
|
||||
{
|
||||
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
||||
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
||||
|
||||
@@ -176,10 +187,10 @@ impl Database {
|
||||
client
|
||||
.execute(
|
||||
r#"
|
||||
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
)
|
||||
.await
|
||||
.context("Failed to store memory")?;
|
||||
@@ -194,7 +205,8 @@ impl Database {
|
||||
/// Query memories by vector similarity
|
||||
pub async fn query_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
auth_scope: &str,
|
||||
source_agent_id: Option<&str>,
|
||||
query_text: &str,
|
||||
embedding: &[f32],
|
||||
limit: i64,
|
||||
@@ -230,7 +242,8 @@ impl Database {
|
||||
END AS text_score
|
||||
FROM memories
|
||||
CROSS JOIN search_query
|
||||
WHERE memories.agent_id = $3
|
||||
WHERE memories.auth_scope = $3
|
||||
AND ($4::text IS NULL OR memories.agent_id = $4)
|
||||
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
|
||||
),
|
||||
ranked AS (
|
||||
@@ -251,18 +264,19 @@ impl Database {
|
||||
text_score,
|
||||
CASE
|
||||
WHEN has_text_match = 1
|
||||
THEN (($5 * vector_score) + ($6 * text_score))::real
|
||||
THEN (($6 * vector_score) + ($7 * text_score))::real
|
||||
ELSE vector_score
|
||||
END AS hybrid_score
|
||||
FROM ranked
|
||||
WHERE vector_score >= $4 OR text_score > 0
|
||||
WHERE vector_score >= $5 OR text_score > 0
|
||||
ORDER BY hybrid_score DESC, vector_score DESC
|
||||
LIMIT $7
|
||||
LIMIT $8
|
||||
"#,
|
||||
&[
|
||||
&vector,
|
||||
&query_text,
|
||||
&agent_id,
|
||||
&auth_scope,
|
||||
&source_agent_id,
|
||||
&threshold,
|
||||
&vector_weight,
|
||||
&text_weight,
|
||||
@@ -296,37 +310,47 @@ impl Database {
|
||||
Ok(matches)
|
||||
}
|
||||
|
||||
/// Delete memories by agent_id and optional filters
|
||||
/// Delete memories visible to an auth scope with an optional provenance filter
|
||||
pub async fn purge_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
auth_scope: &str,
|
||||
source_agent_id: Option<&str>,
|
||||
before: Option<chrono::DateTime<chrono::Utc>>,
|
||||
) -> Result<u64> {
|
||||
let client = self.pool.get().await?;
|
||||
|
||||
let count = if let Some(before_ts) = before {
|
||||
client
|
||||
.execute(
|
||||
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
|
||||
&[&agent_id, &before_ts],
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
client
|
||||
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
|
||||
.await?
|
||||
};
|
||||
let count = client
|
||||
.execute(
|
||||
r#"
|
||||
DELETE FROM memories
|
||||
WHERE auth_scope = $1
|
||||
AND ($2::text IS NULL OR agent_id = $2)
|
||||
AND ($3::timestamptz IS NULL OR created_at < $3)
|
||||
"#,
|
||||
&[&auth_scope, &source_agent_id, &before],
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Get memory count for an agent
|
||||
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
|
||||
/// Get memory count for a token-visible scope and optional provenance filter
|
||||
pub async fn count_memories(
|
||||
&self,
|
||||
auth_scope: &str,
|
||||
source_agent_id: Option<&str>,
|
||||
) -> Result<i64> {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())",
|
||||
&[&agent_id],
|
||||
r#"
|
||||
SELECT COUNT(*) as count
|
||||
FROM memories
|
||||
WHERE auth_scope = $1
|
||||
AND ($2::text IS NULL OR agent_id = $2)
|
||||
AND (expires_at IS NULL OR expires_at > NOW())
|
||||
"#,
|
||||
&[&auth_scope, &source_agent_id],
|
||||
)
|
||||
.await?;
|
||||
Ok(row.get("count"))
|
||||
@@ -346,7 +370,6 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Result for a single batch entry
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct BatchStoreResult {
|
||||
@@ -360,6 +383,7 @@ impl Database {
|
||||
/// Store multiple memories in a single transaction
|
||||
pub async fn batch_store_memories(
|
||||
&self,
|
||||
auth_scope: &str,
|
||||
agent_id: &str,
|
||||
entries: Vec<(
|
||||
String,
|
||||
@@ -378,7 +402,8 @@ impl Database {
|
||||
for (content, metadata, embedding, keywords, expires_at) in entries {
|
||||
let vector = Vector::from(embedding);
|
||||
if let Some(existing) =
|
||||
find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await?
|
||||
find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
|
||||
.await?
|
||||
{
|
||||
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
||||
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
||||
@@ -404,8 +429,8 @@ impl Database {
|
||||
} else {
|
||||
let id = Uuid::new_v4();
|
||||
transaction.execute(
|
||||
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
|
||||
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
).await?;
|
||||
results.push(BatchStoreResult {
|
||||
id: id.to_string(),
|
||||
|
||||
Reference in New Issue
Block a user