Scope memories by API token and add shared-token e2e coverage

This commit is contained in:
Agent Zero
2026-04-01 23:30:58 -04:00
parent 98baa27c90
commit 026ae27366
17 changed files with 1096 additions and 428 deletions

103
src/db.rs
View File

@@ -9,9 +9,9 @@ use tokio_postgres::NoTls;
use tracing::info;
use uuid::Uuid;
use crate::config::DatabaseConfig;
use serde::Serialize;
use serde_json::{Map, Value};
use crate::config::DatabaseConfig;
/// Database wrapper with connection pool
#[derive(Clone)]
@@ -75,6 +75,7 @@ fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
async fn find_dedup_match<C>(
client: &C,
auth_scope: &str,
agent_id: &str,
embedding: &Vector,
threshold: f64,
@@ -87,13 +88,14 @@ where
r#"
SELECT id, metadata, expires_at
FROM memories
WHERE agent_id = $1
WHERE auth_scope = $1
AND agent_id = $2
AND (expires_at IS NULL OR expires_at > NOW())
AND (1 - (embedding <=> $2)) >= $3
ORDER BY (1 - (embedding <=> $2)) DESC, created_at DESC
AND (1 - (embedding <=> $3)) >= $4
ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
LIMIT 1
"#,
&[&agent_id, embedding, &threshold],
&[&auth_scope, &agent_id, embedding, &threshold],
)
.await
.context("Failed to check for duplicate memory")?;
@@ -120,13 +122,19 @@ impl Database {
.context("Failed to create database pool")?;
// Test connection
let client = pool.get().await.context("Failed to get database connection")?;
let client = pool
.get()
.await
.context("Failed to get database connection")?;
client
.simple_query("SELECT 1")
.await
.context("Failed to execute test query")?;
info!("Database connection pool created with {} connections", config.pool_size);
info!(
"Database connection pool created with {} connections",
config.pool_size
);
Ok(Self { pool })
}
@@ -134,6 +142,7 @@ impl Database {
/// Store a memory record
pub async fn store_memory(
&self,
auth_scope: &str,
agent_id: &str,
content: &str,
embedding: &[f32],
@@ -146,7 +155,9 @@ impl Database {
let vector = Vector::from(embedding.to_vec());
let dedup_threshold = dedup_threshold as f64;
if let Some(existing) = find_dedup_match(&client, agent_id, &vector, dedup_threshold).await? {
if let Some(existing) =
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
{
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
let refreshed_expires_at = expires_at.or(existing.expires_at);
@@ -176,10 +187,10 @@ impl Database {
client
.execute(
r#"
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
)
.await
.context("Failed to store memory")?;
@@ -194,7 +205,8 @@ impl Database {
/// Query memories by vector similarity
pub async fn query_memories(
&self,
agent_id: &str,
auth_scope: &str,
source_agent_id: Option<&str>,
query_text: &str,
embedding: &[f32],
limit: i64,
@@ -230,7 +242,8 @@ impl Database {
END AS text_score
FROM memories
CROSS JOIN search_query
WHERE memories.agent_id = $3
WHERE memories.auth_scope = $3
AND ($4::text IS NULL OR memories.agent_id = $4)
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
),
ranked AS (
@@ -251,18 +264,19 @@ impl Database {
text_score,
CASE
WHEN has_text_match = 1
THEN (($5 * vector_score) + ($6 * text_score))::real
THEN (($6 * vector_score) + ($7 * text_score))::real
ELSE vector_score
END AS hybrid_score
FROM ranked
WHERE vector_score >= $4 OR text_score > 0
WHERE vector_score >= $5 OR text_score > 0
ORDER BY hybrid_score DESC, vector_score DESC
LIMIT $7
LIMIT $8
"#,
&[
&vector,
&query_text,
&agent_id,
&auth_scope,
&source_agent_id,
&threshold,
&vector_weight,
&text_weight,
@@ -296,37 +310,47 @@ impl Database {
Ok(matches)
}
/// Delete memories by agent_id and optional filters
/// Delete memories visible to an auth scope with an optional provenance filter
pub async fn purge_memories(
&self,
agent_id: &str,
auth_scope: &str,
source_agent_id: Option<&str>,
before: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<u64> {
let client = self.pool.get().await?;
let count = if let Some(before_ts) = before {
client
.execute(
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
&[&agent_id, &before_ts],
)
.await?
} else {
client
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
.await?
};
let count = client
.execute(
r#"
DELETE FROM memories
WHERE auth_scope = $1
AND ($2::text IS NULL OR agent_id = $2)
AND ($3::timestamptz IS NULL OR created_at < $3)
"#,
&[&auth_scope, &source_agent_id, &before],
)
.await?;
Ok(count)
}
/// Get memory count for an agent
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
/// Get memory count for a token-visible scope and optional provenance filter
pub async fn count_memories(
&self,
auth_scope: &str,
source_agent_id: Option<&str>,
) -> Result<i64> {
let client = self.pool.get().await?;
let row = client
.query_one(
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())",
&[&agent_id],
r#"
SELECT COUNT(*) as count
FROM memories
WHERE auth_scope = $1
AND ($2::text IS NULL OR agent_id = $2)
AND (expires_at IS NULL OR expires_at > NOW())
"#,
&[&auth_scope, &source_agent_id],
)
.await?;
Ok(row.get("count"))
@@ -346,7 +370,6 @@ impl Database {
}
}
/// Result for a single batch entry
#[derive(Debug, Clone, Serialize)]
pub struct BatchStoreResult {
@@ -360,6 +383,7 @@ impl Database {
/// Store multiple memories in a single transaction
pub async fn batch_store_memories(
&self,
auth_scope: &str,
agent_id: &str,
entries: Vec<(
String,
@@ -378,7 +402,8 @@ impl Database {
for (content, metadata, embedding, keywords, expires_at) in entries {
let vector = Vector::from(embedding);
if let Some(existing) =
find_dedup_match(&transaction, agent_id, &vector, dedup_threshold).await?
find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
.await?
{
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
let refreshed_expires_at = expires_at.or(existing.expires_at);
@@ -404,8 +429,8 @@ impl Database {
} else {
let id = Uuid::new_v4();
transaction.execute(
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
).await?;
results.push(BatchStoreResult {
id: id.to_string(),