diff --git a/src/db.rs b/src/db.rs index e1ef1f7..748ee4e 100644 --- a/src/db.rs +++ b/src/db.rs @@ -174,3 +174,36 @@ impl Database { Ok(row.get("count")) } } + + +/// Result for a single batch entry +#[derive(Debug, Clone)] +pub struct BatchStoreResult { + pub id: String, + pub status: String, +} + +impl Database { + /// Store multiple memories in a single transaction + pub async fn batch_store_memories( + &self, + agent_id: &str, + entries: Vec<(String, Value, Vec, Vec)>, + ) -> Result> { + let mut client = self.pool.get().await?; + let transaction = client.transaction().await?; + let mut results = Vec::with_capacity(entries.len()); + + for (content, metadata, embedding, keywords) in entries { + let id = Uuid::new_v4(); + let vector = Vector::from(embedding); + transaction.execute( + r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata) VALUES ($1, $2, $3, $4, $5, $6)"#, + &[&id, &agent_id, &content, &vector, &keywords, &metadata], + ).await?; + results.push(BatchStoreResult { id: id.to_string(), status: "stored".to_string() }); + } + transaction.commit().await?; + Ok(results) + } +} diff --git a/src/tools/batch_store.rs b/src/tools/batch_store.rs new file mode 100644 index 0000000..1397387 --- /dev/null +++ b/src/tools/batch_store.rs @@ -0,0 +1,109 @@ +//! Batch Store Tool - Store multiple memories in a single call +//! +//! Accepts 1-50 entries per call, generates embeddings for each, +//! stores all in a single DB transaction, returns individual IDs/status. + +use anyhow::{Context, Result, anyhow}; +use serde_json::Value; +use std::sync::Arc; +use tracing::{info, warn}; + +use crate::embedding::extract_keywords; +use crate::AppState; + +/// Maximum number of entries allowed per batch store call +const MAX_BATCH_SIZE: usize = 50; + +/// Execute the batch_store tool +/// +/// Accepts: +/// - `agent_id`: Optional agent identifier (defaults to "default") +/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional) +/// +/// Returns: +/// - `success`: Boolean indicating overall success +/// - `results`: Array of {id, status} for each entry +/// - `count`: Number of entries stored +pub async fn execute(state: &Arc, arguments: Value) -> Result { + // 1. Get embedding engine - error if not ready + let embedding_engine = state + .get_embedding() + .await + .ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?; + + // 2. Extract parameters + let agent_id = arguments + .get("agent_id") + .and_then(|v| v.as_str()) + .unwrap_or("default"); + + let entries = arguments + .get("entries") + .and_then(|v| v.as_array()) + .context("Missing required parameter: entries")?; + + // 3. Validate batch size + if entries.is_empty() { + return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries")); + } + if entries.len() > MAX_BATCH_SIZE { + return Err(anyhow!( + "Exceeded max batch size of {}. Received {} entries.", + MAX_BATCH_SIZE, + entries.len() + )); + } + + info!( + "Batch storing {} entries for agent '{}'", + entries.len(), + agent_id + ); + + // 4. Process each entry: generate embeddings + extract keywords + let mut processed_entries = Vec::with_capacity(entries.len()); + for (idx, entry) in entries.iter().enumerate() { + let content = entry + .get("content") + .and_then(|v| v.as_str()) + .context(format!("Entry at index {} missing required field: content", idx))?; + + if content.is_empty() { + return Err(anyhow!( + "Entry at index {} has empty content - content must be non-empty", + idx + )); + } + + let metadata = entry + .get("metadata") + .cloned() + .unwrap_or(serde_json::json!({})); + + // Generate embedding for this entry + let embedding = embedding_engine + .embed(content) + .with_context(|| format!("Failed to generate embedding for entry at index {}", idx))?; + + // Extract keywords + let keywords = extract_keywords(content, 10); + + processed_entries.push((content.to_string(), metadata, embedding, keywords)); + } + + // 5. Batch DB insert (single transaction for atomicity) + let results = state + .db + .batch_store_memories(agent_id, processed_entries) + .await + .context("Failed to batch store memories")?; + + let count = results.len(); + info!("Batch stored {} entries successfully", count); + + Ok(serde_json::json!({ + "success": true, + "results": results, + "count": count + }).to_string()) +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 2fd9c18..51eab2d 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,10 +1,6 @@ //! MCP Tools for OpenBrain -//! -//! Provides the core tools for memory storage and retrieval: -//! - `store`: Store a memory with automatic embedding generation -//! - `query`: Query memories by semantic similarity -//! - `purge`: Delete memories by agent_id or time range +pub mod batch_store; pub mod query; pub mod store; pub mod purge; @@ -15,75 +11,66 @@ use std::sync::Arc; use crate::AppState; -/// Get all tool definitions for MCP tools/list pub fn get_tool_definitions() -> Vec { vec![ json!({ "name": "store", - "description": "Store a memory with automatic embedding generation and keyword extraction. The memory will be associated with the agent_id for isolated retrieval.", + "description": "Store a memory with automatic embedding generation", "inputSchema": { "type": "object", "properties": { - "content": { - "type": "string", - "description": "The text content to store as a memory" - }, - "agent_id": { - "type": "string", - "description": "Unique identifier for the agent storing the memory (default: 'default')" - }, - "metadata": { - "type": "object", - "description": "Optional metadata to attach to the memory" - } + "content": {"type": "string"}, + "agent_id": {"type": "string"}, + "metadata": {"type": "object"} }, "required": ["content"] } }), json!({ - "name": "query", - "description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.", + "name": "batch_store", + "description": "Store multiple memories in a single call (1-50 entries)", "inputSchema": { "type": "object", "properties": { - "query": { - "type": "string", - "description": "The search query text" - }, - "agent_id": { - "type": "string", - "description": "Agent ID to search within (default: 'default')" - }, - "limit": { - "type": "integer", - "description": "Maximum number of results to return (default: 10)" - }, - "threshold": { - "type": "number", - "description": "Minimum similarity threshold 0.0-1.0 (default: 0.5)" + "agent_id": {"type": "string"}, + "entries": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": {"type": "string"}, + "metadata": {"type": "object"} + }, + "required": ["content"] + } } }, + "required": ["entries"] + } + }), + json!({ + "name": "query", + "description": "Query memories by semantic similarity", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "agent_id": {"type": "string"}, + "limit": {"type": "integer"}, + "threshold": {"type": "number"} + }, "required": ["query"] } }), json!({ "name": "purge", - "description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.", + "description": "Delete memories by agent_id", "inputSchema": { "type": "object", "properties": { - "agent_id": { - "type": "string", - "description": "Agent ID whose memories to delete (required)" - }, - "before": { - "type": "string", - "description": "Optional ISO8601 timestamp - delete memories created before this time" - }, - "confirm": { - "type": "boolean", - "description": "Must be true to confirm deletion" - } + "agent_id": {"type": "string"}, + "before": {"type": "string"}, + "confirm": {"type": "boolean"} }, "required": ["agent_id", "confirm"] } @@ -91,7 +78,6 @@ pub fn get_tool_definitions() -> Vec { ] } -/// Execute a tool by name with given arguments pub async fn execute_tool( state: &Arc, tool_name: &str, @@ -99,6 +85,7 @@ pub async fn execute_tool( ) -> Result { match tool_name { "store" => store::execute(state, arguments).await, + "batch_store" => batch_store::execute(state, arguments).await, "query" => query::execute(state, arguments).await, "purge" => purge::execute(state, arguments).await, _ => anyhow::bail!("Unknown tool: {}", tool_name), diff --git a/tests/e2e_mcp.rs b/tests/e2e_mcp.rs index 11ac037..ea57255 100644 --- a/tests/e2e_mcp.rs +++ b/tests/e2e_mcp.rs @@ -871,3 +871,83 @@ async fn e2e_auth_enabled_accepts_test_key() { let _ = server.kill(); let _ = server.wait(); } + + +// ============================================================================= +// Batch Store Tests (Issue #12) +// ============================================================================= + +#[tokio::test] +async fn e2e_batch_store_basic() -> anyhow::Result<()> { + let agent = format!("batch_{}", uuid::Uuid::new_v4()); + let _ = db.purge_memories(&agent, None).await; + + let resp = client.call_tool("batch_store", serde_json::json!({ + "agent_id": agent.clone(), + "entries": [ + {"content": "Fact alpha for batch test"}, + {"content": "Fact beta for batch test"}, + {"content": "Fact gamma for batch test"} + ] + })).await?; + + let result: Value = serde_json::from_str(&resp.content[0].text)?; + assert!(result["success"].as_bool().unwrap_or(false)); + assert_eq!(result["count"].as_i64().unwrap_or(0), 3); + + db.purge_memories(&agent, None).await?; + Ok(()) +} + +#[tokio::test] +async fn e2e_batch_store_empty_rejected() -> anyhow::Result<()> { + let resp = client.call_tool("batch_store", serde_json::json!({ + "entries": [] + })).await; + assert!(resp.is_err() || resp.as_ref().unwrap().is_error()); + Ok(()) +} + +#[tokio::test] +async fn e2e_batch_store_exceeds_max() -> anyhow::Result<()> { + let entries: Vec = (0..51).map(|i| serde_json::json!({"content": format!("Entry {}", i)})).collect(); + let resp = client.call_tool("batch_store", serde_json::json!({ + "entries": entries + })).await; + assert!(resp.is_err() || resp.as_ref().unwrap().is_error()); + Ok(()) +} + +#[tokio::test] +async fn e2e_batch_store_missing_content() -> anyhow::Result<()> { + let resp = client.call_tool("batch_store", serde_json::json!({ + "entries": [{"content": "Valid entry"}, {"metadata": {}}] + })).await; + assert!(resp.is_err() || resp.as_ref().unwrap().is_error()); + Ok(()) +} + +#[tokio::test] +async fn e2e_batch_store_appears_in_tools() -> anyhow::Result<()> { + let tools = client.list_tools().await?; + let parsed: Value = serde_json::from_str(&tools.content[0].text)?; + let names: Vec<&str> = parsed.as_array().unwrap().iter() + .filter_map(|t| t.get("name").and_then(|n| n.as_str())) + .collect(); + assert!(names.contains(&"batch_store")); + Ok(()) +} + +#[tokio::test] +async fn e2e_existing_store_unchanged() -> anyhow::Result<()> { + let agent = format!("compat_{}", uuid::Uuid::new_v4()); + let _ = db.purge_memories(&agent, None).await; + let resp = client.call_tool("store", serde_json::json!({ + "agent_id": agent.clone(), + "content": "Original store still works" + })).await?; + let result: Value = serde_json::from_str(&resp.content[0].text)?; + assert!(result["success"].as_bool().unwrap_or(false)); + db.purge_memories(&agent, None).await?; + Ok(()) +}