//! Database module for PostgreSQL with pgvector support //! //! Provides connection pooling and query helpers for vector operations. use anyhow::{Context, Result}; use deadpool_postgres::{Config, Pool, Runtime}; use pgvector::Vector; use tokio_postgres::NoTls; use tracing::info; use uuid::Uuid; use serde::Serialize; use serde_json::Value; use crate::config::DatabaseConfig; /// Database wrapper with connection pool #[derive(Clone)] pub struct Database { pool: Pool, } /// A memory record stored in the database #[derive(Debug, Clone)] pub struct MemoryRecord { pub id: Uuid, pub agent_id: String, pub content: String, pub embedding: Vec, pub keywords: Vec, pub metadata: serde_json::Value, pub created_at: chrono::DateTime, pub expires_at: Option>, } /// Query result with similarity score #[derive(Debug, Clone)] pub struct MemoryMatch { pub record: MemoryRecord, pub similarity: f32, pub vector_score: f32, pub text_score: f32, pub hybrid_score: f32, } impl Database { /// Create a new database connection pool pub async fn new(config: &DatabaseConfig) -> Result { let mut cfg = Config::new(); cfg.host = Some(config.host.clone()); cfg.port = Some(config.port); cfg.dbname = Some(config.name.clone()); cfg.user = Some(config.user.clone()); cfg.password = Some(config.password.clone()); let pool = cfg .create_pool(Some(Runtime::Tokio1), NoTls) .context("Failed to create database pool")?; // Test connection let client = pool.get().await.context("Failed to get database connection")?; client .simple_query("SELECT 1") .await .context("Failed to execute test query")?; info!("Database connection pool created with {} connections", config.pool_size); Ok(Self { pool }) } /// Store a memory record pub async fn store_memory( &self, agent_id: &str, content: &str, embedding: &[f32], keywords: &[String], metadata: serde_json::Value, expires_at: Option>, ) -> Result { let client = self.pool.get().await?; let id = Uuid::new_v4(); let vector = Vector::from(embedding.to_vec()); client .execute( r#" INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7) "#, &[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], ) .await .context("Failed to store memory")?; Ok(id) } /// Query memories by vector similarity pub async fn query_memories( &self, agent_id: &str, query_text: &str, embedding: &[f32], limit: i64, threshold: f32, vector_weight: f32, text_weight: f32, ) -> Result> { let client = self.pool.get().await?; let vector = Vector::from(embedding.to_vec()); let rows = client .query( r#" WITH search_query AS ( SELECT NULLIF(plainto_tsquery('pg_catalog.english', $2)::text, '') AS query_text, plainto_tsquery('pg_catalog.english', $2) AS ts_query ), scored AS ( SELECT id, agent_id, content, keywords, metadata, created_at, expires_at, (1 - (embedding <=> $1))::real AS vector_score, CASE WHEN search_query.query_text IS NULL THEN 0::real WHEN memories.tsv @@ search_query.ts_query THEN ts_rank(memories.tsv, search_query.ts_query, 32)::real ELSE 0::real END AS text_score FROM memories CROSS JOIN search_query WHERE memories.agent_id = $3 AND (memories.expires_at IS NULL OR memories.expires_at > NOW()) ), ranked AS ( SELECT *, MAX(CASE WHEN text_score > 0 THEN 1 ELSE 0 END) OVER () AS has_text_match FROM scored ) SELECT id, agent_id, content, keywords, metadata, created_at, expires_at, vector_score, text_score, CASE WHEN has_text_match = 1 THEN (($5 * vector_score) + ($6 * text_score))::real ELSE vector_score END AS hybrid_score FROM ranked WHERE vector_score >= $4 OR text_score > 0 ORDER BY hybrid_score DESC, vector_score DESC LIMIT $7 "#, &[ &vector, &query_text, &agent_id, &threshold, &vector_weight, &text_weight, &limit, ], ) .await .context("Failed to query memories")?; let matches = rows .iter() .map(|row| MemoryMatch { record: MemoryRecord { id: row.get("id"), agent_id: row.get("agent_id"), content: row.get("content"), // Query responses do not include raw embedding payloads. embedding: Vec::new(), keywords: row.get("keywords"), metadata: row.get("metadata"), created_at: row.get("created_at"), expires_at: row.get("expires_at"), }, similarity: row.get("hybrid_score"), vector_score: row.get("vector_score"), text_score: row.get("text_score"), hybrid_score: row.get("hybrid_score"), }) .collect(); Ok(matches) } /// Delete memories by agent_id and optional filters pub async fn purge_memories( &self, agent_id: &str, before: Option>, ) -> Result { let client = self.pool.get().await?; let count = if let Some(before_ts) = before { client .execute( "DELETE FROM memories WHERE agent_id = $1 AND created_at < $2", &[&agent_id, &before_ts], ) .await? } else { client .execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id]) .await? }; Ok(count) } /// Get memory count for an agent pub async fn count_memories(&self, agent_id: &str) -> Result { let client = self.pool.get().await?; let row = client .query_one( "SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())", &[&agent_id], ) .await?; Ok(row.get("count")) } /// Delete expired memories across all agents pub async fn cleanup_expired_memories(&self) -> Result { let client = self.pool.get().await?; let deleted = client .execute( "DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at <= NOW()", &[], ) .await .context("Failed to cleanup expired memories")?; Ok(deleted) } } /// Result for a single batch entry #[derive(Debug, Clone, Serialize)] pub struct BatchStoreResult { pub id: String, pub status: String, pub expires_at: Option, } impl Database { /// Store multiple memories in a single transaction pub async fn batch_store_memories( &self, agent_id: &str, entries: Vec<( String, Value, Vec, Vec, Option>, )>, ) -> Result> { let mut client = self.pool.get().await?; let transaction = client.transaction().await?; let mut results = Vec::with_capacity(entries.len()); for (content, metadata, embedding, keywords, expires_at) in entries { let id = Uuid::new_v4(); let vector = Vector::from(embedding); transaction.execute( r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#, &[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at], ).await?; results.push(BatchStoreResult { id: id.to_string(), status: "stored".to_string(), expires_at: expires_at.map(|ts| ts.to_rfc3339()), }); } transaction.commit().await?; Ok(results) } }