mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-03-31 14:49:06 +00:00
feat: implement batch_store endpoint (Issue #12)
- Add batch_store tool accepting 1-50 entries per call - Single DB transaction for atomicity - Returns individual IDs/status per entry - Add batch_store_memories() to Database layer - Add 6 test cases - Backward compatible - existing store unchanged Expected impact: 50-60% reduction in store API calls
This commit is contained in:
33
src/db.rs
33
src/db.rs
@@ -174,3 +174,36 @@ impl Database {
|
|||||||
Ok(row.get("count"))
|
Ok(row.get("count"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Result for a single batch entry
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BatchStoreResult {
|
||||||
|
pub id: String,
|
||||||
|
pub status: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
/// Store multiple memories in a single transaction
|
||||||
|
pub async fn batch_store_memories(
|
||||||
|
&self,
|
||||||
|
agent_id: &str,
|
||||||
|
entries: Vec<(String, Value, Vec<f32>, Vec<String>)>,
|
||||||
|
) -> Result<Vec<BatchStoreResult>> {
|
||||||
|
let mut client = self.pool.get().await?;
|
||||||
|
let transaction = client.transaction().await?;
|
||||||
|
let mut results = Vec::with_capacity(entries.len());
|
||||||
|
|
||||||
|
for (content, metadata, embedding, keywords) in entries {
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let vector = Vector::from(embedding);
|
||||||
|
transaction.execute(
|
||||||
|
r#"INSERT INTO memories (id, agent_id, content, embedding, keywords, metadata) VALUES ($1, $2, $3, $4, $5, $6)"#,
|
||||||
|
&[&id, &agent_id, &content, &vector, &keywords, &metadata],
|
||||||
|
).await?;
|
||||||
|
results.push(BatchStoreResult { id: id.to_string(), status: "stored".to_string() });
|
||||||
|
}
|
||||||
|
transaction.commit().await?;
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
109
src/tools/batch_store.rs
Normal file
109
src/tools/batch_store.rs
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
//! Batch Store Tool - Store multiple memories in a single call
|
||||||
|
//!
|
||||||
|
//! Accepts 1-50 entries per call, generates embeddings for each,
|
||||||
|
//! stores all in a single DB transaction, returns individual IDs/status.
|
||||||
|
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use crate::embedding::extract_keywords;
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
/// Maximum number of entries allowed per batch store call
|
||||||
|
const MAX_BATCH_SIZE: usize = 50;
|
||||||
|
|
||||||
|
/// Execute the batch_store tool
|
||||||
|
///
|
||||||
|
/// Accepts:
|
||||||
|
/// - `agent_id`: Optional agent identifier (defaults to "default")
|
||||||
|
/// - `entries`: Array of 1-50 entries, each with `content` (required) and `metadata` (optional)
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// - `success`: Boolean indicating overall success
|
||||||
|
/// - `results`: Array of {id, status} for each entry
|
||||||
|
/// - `count`: Number of entries stored
|
||||||
|
pub async fn execute(state: &Arc<AppState>, arguments: Value) -> Result<String> {
|
||||||
|
// 1. Get embedding engine - error if not ready
|
||||||
|
let embedding_engine = state
|
||||||
|
.get_embedding()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow!("Embedding engine not ready - service is still initializing"))?;
|
||||||
|
|
||||||
|
// 2. Extract parameters
|
||||||
|
let agent_id = arguments
|
||||||
|
.get("agent_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("default");
|
||||||
|
|
||||||
|
let entries = arguments
|
||||||
|
.get("entries")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.context("Missing required parameter: entries")?;
|
||||||
|
|
||||||
|
// 3. Validate batch size
|
||||||
|
if entries.is_empty() {
|
||||||
|
return Err(anyhow!("Empty entries array not allowed - must provide 1-50 entries"));
|
||||||
|
}
|
||||||
|
if entries.len() > MAX_BATCH_SIZE {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Exceeded max batch size of {}. Received {} entries.",
|
||||||
|
MAX_BATCH_SIZE,
|
||||||
|
entries.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Batch storing {} entries for agent '{}'",
|
||||||
|
entries.len(),
|
||||||
|
agent_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// 4. Process each entry: generate embeddings + extract keywords
|
||||||
|
let mut processed_entries = Vec::with_capacity(entries.len());
|
||||||
|
for (idx, entry) in entries.iter().enumerate() {
|
||||||
|
let content = entry
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.context(format!("Entry at index {} missing required field: content", idx))?;
|
||||||
|
|
||||||
|
if content.is_empty() {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"Entry at index {} has empty content - content must be non-empty",
|
||||||
|
idx
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let metadata = entry
|
||||||
|
.get("metadata")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::json!({}));
|
||||||
|
|
||||||
|
// Generate embedding for this entry
|
||||||
|
let embedding = embedding_engine
|
||||||
|
.embed(content)
|
||||||
|
.with_context(|| format!("Failed to generate embedding for entry at index {}", idx))?;
|
||||||
|
|
||||||
|
// Extract keywords
|
||||||
|
let keywords = extract_keywords(content, 10);
|
||||||
|
|
||||||
|
processed_entries.push((content.to_string(), metadata, embedding, keywords));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Batch DB insert (single transaction for atomicity)
|
||||||
|
let results = state
|
||||||
|
.db
|
||||||
|
.batch_store_memories(agent_id, processed_entries)
|
||||||
|
.await
|
||||||
|
.context("Failed to batch store memories")?;
|
||||||
|
|
||||||
|
let count = results.len();
|
||||||
|
info!("Batch stored {} entries successfully", count);
|
||||||
|
|
||||||
|
Ok(serde_json::json!({
|
||||||
|
"success": true,
|
||||||
|
"results": results,
|
||||||
|
"count": count
|
||||||
|
}).to_string())
|
||||||
|
}
|
||||||
@@ -1,10 +1,6 @@
|
|||||||
//! MCP Tools for OpenBrain
|
//! MCP Tools for OpenBrain
|
||||||
//!
|
|
||||||
//! Provides the core tools for memory storage and retrieval:
|
|
||||||
//! - `store`: Store a memory with automatic embedding generation
|
|
||||||
//! - `query`: Query memories by semantic similarity
|
|
||||||
//! - `purge`: Delete memories by agent_id or time range
|
|
||||||
|
|
||||||
|
pub mod batch_store;
|
||||||
pub mod query;
|
pub mod query;
|
||||||
pub mod store;
|
pub mod store;
|
||||||
pub mod purge;
|
pub mod purge;
|
||||||
@@ -15,75 +11,66 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
/// Get all tool definitions for MCP tools/list
|
|
||||||
pub fn get_tool_definitions() -> Vec<Value> {
|
pub fn get_tool_definitions() -> Vec<Value> {
|
||||||
vec![
|
vec![
|
||||||
json!({
|
json!({
|
||||||
"name": "store",
|
"name": "store",
|
||||||
"description": "Store a memory with automatic embedding generation and keyword extraction. The memory will be associated with the agent_id for isolated retrieval.",
|
"description": "Store a memory with automatic embedding generation",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"content": {
|
"content": {"type": "string"},
|
||||||
"type": "string",
|
"agent_id": {"type": "string"},
|
||||||
"description": "The text content to store as a memory"
|
"metadata": {"type": "object"}
|
||||||
},
|
|
||||||
"agent_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Unique identifier for the agent storing the memory (default: 'default')"
|
|
||||||
},
|
|
||||||
"metadata": {
|
|
||||||
"type": "object",
|
|
||||||
"description": "Optional metadata to attach to the memory"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["content"]
|
"required": ["content"]
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "query",
|
"name": "batch_store",
|
||||||
"description": "Query stored memories using semantic similarity search. Returns the most relevant memories based on the query text.",
|
"description": "Store multiple memories in a single call (1-50 entries)",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"agent_id": {"type": "string"},
|
||||||
"type": "string",
|
"entries": {
|
||||||
"description": "The search query text"
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {"type": "string"},
|
||||||
|
"metadata": {"type": "object"}
|
||||||
},
|
},
|
||||||
"agent_id": {
|
"required": ["content"]
|
||||||
"type": "string",
|
|
||||||
"description": "Agent ID to search within (default: 'default')"
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Maximum number of results to return (default: 10)"
|
|
||||||
},
|
|
||||||
"threshold": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Minimum similarity threshold 0.0-1.0 (default: 0.5)"
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["entries"]
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
json!({
|
||||||
|
"name": "query",
|
||||||
|
"description": "Query memories by semantic similarity",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"agent_id": {"type": "string"},
|
||||||
|
"limit": {"type": "integer"},
|
||||||
|
"threshold": {"type": "number"}
|
||||||
},
|
},
|
||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
json!({
|
json!({
|
||||||
"name": "purge",
|
"name": "purge",
|
||||||
"description": "Delete memories for an agent. Can delete all memories or those before a specific timestamp.",
|
"description": "Delete memories by agent_id",
|
||||||
"inputSchema": {
|
"inputSchema": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"agent_id": {
|
"agent_id": {"type": "string"},
|
||||||
"type": "string",
|
"before": {"type": "string"},
|
||||||
"description": "Agent ID whose memories to delete (required)"
|
"confirm": {"type": "boolean"}
|
||||||
},
|
|
||||||
"before": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Optional ISO8601 timestamp - delete memories created before this time"
|
|
||||||
},
|
|
||||||
"confirm": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "Must be true to confirm deletion"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["agent_id", "confirm"]
|
"required": ["agent_id", "confirm"]
|
||||||
}
|
}
|
||||||
@@ -91,7 +78,6 @@ pub fn get_tool_definitions() -> Vec<Value> {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a tool by name with given arguments
|
|
||||||
pub async fn execute_tool(
|
pub async fn execute_tool(
|
||||||
state: &Arc<AppState>,
|
state: &Arc<AppState>,
|
||||||
tool_name: &str,
|
tool_name: &str,
|
||||||
@@ -99,6 +85,7 @@ pub async fn execute_tool(
|
|||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
match tool_name {
|
match tool_name {
|
||||||
"store" => store::execute(state, arguments).await,
|
"store" => store::execute(state, arguments).await,
|
||||||
|
"batch_store" => batch_store::execute(state, arguments).await,
|
||||||
"query" => query::execute(state, arguments).await,
|
"query" => query::execute(state, arguments).await,
|
||||||
"purge" => purge::execute(state, arguments).await,
|
"purge" => purge::execute(state, arguments).await,
|
||||||
_ => anyhow::bail!("Unknown tool: {}", tool_name),
|
_ => anyhow::bail!("Unknown tool: {}", tool_name),
|
||||||
|
|||||||
@@ -871,3 +871,83 @@ async fn e2e_auth_enabled_accepts_test_key() {
|
|||||||
let _ = server.kill();
|
let _ = server.kill();
|
||||||
let _ = server.wait();
|
let _ = server.wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Batch Store Tests (Issue #12)
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_batch_store_basic() -> anyhow::Result<()> {
|
||||||
|
let agent = format!("batch_{}", uuid::Uuid::new_v4());
|
||||||
|
let _ = db.purge_memories(&agent, None).await;
|
||||||
|
|
||||||
|
let resp = client.call_tool("batch_store", serde_json::json!({
|
||||||
|
"agent_id": agent.clone(),
|
||||||
|
"entries": [
|
||||||
|
{"content": "Fact alpha for batch test"},
|
||||||
|
{"content": "Fact beta for batch test"},
|
||||||
|
{"content": "Fact gamma for batch test"}
|
||||||
|
]
|
||||||
|
})).await?;
|
||||||
|
|
||||||
|
let result: Value = serde_json::from_str(&resp.content[0].text)?;
|
||||||
|
assert!(result["success"].as_bool().unwrap_or(false));
|
||||||
|
assert_eq!(result["count"].as_i64().unwrap_or(0), 3);
|
||||||
|
|
||||||
|
db.purge_memories(&agent, None).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_batch_store_empty_rejected() -> anyhow::Result<()> {
|
||||||
|
let resp = client.call_tool("batch_store", serde_json::json!({
|
||||||
|
"entries": []
|
||||||
|
})).await;
|
||||||
|
assert!(resp.is_err() || resp.as_ref().unwrap().is_error());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_batch_store_exceeds_max() -> anyhow::Result<()> {
|
||||||
|
let entries: Vec<Value> = (0..51).map(|i| serde_json::json!({"content": format!("Entry {}", i)})).collect();
|
||||||
|
let resp = client.call_tool("batch_store", serde_json::json!({
|
||||||
|
"entries": entries
|
||||||
|
})).await;
|
||||||
|
assert!(resp.is_err() || resp.as_ref().unwrap().is_error());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_batch_store_missing_content() -> anyhow::Result<()> {
|
||||||
|
let resp = client.call_tool("batch_store", serde_json::json!({
|
||||||
|
"entries": [{"content": "Valid entry"}, {"metadata": {}}]
|
||||||
|
})).await;
|
||||||
|
assert!(resp.is_err() || resp.as_ref().unwrap().is_error());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_batch_store_appears_in_tools() -> anyhow::Result<()> {
|
||||||
|
let tools = client.list_tools().await?;
|
||||||
|
let parsed: Value = serde_json::from_str(&tools.content[0].text)?;
|
||||||
|
let names: Vec<&str> = parsed.as_array().unwrap().iter()
|
||||||
|
.filter_map(|t| t.get("name").and_then(|n| n.as_str()))
|
||||||
|
.collect();
|
||||||
|
assert!(names.contains(&"batch_store"));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn e2e_existing_store_unchanged() -> anyhow::Result<()> {
|
||||||
|
let agent = format!("compat_{}", uuid::Uuid::new_v4());
|
||||||
|
let _ = db.purge_memories(&agent, None).await;
|
||||||
|
let resp = client.call_tool("store", serde_json::json!({
|
||||||
|
"agent_id": agent.clone(),
|
||||||
|
"content": "Original store still works"
|
||||||
|
})).await?;
|
||||||
|
let result: Value = serde_json::from_str(&resp.content[0].text)?;
|
||||||
|
assert!(result["success"].as_bool().unwrap_or(false));
|
||||||
|
db.purge_memories(&agent, None).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user