Initial public release

This commit is contained in:
Agent Zero
2026-03-07 13:41:36 -05:00
commit 774982dc5a
22 changed files with 3517 additions and 0 deletions

176
src/db.rs Normal file
View File

@@ -0,0 +1,176 @@
//! Database module for PostgreSQL with pgvector support
//!
//! Provides connection pooling and query helpers for vector operations.
use anyhow::{Context, Result};
use deadpool_postgres::{Config, Pool, Runtime};
use pgvector::Vector;
use tokio_postgres::NoTls;
use tracing::info;
use uuid::Uuid;
use crate::config::DatabaseConfig;
/// Database wrapper with connection pool
#[derive(Clone)]
pub struct Database {
pool: Pool,
}
/// A memory record stored in the database
#[derive(Debug, Clone)]
pub struct MemoryRecord {
pub id: Uuid,
pub agent_id: String,
pub content: String,
pub embedding: Vec<f32>,
pub keywords: Vec<String>,
pub metadata: serde_json::Value,
pub created_at: chrono::DateTime<chrono::Utc>,
}
/// Query result with similarity score
#[derive(Debug, Clone)]
pub struct MemoryMatch {
pub record: MemoryRecord,
pub similarity: f32,
}
impl Database {
/// Create a new database connection pool
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
let mut cfg = Config::new();
cfg.host = Some(config.host.clone());
cfg.port = Some(config.port);
cfg.dbname = Some(config.name.clone());
cfg.user = Some(config.user.clone());
cfg.password = Some(config.password.clone());
let pool = cfg
.create_pool(Some(Runtime::Tokio1), NoTls)
.context("Failed to create database pool")?;
// Test connection
let client = pool.get().await.context("Failed to get database connection")?;
client
.simple_query("SELECT 1")
.await
.context("Failed to execute test query")?;
info!("Database connection pool created with {} connections", config.pool_size);
Ok(Self { pool })
}
/// Store a memory record
pub async fn store_memory(
&self,
agent_id: &str,
content: &str,
embedding: &[f32],
keywords: &[String],
metadata: serde_json::Value,
) -> Result<Uuid> {
let client = self.pool.get().await?;
let id = Uuid::new_v4();
let vector = Vector::from(embedding.to_vec());
client
.execute(
r#"
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata)
VALUES ($1, $2, $3, $4, $5, $6)
"#,
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
)
.await
.context("Failed to store memory")?;
Ok(id)
}
/// Query memories by vector similarity
pub async fn query_memories(
&self,
agent_id: &str,
embedding: &[f32],
limit: i64,
threshold: f32,
) -> Result<Vec<MemoryMatch>> {
let client = self.pool.get().await?;
let vector = Vector::from(embedding.to_vec());
let threshold_f64 = threshold as f64;
let rows = client
.query(
r#"
SELECT
id, agent_id, content, keywords, metadata, created_at,
(1 - (embedding <=> $1))::real AS similarity
FROM memories
WHERE agent_id = $2
AND (1 - (embedding <=> $1)) >= $3
ORDER BY embedding <=> $1
LIMIT $4
"#,
&[&vector, &agent_id, &threshold_f64, &limit],
)
.await
.context("Failed to query memories")?;
let matches = rows
.iter()
.map(|row| MemoryMatch {
record: MemoryRecord {
id: row.get("id"),
agent_id: row.get("agent_id"),
content: row.get("content"),
// Query responses do not include raw embedding payloads.
embedding: Vec::new(),
keywords: row.get("keywords"),
metadata: row.get("metadata"),
created_at: row.get("created_at"),
},
similarity: row.get("similarity"),
})
.collect();
Ok(matches)
}
/// Delete memories by agent_id and optional filters
pub async fn purge_memories(
&self,
agent_id: &str,
before: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<u64> {
let client = self.pool.get().await?;
let count = if let Some(before_ts) = before {
client
.execute(
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
&[&agent_id, &before_ts],
)
.await?
} else {
client
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
.await?
};
Ok(count)
}
/// Get memory count for an agent
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
let client = self.pool.get().await?;
let row = client
.query_one(
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1",
&[&agent_id],
)
.await?;
Ok(row.get("count"))
}
}