mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-06-15 22:07:08 +00:00
New structs: - TruthScoreUpdate: parameters for updating truth scores - TruthStats: aggregated truth scoring statistics - ScoringCandidate: lightweight record for the scoring worker New Database methods: - get_unscored_memories(): fetch unscored memories FIFO - get_stale_memories(): fetch memories due for re-evaluation - update_truth_score(): update single memory truth fields - batch_update_truth_scores(): transactional batch update - get_truth_stats(): aggregate stats with category breakdown Uses partial index idx_memories_truth_unevaluated for efficient unscored memory queries. Part of #29
708 lines
24 KiB
Rust
708 lines
24 KiB
Rust
//! Database module for PostgreSQL with pgvector support
|
|
//!
|
|
//! Provides connection pooling and query helpers for vector operations.
|
|
|
|
use anyhow::{Context, Result};
|
|
use deadpool_postgres::{Config, GenericClient, Pool, Runtime};
|
|
use pgvector::Vector;
|
|
use tokio_postgres::NoTls;
|
|
use tracing::info;
|
|
use uuid::Uuid;
|
|
|
|
use crate::config::DatabaseConfig;
|
|
use serde::Serialize;
|
|
use serde_json::{Map, Value};
|
|
|
|
/// Database wrapper with connection pool
|
|
#[derive(Clone)]
|
|
pub struct Database {
|
|
pool: Pool,
|
|
}
|
|
|
|
/// A memory record stored in the database
|
|
#[derive(Debug, Clone)]
|
|
pub struct MemoryRecord {
|
|
pub id: Uuid,
|
|
pub agent_id: String,
|
|
pub content: String,
|
|
pub embedding: Vec<f32>,
|
|
pub keywords: Vec<String>,
|
|
pub metadata: serde_json::Value,
|
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
|
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
// Truth scoring fields (populated by background worker)
|
|
pub truth_value: Option<f32>,
|
|
pub truth_confidence: Option<f32>,
|
|
pub truth_category: Option<String>,
|
|
pub truth_evaluated_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
pub ecan_sti: Option<f32>,
|
|
pub ecan_lti: Option<f32>,
|
|
}
|
|
|
|
/// Query result with similarity score
|
|
#[derive(Debug, Clone)]
|
|
pub struct MemoryMatch {
|
|
pub record: MemoryRecord,
|
|
pub similarity: f32,
|
|
pub vector_score: f32,
|
|
pub text_score: f32,
|
|
pub hybrid_score: f32,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct StoreMemoryResult {
|
|
pub id: Uuid,
|
|
pub deduplicated: bool,
|
|
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct DedupMatch {
|
|
id: Uuid,
|
|
metadata: Value,
|
|
expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
}
|
|
|
|
fn merge_metadata(existing: &Value, incoming: &Value) -> Value {
|
|
match (existing, incoming) {
|
|
(Value::Object(existing), Value::Object(incoming)) => {
|
|
let mut merged = Map::with_capacity(existing.len() + incoming.len());
|
|
for (key, value) in existing {
|
|
merged.insert(key.clone(), value.clone());
|
|
}
|
|
for (key, value) in incoming {
|
|
merged.insert(key.clone(), value.clone());
|
|
}
|
|
Value::Object(merged)
|
|
}
|
|
(_, Value::Null) => existing.clone(),
|
|
_ => incoming.clone(),
|
|
}
|
|
}
|
|
|
|
async fn find_dedup_match<C>(
|
|
client: &C,
|
|
auth_scope: &str,
|
|
agent_id: &str,
|
|
embedding: &Vector,
|
|
threshold: f64,
|
|
) -> Result<Option<DedupMatch>>
|
|
where
|
|
C: GenericClient + Sync,
|
|
{
|
|
let row = client
|
|
.query_opt(
|
|
r#"
|
|
SELECT id, metadata, expires_at
|
|
FROM memories
|
|
WHERE auth_scope = $1
|
|
AND agent_id = $2
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
AND (1 - (embedding <=> $3)) >= $4
|
|
ORDER BY (1 - (embedding <=> $3)) DESC, created_at DESC
|
|
LIMIT 1
|
|
"#,
|
|
&[&auth_scope, &agent_id, embedding, &threshold],
|
|
)
|
|
.await
|
|
.context("Failed to check for duplicate memory")?;
|
|
|
|
Ok(row.map(|row| DedupMatch {
|
|
id: row.get("id"),
|
|
metadata: row.get("metadata"),
|
|
expires_at: row.get("expires_at"),
|
|
}))
|
|
}
|
|
|
|
impl Database {
|
|
/// Create a new database connection pool
|
|
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
|
|
let mut cfg = Config::new();
|
|
cfg.host = Some(config.host.clone());
|
|
cfg.port = Some(config.port);
|
|
cfg.dbname = Some(config.name.clone());
|
|
cfg.user = Some(config.user.clone());
|
|
cfg.password = Some(config.password.clone());
|
|
|
|
let pool = cfg
|
|
.create_pool(Some(Runtime::Tokio1), NoTls)
|
|
.context("Failed to create database pool")?;
|
|
|
|
// Test connection
|
|
let client = pool
|
|
.get()
|
|
.await
|
|
.context("Failed to get database connection")?;
|
|
client
|
|
.simple_query("SELECT 1")
|
|
.await
|
|
.context("Failed to execute test query")?;
|
|
|
|
info!(
|
|
"Database connection pool created with {} connections",
|
|
config.pool_size
|
|
);
|
|
|
|
Ok(Self { pool })
|
|
}
|
|
|
|
/// Store a memory record
|
|
pub async fn store_memory(
|
|
&self,
|
|
auth_scope: &str,
|
|
agent_id: &str,
|
|
content: &str,
|
|
embedding: &[f32],
|
|
keywords: &[String],
|
|
metadata: serde_json::Value,
|
|
expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
dedup_threshold: f32,
|
|
) -> Result<StoreMemoryResult> {
|
|
let client = self.pool.get().await?;
|
|
let vector = Vector::from(embedding.to_vec());
|
|
let dedup_threshold = dedup_threshold as f64;
|
|
|
|
if let Some(existing) =
|
|
find_dedup_match(&client, auth_scope, agent_id, &vector, dedup_threshold).await?
|
|
{
|
|
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
|
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
|
|
|
client
|
|
.execute(
|
|
r#"
|
|
UPDATE memories
|
|
SET metadata = $2,
|
|
created_at = NOW(),
|
|
expires_at = $3
|
|
WHERE id = $1
|
|
"#,
|
|
&[&existing.id, &merged_metadata, &refreshed_expires_at],
|
|
)
|
|
.await
|
|
.context("Failed to update deduplicated memory")?;
|
|
|
|
return Ok(StoreMemoryResult {
|
|
id: existing.id,
|
|
deduplicated: true,
|
|
expires_at: refreshed_expires_at,
|
|
});
|
|
}
|
|
|
|
let id = Uuid::new_v4();
|
|
|
|
client
|
|
.execute(
|
|
r#"
|
|
INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
"#,
|
|
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
|
)
|
|
.await
|
|
.context("Failed to store memory")?;
|
|
|
|
Ok(StoreMemoryResult {
|
|
id,
|
|
deduplicated: false,
|
|
expires_at,
|
|
})
|
|
}
|
|
|
|
/// Query memories by vector similarity
|
|
pub async fn query_memories(
|
|
&self,
|
|
auth_scope: &str,
|
|
source_agent_id: Option<&str>,
|
|
query_text: &str,
|
|
embedding: &[f32],
|
|
limit: i64,
|
|
threshold: f32,
|
|
vector_weight: f32,
|
|
text_weight: f32,
|
|
) -> Result<Vec<MemoryMatch>> {
|
|
let client = self.pool.get().await?;
|
|
let vector = Vector::from(embedding.to_vec());
|
|
let rows = client
|
|
.query(
|
|
r#"
|
|
WITH search_query AS (
|
|
SELECT
|
|
NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text,
|
|
plainto_tsquery('pg_catalog.english', $2) AS ts_query
|
|
),
|
|
scored AS (
|
|
SELECT
|
|
id,
|
|
agent_id,
|
|
content,
|
|
keywords,
|
|
metadata,
|
|
created_at,
|
|
expires_at,
|
|
(1 - (embedding <=> $1))::real AS vector_score,
|
|
CASE
|
|
WHEN search_query.query_text IS NULL THEN 0::real
|
|
WHEN memories.tsv @@ search_query.ts_query
|
|
THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real
|
|
ELSE 0::real
|
|
END AS text_score
|
|
FROM memories
|
|
CROSS JOIN search_query
|
|
WHERE memories.auth_scope = $3
|
|
AND ($4::text IS NULL OR memories.agent_id = $4)
|
|
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
|
|
),
|
|
ranked AS (
|
|
SELECT
|
|
*,
|
|
MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match
|
|
FROM scored
|
|
)
|
|
SELECT
|
|
id,
|
|
agent_id,
|
|
content,
|
|
keywords,
|
|
metadata,
|
|
created_at,
|
|
expires_at,
|
|
vector_score,
|
|
text_score,
|
|
CASE
|
|
WHEN has_text_match = 1
|
|
THEN (($6 * vector_score) + ($7 * text_score))::real
|
|
ELSE vector_score
|
|
END AS hybrid_score
|
|
FROM ranked
|
|
WHERE vector_score >= $5 OR text_score > 0
|
|
ORDER BY hybrid_score DESC, vector_score DESC
|
|
LIMIT $8
|
|
"#,
|
|
&[
|
|
&vector,
|
|
&query_text,
|
|
&auth_scope,
|
|
&source_agent_id,
|
|
&threshold,
|
|
&vector_weight,
|
|
&text_weight,
|
|
&limit,
|
|
],
|
|
)
|
|
.await
|
|
.context("Failed to query memories")?;
|
|
|
|
let matches = rows
|
|
.iter()
|
|
.map(|row| MemoryMatch {
|
|
record: MemoryRecord {
|
|
id: row.get("id"),
|
|
agent_id: row.get("agent_id"),
|
|
content: row.get("content"),
|
|
// Query responses do not include raw embedding payloads.
|
|
embedding: Vec::new(),
|
|
keywords: row.get("keywords"),
|
|
metadata: row.get("metadata"),
|
|
created_at: row.get("created_at"),
|
|
expires_at: row.get("expires_at"),
|
|
// Truth fields will be populated by issue #39
|
|
truth_value: None,
|
|
truth_confidence: None,
|
|
truth_category: None,
|
|
truth_evaluated_at: None,
|
|
ecan_sti: None,
|
|
ecan_lti: None,
|
|
},
|
|
similarity: row.get("hybrid_score"),
|
|
vector_score: row.get("vector_score"),
|
|
text_score: row.get("text_score"),
|
|
hybrid_score: row.get("hybrid_score"),
|
|
})
|
|
.collect();
|
|
|
|
Ok(matches)
|
|
}
|
|
|
|
/// Delete memories visible to an auth scope with an optional provenance filter
|
|
pub async fn purge_memories(
|
|
&self,
|
|
auth_scope: &str,
|
|
source_agent_id: Option<&str>,
|
|
before: Option<chrono::DateTime<chrono::Utc>>,
|
|
) -> Result<u64> {
|
|
let client = self.pool.get().await?;
|
|
|
|
let count = client
|
|
.execute(
|
|
r#"
|
|
DELETE FROM memories
|
|
WHERE auth_scope = $1
|
|
AND ($2::text IS NULL OR agent_id = $2)
|
|
AND ($3::timestamptz IS NULL OR created_at < $3)
|
|
"#,
|
|
&[&auth_scope, &source_agent_id, &before],
|
|
)
|
|
.await?;
|
|
|
|
Ok(count)
|
|
}
|
|
|
|
/// Get memory count for a token-visible scope and optional provenance filter
|
|
pub async fn count_memories(
|
|
&self,
|
|
auth_scope: &str,
|
|
source_agent_id: Option<&str>,
|
|
) -> Result<i64> {
|
|
let client = self.pool.get().await?;
|
|
let row = client
|
|
.query_one(
|
|
r#"
|
|
SELECT COUNT(*) as count
|
|
FROM memories
|
|
WHERE auth_scope = $1
|
|
AND ($2::text IS NULL OR agent_id = $2)
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
"#,
|
|
&[&auth_scope, &source_agent_id],
|
|
)
|
|
.await?;
|
|
Ok(row.get("count"))
|
|
}
|
|
|
|
/// Delete expired memories across all agents
|
|
pub async fn cleanup_expired_memories(&self) -> Result<u64> {
|
|
let client = self.pool.get().await?;
|
|
let deleted = client
|
|
.execute(
|
|
"DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at <= NOW()",
|
|
&[],
|
|
)
|
|
.await
|
|
.context("Failed to cleanup expired memories")?;
|
|
Ok(deleted)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Truth scoring database helpers (Issue #35)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Parameters for updating truth scores on a memory.
|
|
#[derive(Debug, Clone)]
|
|
pub struct TruthScoreUpdate {
|
|
pub id: Uuid,
|
|
pub truth_value: f32,
|
|
pub truth_confidence: f32,
|
|
pub truth_category: String,
|
|
pub ecan_sti: f32,
|
|
pub ecan_lti: f32,
|
|
}
|
|
|
|
/// Aggregated truth scoring statistics.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct TruthStats {
|
|
pub total_memories: i64,
|
|
pub scored_memories: i64,
|
|
pub unscored_memories: i64,
|
|
pub category_verified: i64,
|
|
pub category_plausible: i64,
|
|
pub category_unverified: i64,
|
|
pub category_contradicted: i64,
|
|
pub avg_truth_value: Option<f64>,
|
|
pub avg_confidence: Option<f64>,
|
|
pub coverage_pct: f64,
|
|
}
|
|
|
|
/// A lightweight memory record for the truth scoring worker.
|
|
/// Contains only the fields needed for scoring (avoids fetching full embeddings
|
|
/// unless cross-referencing requires them).
|
|
#[derive(Debug, Clone)]
|
|
pub struct ScoringCandidate {
|
|
pub id: Uuid,
|
|
pub content: String,
|
|
pub embedding: Vec<f32>,
|
|
pub metadata: serde_json::Value,
|
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
|
/// Existing truth value, if previously scored.
|
|
pub truth_value: Option<f32>,
|
|
pub truth_confidence: Option<f32>,
|
|
pub ecan_sti: Option<f32>,
|
|
pub ecan_lti: Option<f32>,
|
|
}
|
|
|
|
impl Database {
|
|
/// Fetch memories that have never been truth-scored.
|
|
///
|
|
/// Returns up to `limit` memories ordered by creation time (oldest first),
|
|
/// so the worker processes memories in FIFO order.
|
|
pub async fn get_unscored_memories(&self, limit: i64) -> Result<Vec<ScoringCandidate>> {
|
|
let client = self.pool.get().await?;
|
|
let rows = client
|
|
.query(
|
|
r#"
|
|
SELECT id, content, embedding, metadata, created_at,
|
|
truth_value, truth_confidence, ecan_sti, ecan_lti
|
|
FROM memories
|
|
WHERE truth_evaluated_at IS NULL
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
ORDER BY created_at ASC
|
|
LIMIT $1
|
|
"#,
|
|
&[&limit],
|
|
)
|
|
.await
|
|
.context("Failed to fetch unscored memories")?;
|
|
|
|
Ok(rows
|
|
.iter()
|
|
.map(|row| {
|
|
let pgvec: Vector = row.get("embedding");
|
|
ScoringCandidate {
|
|
id: row.get("id"),
|
|
content: row.get("content"),
|
|
embedding: pgvec.to_vec(),
|
|
metadata: row.get("metadata"),
|
|
created_at: row.get("created_at"),
|
|
truth_value: row.get("truth_value"),
|
|
truth_confidence: row.get("truth_confidence"),
|
|
ecan_sti: row.get("ecan_sti"),
|
|
ecan_lti: row.get("ecan_lti"),
|
|
}
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
/// Fetch memories whose truth score is stale (evaluated more than
|
|
/// `older_than_seconds` ago).
|
|
pub async fn get_stale_memories(
|
|
&self,
|
|
older_than_seconds: i64,
|
|
limit: i64,
|
|
) -> Result<Vec<ScoringCandidate>> {
|
|
let client = self.pool.get().await?;
|
|
let rows = client
|
|
.query(
|
|
r#"
|
|
SELECT id, content, embedding, metadata, created_at,
|
|
truth_value, truth_confidence, ecan_sti, ecan_lti
|
|
FROM memories
|
|
WHERE truth_evaluated_at IS NOT NULL
|
|
AND truth_evaluated_at < NOW() - ($1 || ' seconds')::interval
|
|
AND (expires_at IS NULL OR expires_at > NOW())
|
|
ORDER BY truth_evaluated_at ASC
|
|
LIMIT $2
|
|
"#,
|
|
&[&older_than_seconds.to_string(), &limit],
|
|
)
|
|
.await
|
|
.context("Failed to fetch stale memories")?;
|
|
|
|
Ok(rows
|
|
.iter()
|
|
.map(|row| {
|
|
let pgvec: Vector = row.get("embedding");
|
|
ScoringCandidate {
|
|
id: row.get("id"),
|
|
content: row.get("content"),
|
|
embedding: pgvec.to_vec(),
|
|
metadata: row.get("metadata"),
|
|
created_at: row.get("created_at"),
|
|
truth_value: row.get("truth_value"),
|
|
truth_confidence: row.get("truth_confidence"),
|
|
ecan_sti: row.get("ecan_sti"),
|
|
ecan_lti: row.get("ecan_lti"),
|
|
}
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
/// Update truth scores for a single memory.
|
|
pub async fn update_truth_score(&self, update: &TruthScoreUpdate) -> Result<()> {
|
|
let client = self.pool.get().await?;
|
|
client
|
|
.execute(
|
|
r#"
|
|
UPDATE memories
|
|
SET truth_value = $2,
|
|
truth_confidence = $3,
|
|
truth_category = $4,
|
|
truth_evaluated_at = NOW(),
|
|
ecan_sti = $5,
|
|
ecan_lti = $6
|
|
WHERE id = $1
|
|
"#,
|
|
&[
|
|
&update.id,
|
|
&update.truth_value,
|
|
&update.truth_confidence,
|
|
&update.truth_category,
|
|
&update.ecan_sti,
|
|
&update.ecan_lti,
|
|
],
|
|
)
|
|
.await
|
|
.context("Failed to update truth score")?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Batch update truth scores in a single transaction.
|
|
pub async fn batch_update_truth_scores(&self, updates: &[TruthScoreUpdate]) -> Result<usize> {
|
|
if updates.is_empty() {
|
|
return Ok(0);
|
|
}
|
|
let mut client = self.pool.get().await?;
|
|
let transaction = client.transaction().await?;
|
|
|
|
for update in updates {
|
|
transaction
|
|
.execute(
|
|
r#"
|
|
UPDATE memories
|
|
SET truth_value = $2,
|
|
truth_confidence = $3,
|
|
truth_category = $4,
|
|
truth_evaluated_at = NOW(),
|
|
ecan_sti = $5,
|
|
ecan_lti = $6
|
|
WHERE id = $1
|
|
"#,
|
|
&[
|
|
&update.id,
|
|
&update.truth_value,
|
|
&update.truth_confidence,
|
|
&update.truth_category,
|
|
&update.ecan_sti,
|
|
&update.ecan_lti,
|
|
],
|
|
)
|
|
.await
|
|
.context("Failed to update truth score in batch")?;
|
|
}
|
|
|
|
transaction.commit().await?;
|
|
Ok(updates.len())
|
|
}
|
|
|
|
/// Get aggregated truth scoring statistics.
|
|
pub async fn get_truth_stats(&self) -> Result<TruthStats> {
|
|
let client = self.pool.get().await?;
|
|
let row = client
|
|
.query_one(
|
|
r#"
|
|
SELECT
|
|
COUNT(*) AS total,
|
|
COUNT(truth_evaluated_at) AS scored,
|
|
COUNT(*) - COUNT(truth_evaluated_at) AS unscored,
|
|
COUNT(*) FILTER (WHERE truth_category = 'verified') AS cat_verified,
|
|
COUNT(*) FILTER (WHERE truth_category = 'plausible') AS cat_plausible,
|
|
COUNT(*) FILTER (WHERE truth_category = 'unverified') AS cat_unverified,
|
|
COUNT(*) FILTER (WHERE truth_category = 'contradicted') AS cat_contradicted,
|
|
AVG(truth_value) FILTER (WHERE truth_value IS NOT NULL) AS avg_tv,
|
|
AVG(truth_confidence) FILTER (WHERE truth_confidence IS NOT NULL) AS avg_conf
|
|
FROM memories
|
|
WHERE expires_at IS NULL OR expires_at > NOW()
|
|
"#,
|
|
&[],
|
|
)
|
|
.await
|
|
.context("Failed to get truth stats")?;
|
|
|
|
let total: i64 = row.get("total");
|
|
let scored: i64 = row.get("scored");
|
|
let coverage_pct = if total > 0 {
|
|
(scored as f64 / total as f64) * 100.0
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
Ok(TruthStats {
|
|
total_memories: total,
|
|
scored_memories: scored,
|
|
unscored_memories: row.get("unscored"),
|
|
category_verified: row.get("cat_verified"),
|
|
category_plausible: row.get("cat_plausible"),
|
|
category_unverified: row.get("cat_unverified"),
|
|
category_contradicted: row.get("cat_contradicted"),
|
|
avg_truth_value: row.get("avg_tv"),
|
|
avg_confidence: row.get("avg_conf"),
|
|
coverage_pct,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Result for a single batch entry
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct BatchStoreResult {
|
|
pub id: String,
|
|
pub status: String,
|
|
pub deduplicated: bool,
|
|
pub expires_at: Option<String>,
|
|
}
|
|
|
|
impl Database {
|
|
/// Store multiple memories in a single transaction
|
|
pub async fn batch_store_memories(
|
|
&self,
|
|
auth_scope: &str,
|
|
agent_id: &str,
|
|
entries: Vec<(
|
|
String,
|
|
Value,
|
|
Vec<f32>,
|
|
Vec<String>,
|
|
Option<chrono::DateTime<chrono::Utc>>,
|
|
)>,
|
|
dedup_threshold: f32,
|
|
) -> Result<Vec<BatchStoreResult>> {
|
|
let mut client = self.pool.get().await?;
|
|
let transaction = client.transaction().await?;
|
|
let mut results = Vec::with_capacity(entries.len());
|
|
let dedup_threshold = dedup_threshold as f64;
|
|
|
|
for (content, metadata, embedding, keywords, expires_at) in entries {
|
|
let vector = Vector::from(embedding);
|
|
if let Some(existing) =
|
|
find_dedup_match(&transaction, auth_scope, agent_id, &vector, dedup_threshold)
|
|
.await?
|
|
{
|
|
let merged_metadata = merge_metadata(&existing.metadata, &metadata);
|
|
let refreshed_expires_at = expires_at.or(existing.expires_at);
|
|
transaction
|
|
.execute(
|
|
r#"
|
|
UPDATE memories
|
|
SET metadata = $2,
|
|
created_at = NOW(),
|
|
expires_at = $3
|
|
WHERE id = $1
|
|
"#,
|
|
&[&existing.id, &merged_metadata, &refreshed_expires_at],
|
|
)
|
|
.await
|
|
.context("Failed to update deduplicated batch memory")?;
|
|
results.push(BatchStoreResult {
|
|
id: existing.id.to_string(),
|
|
status: "deduplicated".to_string(),
|
|
deduplicated: true,
|
|
expires_at: refreshed_expires_at.map(|ts| ts.to_rfc3339()),
|
|
});
|
|
} else {
|
|
let id = Uuid::new_v4();
|
|
transaction.execute(
|
|
r#"INSERT INTO memories (id, auth_scope, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"#,
|
|
&[&id, &auth_scope, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
|
).await?;
|
|
results.push(BatchStoreResult {
|
|
id: id.to_string(),
|
|
status: "stored".to_string(),
|
|
deduplicated: false,
|
|
expires_at: expires_at.map(|ts| ts.to_rfc3339()),
|
|
});
|
|
}
|
|
}
|
|
transaction.commit().await?;
|
|
Ok(results)
|
|
}
|
|
}
|