mirror of
https://gitea.ingwaz.work/Ingwaz/openbrain-mcp.git
synced 2026-06-15 22:07:08 +00:00
Scope memories by API token and add shared-token e2e coverage
This commit is contained in:
130
src/embedding.rs
130
src/embedding.rs
@@ -1,7 +1,7 @@
|
||||
//! Embedding engine using local ONNX models
|
||||
|
||||
use anyhow::Result;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::session::{builder::GraphOptimizationLevel, Session};
|
||||
use ort::value::Value;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Once;
|
||||
@@ -15,9 +15,9 @@ static ORT_INIT: Once = Once::new();
|
||||
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
|
||||
fn init_ort_sync(dylib_path: &str) -> Result<()> {
|
||||
info!("Initializing ONNX Runtime from: {}", dylib_path);
|
||||
|
||||
|
||||
let mut init_error: Option<String> = None;
|
||||
|
||||
|
||||
ORT_INIT.call_once(|| {
|
||||
info!("ORT_INIT.call_once - starting initialization");
|
||||
match ort::init_from(dylib_path) {
|
||||
@@ -43,7 +43,7 @@ fn init_ort_sync(dylib_path: &str) -> Result<()> {
|
||||
if let Some(err) = init_error {
|
||||
return Err(anyhow::anyhow!("{}", err));
|
||||
}
|
||||
|
||||
|
||||
info!("ONNX Runtime initialization complete");
|
||||
Ok(())
|
||||
}
|
||||
@@ -91,35 +91,40 @@ impl EmbeddingEngine {
|
||||
|
||||
let model_path = PathBuf::from(&config.model_path);
|
||||
let dimension = config.dimension;
|
||||
|
||||
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
|
||||
|
||||
|
||||
info!(
|
||||
"Loading ONNX model from {:?}",
|
||||
model_path.join("model.onnx")
|
||||
);
|
||||
|
||||
// Use spawn_blocking to avoid blocking the async runtime
|
||||
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
||||
// Initialize ONNX Runtime first
|
||||
init_ort_sync(&dylib_path)?;
|
||||
|
||||
info!("Creating ONNX session...");
|
||||
|
||||
// Load ONNX model with ort 2.0 API
|
||||
let session = Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
||||
.commit_from_file(model_path.join("model.onnx"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
||||
let (session, tokenizer) =
|
||||
tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
|
||||
// Initialize ONNX Runtime first
|
||||
init_ort_sync(&dylib_path)?;
|
||||
|
||||
info!("ONNX model loaded, loading tokenizer...");
|
||||
|
||||
// Load tokenizer
|
||||
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
||||
info!("Creating ONNX session...");
|
||||
|
||||
info!("Tokenizer loaded successfully");
|
||||
Ok((session, tokenizer))
|
||||
}).await
|
||||
// Load ONNX model with ort 2.0 API
|
||||
let session = Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
|
||||
.commit_from_file(model_path.join("model.onnx"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
|
||||
|
||||
info!("ONNX model loaded, loading tokenizer...");
|
||||
|
||||
// Load tokenizer
|
||||
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
|
||||
|
||||
info!("Tokenizer loaded successfully");
|
||||
Ok((session, tokenizer))
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
|
||||
|
||||
info!(
|
||||
@@ -136,12 +141,17 @@ impl EmbeddingEngine {
|
||||
|
||||
/// Generate embedding for a single text
|
||||
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
|
||||
let encoding = self.tokenizer
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(text, true)
|
||||
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
|
||||
|
||||
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
|
||||
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
|
||||
let attention_mask: Vec<i64> = encoding
|
||||
.get_attention_mask()
|
||||
.iter()
|
||||
.map(|&x| x as i64)
|
||||
.collect();
|
||||
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
|
||||
|
||||
let seq_len = input_ids.len();
|
||||
@@ -157,22 +167,25 @@ impl EmbeddingEngine {
|
||||
"attention_mask" => attention_mask_tensor,
|
||||
"token_type_ids" => token_type_ids_tensor,
|
||||
];
|
||||
|
||||
let mut session_guard = self.session.lock()
|
||||
|
||||
let mut session_guard = self
|
||||
.session
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
|
||||
let outputs = session_guard.run(inputs)?;
|
||||
|
||||
// Extract output
|
||||
let output = outputs.get("last_hidden_state")
|
||||
let output = outputs
|
||||
.get("last_hidden_state")
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
|
||||
|
||||
|
||||
// Get the tensor data
|
||||
let (shape, data) = output.try_extract_tensor::<f32>()?;
|
||||
|
||||
|
||||
// Mean pooling over sequence dimension
|
||||
let hidden_size = *shape.last().unwrap_or(&384) as usize;
|
||||
let seq_len = data.len() / hidden_size;
|
||||
|
||||
|
||||
let mut embedding = vec![0.0f32; hidden_size];
|
||||
for i in 0..seq_len {
|
||||
for j in 0..hidden_size {
|
||||
@@ -182,7 +195,7 @@ impl EmbeddingEngine {
|
||||
for val in &mut embedding {
|
||||
*val /= seq_len as f32;
|
||||
}
|
||||
|
||||
|
||||
// L2 normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
@@ -190,7 +203,7 @@ impl EmbeddingEngine {
|
||||
*val /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
@@ -208,37 +221,40 @@ impl EmbeddingEngine {
|
||||
/// Extract keywords from text using simple frequency analysis
|
||||
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
|
||||
use std::collections::HashMap;
|
||||
|
||||
|
||||
let stop_words: std::collections::HashSet<&str> = [
|
||||
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
|
||||
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
||||
"could", "should", "may", "might", "must", "shall", "can", "this",
|
||||
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
|
||||
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
|
||||
"all", "each", "every", "both", "few", "more", "most", "other", "some",
|
||||
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
|
||||
"very", "just", "also", "now", "here", "there", "then", "once", "if",
|
||||
].iter().cloned().collect();
|
||||
|
||||
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by",
|
||||
"from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do",
|
||||
"does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can",
|
||||
"this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what",
|
||||
"which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every",
|
||||
"both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own",
|
||||
"same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then",
|
||||
"once", "if",
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut word_counts: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
|
||||
for word in text.split_whitespace() {
|
||||
let clean: String = word
|
||||
.chars()
|
||||
.filter(|c| c.is_alphanumeric())
|
||||
.collect::<String>()
|
||||
.to_lowercase();
|
||||
|
||||
|
||||
if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
|
||||
*word_counts.entry(clean).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let mut sorted: Vec<_> = word_counts.into_iter().collect();
|
||||
sorted.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
sorted.into_iter()
|
||||
|
||||
sorted
|
||||
.into_iter()
|
||||
.take(limit)
|
||||
.map(|(word, _)| word)
|
||||
.collect()
|
||||
|
||||
Reference in New Issue
Block a user