mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
Add TTL expiry for transient facts
This commit is contained in:
@@ -12,6 +12,7 @@ pub struct Config {
|
||||
pub database: DatabaseConfig,
|
||||
pub embedding: EmbeddingConfig,
|
||||
pub query: QueryConfig,
|
||||
pub ttl: TtlConfig,
|
||||
pub auth: AuthConfig,
|
||||
}
|
||||
|
||||
@@ -55,6 +56,13 @@ pub struct QueryConfig {
|
||||
pub text_weight: f32,
|
||||
}
|
||||
|
||||
/// TTL / expiry configuration
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct TtlConfig {
|
||||
#[serde(default = "default_cleanup_interval_seconds")]
|
||||
pub cleanup_interval_seconds: u64,
|
||||
}
|
||||
|
||||
/// Authentication configuration
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AuthConfig {
|
||||
@@ -98,6 +106,7 @@ fn default_model_path() -> String { "models/all-MiniLM-L6-v2".to_string() }
|
||||
fn default_embedding_dim() -> usize { 384 }
|
||||
fn default_vector_weight() -> f32 { 0.6 }
|
||||
fn default_text_weight() -> f32 { 0.4 }
|
||||
fn default_cleanup_interval_seconds() -> u64 { 300 }
|
||||
fn default_auth_enabled() -> bool { false }
|
||||
|
||||
impl Config {
|
||||
@@ -119,6 +128,11 @@ impl Config {
|
||||
// Query settings
|
||||
.set_default("query.vector_weight", default_vector_weight() as f64)?
|
||||
.set_default("query.text_weight", default_text_weight() as f64)?
|
||||
// TTL settings
|
||||
.set_default(
|
||||
"ttl.cleanup_interval_seconds",
|
||||
default_cleanup_interval_seconds() as i64,
|
||||
)?
|
||||
// Auth settings
|
||||
.set_default("auth.enabled", default_auth_enabled())?
|
||||
// Load from environment with OPENBRAIN_ prefix
|
||||
@@ -170,6 +184,9 @@ impl Default for Config {
|
||||
vector_weight: default_vector_weight(),
|
||||
text_weight: default_text_weight(),
|
||||
},
|
||||
ttl: TtlConfig {
|
||||
cleanup_interval_seconds: default_cleanup_interval_seconds(),
|
||||
},
|
||||
auth: AuthConfig {
|
||||
enabled: default_auth_enabled(),
|
||||
api_keys: Vec::new(),
|
||||
|
||||
48
src/db.rs
48
src/db.rs
@@ -29,6 +29,7 @@ pub struct MemoryRecord {
|
||||
pub keywords: Vec<String>,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
/// Query result with similarity score
|
||||
@@ -75,6 +76,7 @@ impl Database {
|
||||
embedding: &[f32],
|
||||
keywords: &[String],
|
||||
metadata: serde_json::Value,
|
||||
expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
||||
) -> Result<Uuid> {
|
||||
let client = self.pool.get().await?;
|
||||
let id = Uuid::new_v4();
|
||||
@@ -83,10 +85,10 @@ impl Database {
|
||||
client
|
||||
.execute(
|
||||
r#"
|
||||
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
)
|
||||
.await
|
||||
.context("Failed to store memory")?;
|
||||
@@ -123,6 +125,7 @@ impl Database {
|
||||
keywords,
|
||||
metadata,
|
||||
created_at,
|
||||
expires_at,
|
||||
(1 - (embedding <=> $1))::real AS vector_score,
|
||||
CASE
|
||||
WHEN search_query.query_text IS NULL THEN 0::real
|
||||
@@ -133,6 +136,7 @@ impl Database {
|
||||
FROM memories
|
||||
CROSS JOIN search_query
|
||||
WHERE memories.agent_id = $3
|
||||
AND (memories.expires_at IS NULL OR memories.expires_at > NOW())
|
||||
),
|
||||
ranked AS (
|
||||
SELECT
|
||||
@@ -147,6 +151,7 @@ impl Database {
|
||||
keywords,
|
||||
metadata,
|
||||
created_at,
|
||||
expires_at,
|
||||
vector_score,
|
||||
text_score,
|
||||
CASE
|
||||
@@ -184,6 +189,7 @@ impl Database {
|
||||
keywords: row.get("keywords"),
|
||||
metadata: row.get("metadata"),
|
||||
created_at: row.get("created_at"),
|
||||
expires_at: row.get("expires_at"),
|
||||
},
|
||||
similarity: row.get("hybrid_score"),
|
||||
vector_score: row.get("vector_score"),
|
||||
@@ -224,12 +230,25 @@ impl Database {
|
||||
let client = self.pool.get().await?;
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1",
|
||||
"SELECT COUNT(*) as count FROM memories WHERE agent_id = $1 AND (expires_at IS NULL OR expires_at > NOW())",
|
||||
&[&agent_id],
|
||||
)
|
||||
.await?;
|
||||
Ok(row.get("count"))
|
||||
}
|
||||
|
||||
/// Delete expired memories across all agents
|
||||
pub async fn cleanup_expired_memories(&self) -> Result<u64> {
|
||||
let client = self.pool.get().await?;
|
||||
let deleted = client
|
||||
.execute(
|
||||
"DELETE FROM memories WHERE expires_at IS NOT NULL AND expires_at <= NOW()",
|
||||
&[],
|
||||
)
|
||||
.await
|
||||
.context("Failed to cleanup expired memories")?;
|
||||
Ok(deleted)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -238,6 +257,7 @@ impl Database {
|
||||
pub struct BatchStoreResult {
|
||||
pub id: String,
|
||||
pub status: String,
|
||||
pub expires_at: Option<String>,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
@@ -245,20 +265,30 @@ impl Database {
|
||||
pub async fn batch_store_memories(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
entries: Vec<(String, Value, Vec<f32>, Vec<String>)>,
|
||||
entries: Vec<(
|
||||
String,
|
||||
Value,
|
||||
Vec<f32>,
|
||||
Vec<String>,
|
||||
Option<chrono::DateTime<chrono::Utc>>,
|
||||
)>,
|
||||
) -> Result<Vec<BatchStoreResult>> {
|
||||
let mut client = self.pool.get().await?;
|
||||
let transaction = client.transaction().await?;
|
||||
let mut results = Vec::with_capacity(entries.len());
|
||||
|
||||
for (content, metadata, embedding, keywords) in entries {
|
||||
for (content, metadata, embedding, keywords, expires_at) in entries {
|
||||
let id = Uuid::new_v4();
|
||||
let vector = Vector::from(embedding);
|
||||
transaction.execute(
|
||||
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata) VALUES ($1, $2, $3, $4, $5, $6)"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
|
||||
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7)"#,
|
||||
&[&id, &agent_id, &content, &vector, &keywords, &metadata, &expires_at],
|
||||
).await?;
|
||||
results.push(BatchStoreResult { id: id.to_string(), status: "stored".to_string() });
|
||||
results.push(BatchStoreResult {
|
||||
id: id.to_string(),
|
||||
status: "stored".to_string(),
|
||||
expires_at: expires_at.map(|ts| ts.to_rfc3339()),
|
||||
});
|
||||
}
|
||||
transaction.commit().await?;
|
||||
Ok(results)
|
||||
|
||||
25
src/lib.rs
25
src/lib.rs
@@ -5,6 +5,7 @@ pub mod config;
|
||||
pub mod db;
|
||||
pub mod embedding;
|
||||
pub mod migrations;
|
||||
pub mod ttl;
|
||||
pub mod tools;
|
||||
pub mod transport;
|
||||
|
||||
@@ -115,6 +116,30 @@ pub async fn run_server(config: Config, db: Database) -> Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
if config.ttl.cleanup_interval_seconds > 0 {
|
||||
let cleanup_state = state.clone();
|
||||
let cleanup_interval_seconds = config.ttl.cleanup_interval_seconds;
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
|
||||
cleanup_interval_seconds,
|
||||
));
|
||||
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
match cleanup_state.db.cleanup_expired_memories().await {
|
||||
Ok(deleted) if deleted > 0 => {
|
||||
info!("Cleaned up {} expired memories", deleted);
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
error!("Failed to cleanup expired memories: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Create MCP state for SSE transport
|
||||
let mcp_state = McpState::new(state.clone());
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::embedding::extract_keywords;
|
||||
use crate::ttl::expires_at_from_ttl;
|
||||
use crate::AppState;
|
||||
|
||||
/// Maximum number of entries allowed per batch store call
|
||||
@@ -18,6 +19,7 @@ const MAX_BATCH_SIZE: usize = 50;
|
||||
///
|
||||
/// Accepts:
|
||||
/// - `agent_id`: Optional agent identifier (defaults to "default")
|
||||
/// - `ttl`: Optional default TTL string applied to entries without their own ttl
|
||||
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
|
||||
///
|
||||
/// Returns:
|
||||
@@ -41,6 +43,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.get("entries")
|
||||
.and_then(|v| v.as_array())
|
||||
.context("Missing required parameter: entries")?;
|
||||
let default_ttl = arguments.get("ttl").and_then(|v| v.as_str());
|
||||
|
||||
// 3. Validate batch size
|
||||
if entries.is_empty() {
|
||||
@@ -79,6 +82,12 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.get("metadata")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
let ttl = entry
|
||||
.get("ttl")
|
||||
.and_then(|v| v.as_str())
|
||||
.or(default_ttl);
|
||||
let expires_at = expires_at_from_ttl(ttl)
|
||||
.with_context(|| format!("Invalid ttl for entry at index {}", idx))?;
|
||||
|
||||
// Generate embedding for this entry
|
||||
let embedding = embedding_engine
|
||||
@@ -88,7 +97,7 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
// Extract keywords
|
||||
let keywords = extract_keywords(content, 10);
|
||||
|
||||
processed_entries.push((content.to_string(), metadata, embedding, keywords));
|
||||
processed_entries.push((content.to_string(), metadata, embedding, keywords, expires_at));
|
||||
}
|
||||
|
||||
// 5. Batch DB insert (single transaction for atomicity)
|
||||
|
||||
@@ -30,6 +30,10 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Optional metadata to attach to the memory"
|
||||
},
|
||||
"ttl": {
|
||||
"type": "string",
|
||||
"description": "Optional time-to-live for transient facts, like 30s, 15m, 1h, 7d, or 2w"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
@@ -45,6 +49,10 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the agent storing the memories (default: 'default')"
|
||||
},
|
||||
"ttl": {
|
||||
"type": "string",
|
||||
"description": "Optional default time-to-live applied to entries without their own ttl"
|
||||
},
|
||||
"entries": {
|
||||
"type": "array",
|
||||
"description": "Array of 1-50 memory entries to store atomically",
|
||||
@@ -58,6 +66,10 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Optional metadata to attach to the memory"
|
||||
},
|
||||
"ttl": {
|
||||
"type": "string",
|
||||
"description": "Optional per-entry time-to-live override like 30s, 15m, 1h, 7d, or 2w"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
|
||||
@@ -81,7 +81,8 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
"hybrid_score": m.hybrid_score,
|
||||
"keywords": m.record.keywords,
|
||||
"metadata": m.record.metadata,
|
||||
"created_at": m.record.created_at.to_rfc3339()
|
||||
"created_at": m.record.created_at.to_rfc3339(),
|
||||
"expires_at": m.record.expires_at.as_ref().map(|ts| ts.to_rfc3339())
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::embedding::extract_keywords;
|
||||
use crate::ttl::expires_at_from_ttl;
|
||||
use crate::AppState;
|
||||
|
||||
/// Execute the store tool
|
||||
@@ -32,6 +33,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
|
||||
let ttl = arguments.get("ttl").and_then(|v| v.as_str());
|
||||
let expires_at = expires_at_from_ttl(ttl).context("Invalid ttl")?;
|
||||
|
||||
info!(
|
||||
"Storing memory for agent '{}': {} chars",
|
||||
agent_id,
|
||||
@@ -49,7 +53,14 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
// Store in database
|
||||
let id = state
|
||||
.db
|
||||
.store_memory(agent_id, content, &embedding, &keywords, metadata)
|
||||
.store_memory(
|
||||
agent_id,
|
||||
content,
|
||||
&embedding,
|
||||
&keywords,
|
||||
metadata,
|
||||
expires_at.clone(),
|
||||
)
|
||||
.await
|
||||
.context("Failed to store memory")?;
|
||||
|
||||
@@ -60,7 +71,9 @@ pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String>
|
||||
"id": id.to_string(),
|
||||
"agent_id": agent_id,
|
||||
"keywords": keywords,
|
||||
"embedding_dimension": embedding.len()
|
||||
"embedding_dimension": embedding.len(),
|
||||
"ttl": ttl,
|
||||
"expires_at": expires_at.as_ref().map(|ts| ts.to_rfc3339())
|
||||
})
|
||||
.to_string())
|
||||
}
|
||||
|
||||
49
src/ttl.rs
Normal file
49
src/ttl.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
|
||||
pub fn parse_ttl_spec(ttl: &str) -> Result<Duration> {
|
||||
let ttl = ttl.trim();
|
||||
if ttl.is_empty() {
|
||||
return Err(anyhow!("ttl must not be empty"));
|
||||
}
|
||||
|
||||
let (value, multiplier_seconds) = match ttl.chars().last() {
|
||||
Some('s') | Some('S') => (&ttl[..ttl.len() - 1], 1i64),
|
||||
Some('m') | Some('M') => (&ttl[..ttl.len() - 1], 60i64),
|
||||
Some('h') | Some('H') => (&ttl[..ttl.len() - 1], 60i64 * 60),
|
||||
Some('d') | Some('D') => (&ttl[..ttl.len() - 1], 60i64 * 60 * 24),
|
||||
Some('w') | Some('W') => (&ttl[..ttl.len() - 1], 60i64 * 60 * 24 * 7),
|
||||
_ => {
|
||||
return Err(anyhow!(
|
||||
"invalid ttl '{ttl}'. Use a positive duration like 30s, 15m, 1h, 7d, or 2w"
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let value: i64 = value
|
||||
.trim()
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("invalid ttl '{ttl}'. Duration value must be a positive integer"))?;
|
||||
if value <= 0 {
|
||||
return Err(anyhow!("invalid ttl '{ttl}'. Duration value must be greater than zero"));
|
||||
}
|
||||
|
||||
let total_seconds = value
|
||||
.checked_mul(multiplier_seconds)
|
||||
.ok_or_else(|| anyhow!("invalid ttl '{ttl}'. Duration is too large"))?;
|
||||
|
||||
Ok(Duration::seconds(total_seconds))
|
||||
}
|
||||
|
||||
pub fn expires_at_from_ttl(ttl: Option<&str>) -> Result<Option<DateTime<Utc>>> {
|
||||
match ttl {
|
||||
Some(ttl) => {
|
||||
let duration = parse_ttl_spec(ttl)?;
|
||||
Utc::now()
|
||||
.checked_add_signed(duration)
|
||||
.map(Some)
|
||||
.ok_or_else(|| anyhow!("ttl '{ttl}' overflows supported timestamp range"))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user