mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Initial public release
This commit is contained in:
176
src/db.rs
Normal file
176
src/db.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
//! Database module for PostgreSQL with pgvector support
|
||||
//!
|
||||
//! Provides connection pooling and query helpers for vector operations.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use deadpool_postgres::{Config, Pool, Runtime};
|
||||
use pgvector::Vector;
|
||||
use tokio_postgres::NoTls;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config::DatabaseConfig;
|
||||
|
||||
/// Database wrapper with connection pool
|
||||
#[derive(Clone)]
|
||||
pub struct Database {
|
||||
pool: Pool,
|
||||
}
|
||||
|
||||
/// A memory record stored in the database
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryRecord {
|
||||
pub id: Uuid,
|
||||
pub agent_id: String,
|
||||
pub content: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub keywords: Vec<String>,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Query result with similarity score
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryMatch {
|
||||
pub record: MemoryRecord,
|
||||
pub similarity: f32,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
/// Create a new database connection pool
|
||||
pub async fn new(config: &DatabaseConfig) -> Result<Self> {
|
||||
let mut cfg = Config::new();
|
||||
cfg.host = Some(config.host.clone());
|
||||
cfg.port = Some(config.port);
|
||||
cfg.dbname = Some(config.name.clone());
|
||||
cfg.user = Some(config.user.clone());
|
||||
cfg.password = Some(config.password.clone());
|
||||
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context("Failed to create database pool")?;
|
||||
|
||||
// Test connection
|
||||
let client = pool.get().await.context("Failed to get database connection")?;
|
||||
client
|
||||
.simple_query("SELECT 1")
|
||||
.await
|
||||
.context("Failed to execute test query")?;
|
||||
|
||||
info!("Database connection pool created with {} connections", config.pool_size);
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
/// Store a memory record
|
||||
pub async fn store_memory(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
content: &str,
|
||||
embedding: &[f32],
|
||||
keywords: &[String],
|
||||
metadata: serde_json::Value,
|
||||
) -> Result<Uuid> {
|
||||
let client = self.pool.get().await?;
|
||||
let id = Uuid::new_v4();
|
||||
let vector = Vector::from(embedding.to_vec());
|
||||
|
||||
client
|
||||
.execute(
|
||||
r#"
|
||||
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
|
||||
)
|
||||
.await
|
||||
.context("Failed to store memory")?;
|
||||
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Query memories by vector similarity
|
||||
pub async fn query_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
embedding: &[f32],
|
||||
limit: i64,
|
||||
threshold: f32,
|
||||
) -> Result<Vec<MemoryMatch>> {
|
||||
let client = self.pool.get().await?;
|
||||
let vector = Vector::from(embedding.to_vec());
|
||||
let threshold_f64 = threshold as f64;
|
||||
|
||||
let rows = client
|
||||
.query(
|
||||
r#"
|
||||
SELECT
|
||||
id, agent_id, content, keywords, metadata, created_at,
|
||||
(1 - (embedding <=> $1))::real AS similarity
|
||||
FROM memories
|
||||
WHERE agent_id = $2
|
||||
AND (1 - (embedding <=> $1)) >= $3
|
||||
ORDER BY embedding <=> $1
|
||||
LIMIT $4
|
||||
"#,
|
||||
&[&vector, &agent_id, &threshold_f64, &limit],
|
||||
)
|
||||
.await
|
||||
.context("Failed to query memories")?;
|
||||
|
||||
let matches = rows
|
||||
.iter()
|
||||
.map(|row| MemoryMatch {
|
||||
record: MemoryRecord {
|
||||
id: row.get("id"),
|
||||
agent_id: row.get("agent_id"),
|
||||
content: row.get("content"),
|
||||
// Query responses do not include raw embedding payloads.
|
||||
embedding: Vec::new(),
|
||||
keywords: row.get("keywords"),
|
||||
metadata: row.get("metadata"),
|
||||
created_at: row.get("created_at"),
|
||||
},
|
||||
similarity: row.get("similarity"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(matches)
|
||||
}
|
||||
|
||||
/// Delete memories by agent_id and optional filters
|
||||
pub async fn purge_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
before: Option<chrono::DateTime<chrono::Utc>>,
|
||||
) -> Result<u64> {
|
||||
let client = self.pool.get().await?;
|
||||
|
||||
let count = if let Some(before_ts) = before {
|
||||
client
|
||||
.execute(
|
||||
"DELETE FROM memories WHERE agent_id = $1 AND created_at < $2",
|
||||
&[&agent_id, &before_ts],
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
client
|
||||
.execute("DELETE FROM memories WHERE agent_id = $1", &[&agent_id])
|
||||
.await?
|
||||
};
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Get memory count for an agent
|
||||
pub async fn count_memories(&self, agent_id: &str) -> Result<i64> {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1",
|
||||
&[&agent_id],
|
||||
)
|
||||
.await?;
|
||||
Ok(row.get("count"))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user