diff --git a/src/db.rs b/src/db.rs index ae3447b..eaa4499 100644 --- a/src/db.rs +++ b/src/db.rs @@ -384,6 +384,253 @@ impl Database { } } +// --------------------------------------------------------------------------- +// Truth scoring database helpers (Issue #35) +// --------------------------------------------------------------------------- + +/// Parameters for updating truth scores on a memory. +#[derive(Debug, Clone)] +pub struct TruthScoreUpdate { + pub id: Uuid, + pub truth_value: f32, + pub truth_confidence: f32, + pub truth_category: String, + pub ecan_sti: f32, + pub ecan_lti: f32, +} + +/// Aggregated truth scoring statistics. +#[derive(Debug, Clone, Serialize)] +pub struct TruthStats { + pub total_memories: i64, + pub scored_memories: i64, + pub unscored_memories: i64, + pub category_verified: i64, + pub category_plausible: i64, + pub category_unverified: i64, + pub category_contradicted: i64, + pub avg_truth_value: Option, + pub avg_confidence: Option, + pub coverage_pct: f64, +} + +/// A lightweight memory record for the truth scoring worker. +/// Contains only the fields needed for scoring (avoids fetching full embeddings +/// unless cross-referencing requires them). +#[derive(Debug, Clone)] +pub struct ScoringCandidate { + pub id: Uuid, + pub content: String, + pub embedding: Vec, + pub metadata: serde_json::Value, + pub created_at: chrono::DateTime, + /// Existing truth value, if previously scored. + pub truth_value: Option, + pub truth_confidence: Option, + pub ecan_sti: Option, + pub ecan_lti: Option, +} + +impl Database { + /// Fetch memories that have never been truth-scored. + /// + /// Returns up to `limit` memories ordered by creation time (oldest first), + /// so the worker processes memories in FIFO order. + pub async fn get_unscored_memories(&self, limit: i64) -> Result> { + let client = self.pool.get().await?; + let rows = client + .query( + r#" + SELECT id, content, embedding, metadata, created_at, + truth_value, truth_confidence, ecan_sti, ecan_lti + FROM memories + WHERE truth_evaluated_at IS NULL + AND (expires_at IS NULL OR expires_at > NOW()) + ORDER BY created_at ASC + LIMIT $1 + "#, + &[&limit], + ) + .await + .context("Failed to fetch unscored memories")?; + + Ok(rows + .iter() + .map(|row| { + let pgvec: Vector = row.get("embedding"); + ScoringCandidate { + id: row.get("id"), + content: row.get("content"), + embedding: pgvec.to_vec(), + metadata: row.get("metadata"), + created_at: row.get("created_at"), + truth_value: row.get("truth_value"), + truth_confidence: row.get("truth_confidence"), + ecan_sti: row.get("ecan_sti"), + ecan_lti: row.get("ecan_lti"), + } + }) + .collect()) + } + + /// Fetch memories whose truth score is stale (evaluated more than + /// `older_than_seconds` ago). + pub async fn get_stale_memories( + &self, + older_than_seconds: i64, + limit: i64, + ) -> Result> { + let client = self.pool.get().await?; + let rows = client + .query( + r#" + SELECT id, content, embedding, metadata, created_at, + truth_value, truth_confidence, ecan_sti, ecan_lti + FROM memories + WHERE truth_evaluated_at IS NOT NULL + AND truth_evaluated_at < NOW() - ($1 || ' seconds')::interval + AND (expires_at IS NULL OR expires_at > NOW()) + ORDER BY truth_evaluated_at ASC + LIMIT $2 + "#, + &[&older_than_seconds.to_string(), &limit], + ) + .await + .context("Failed to fetch stale memories")?; + + Ok(rows + .iter() + .map(|row| { + let pgvec: Vector = row.get("embedding"); + ScoringCandidate { + id: row.get("id"), + content: row.get("content"), + embedding: pgvec.to_vec(), + metadata: row.get("metadata"), + created_at: row.get("created_at"), + truth_value: row.get("truth_value"), + truth_confidence: row.get("truth_confidence"), + ecan_sti: row.get("ecan_sti"), + ecan_lti: row.get("ecan_lti"), + } + }) + .collect()) + } + + /// Update truth scores for a single memory. + pub async fn update_truth_score(&self, update: &TruthScoreUpdate) -> Result<()> { + let client = self.pool.get().await?; + client + .execute( + r#" + UPDATE memories + SET truth_value = $2, + truth_confidence = $3, + truth_category = $4, + truth_evaluated_at = NOW(), + ecan_sti = $5, + ecan_lti = $6 + WHERE id = $1 + "#, + &[ + &update.id, + &update.truth_value, + &update.truth_confidence, + &update.truth_category, + &update.ecan_sti, + &update.ecan_lti, + ], + ) + .await + .context("Failed to update truth score")?; + Ok(()) + } + + /// Batch update truth scores in a single transaction. + pub async fn batch_update_truth_scores(&self, updates: &[TruthScoreUpdate]) -> Result { + if updates.is_empty() { + return Ok(0); + } + let mut client = self.pool.get().await?; + let transaction = client.transaction().await?; + + for update in updates { + transaction + .execute( + r#" + UPDATE memories + SET truth_value = $2, + truth_confidence = $3, + truth_category = $4, + truth_evaluated_at = NOW(), + ecan_sti = $5, + ecan_lti = $6 + WHERE id = $1 + "#, + &[ + &update.id, + &update.truth_value, + &update.truth_confidence, + &update.truth_category, + &update.ecan_sti, + &update.ecan_lti, + ], + ) + .await + .context("Failed to update truth score in batch")?; + } + + transaction.commit().await?; + Ok(updates.len()) + } + + /// Get aggregated truth scoring statistics. + pub async fn get_truth_stats(&self) -> Result { + let client = self.pool.get().await?; + let row = client + .query_one( + r#" + SELECT + COUNT(*) AS total, + COUNT(truth_evaluated_at) AS scored, + COUNT(*) - COUNT(truth_evaluated_at) AS unscored, + COUNT(*) FILTER (WHERE truth_category = 'verified') AS cat_verified, + COUNT(*) FILTER (WHERE truth_category = 'plausible') AS cat_plausible, + COUNT(*) FILTER (WHERE truth_category = 'unverified') AS cat_unverified, + COUNT(*) FILTER (WHERE truth_category = 'contradicted') AS cat_contradicted, + AVG(truth_value) FILTER (WHERE truth_value IS NOT NULL) AS avg_tv, + AVG(truth_confidence) FILTER (WHERE truth_confidence IS NOT NULL) AS avg_conf + FROM memories + WHERE expires_at IS NULL OR expires_at > NOW() + "#, + &[], + ) + .await + .context("Failed to get truth stats")?; + + let total: i64 = row.get("total"); + let scored: i64 = row.get("scored"); + let coverage_pct = if total > 0 { + (scored as f64 / total as f64) * 100.0 + } else { + 0.0 + }; + + Ok(TruthStats { + total_memories: total, + scored_memories: scored, + unscored_memories: row.get("unscored"), + category_verified: row.get("cat_verified"), + category_plausible: row.get("cat_plausible"), + category_unverified: row.get("cat_unverified"), + category_contradicted: row.get("cat_contradicted"), + avg_truth_value: row.get("avg_tv"), + avg_confidence: row.get("avg_conf"), + coverage_pct, + }) + } +} + /// Result for a single batch entry #[derive(Debug, Clone, Serialize)] pub struct BatchStoreResult {