Initial public release

This commit is contained in:
Agent Zero
2026-03-07 13:41:36 -05:00
commit 774982dc5a
22 changed files with 3517 additions and 0 deletions

245
src/embedding.rs Normal file
View File

@@ -0,0 +1,245 @@
//! Embedding engine using local ONNX models
use anyhow::Result;
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use std::path::{Path, PathBuf};
use std::sync::Once;
use tokenizers::Tokenizer;
use tracing::info;
use crate::config::EmbeddingConfig;
static ORT_INIT: Once = Once::new();
/// Initialize ONNX Runtime synchronously (called inside spawn_blocking)
fn init_ort_sync(dylib_path: &str) -> Result<()> {
info!("Initializing ONNX Runtime from: {}", dylib_path);
let mut init_error: Option<String> = None;
ORT_INIT.call_once(|| {
info!("ORT_INIT.call_once - starting initialization");
match ort::init_from(dylib_path) {
Ok(builder) => {
info!("ort::init_from succeeded, calling commit()");
let committed = builder.commit();
info!("commit() returned: {}", committed);
if !committed {
init_error = Some("ONNX Runtime commit returned false".to_string());
}
}
Err(e) => {
let err_msg = format!("ONNX Runtime init_from failed: {:?}", e);
info!("{}", err_msg);
init_error = Some(err_msg);
}
}
info!("ORT_INIT.call_once - finished");
});
// Note: init_error won't be set if ORT_INIT was already called
// This is fine - we only initialize once
if let Some(err) = init_error {
return Err(anyhow::anyhow!("{}", err));
}
info!("ONNX Runtime initialization complete");
Ok(())
}
/// Resolve ONNX Runtime dylib path from env var or common local install locations.
fn resolve_ort_dylib_path() -> Result<String> {
if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
if Path::new(&path).exists() {
return Ok(path);
}
return Err(anyhow::anyhow!(
"ORT_DYLIB_PATH is set but file does not exist: {}",
path
));
}
let candidates = [
"/opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib",
"/usr/local/opt/onnxruntime/lib/libonnxruntime.dylib",
];
for candidate in candidates {
if Path::new(candidate).exists() {
return Ok(candidate.to_string());
}
}
Err(anyhow::anyhow!(
"ORT_DYLIB_PATH environment variable not set and ONNX Runtime dylib not found. \
Set ORT_DYLIB_PATH to your libonnxruntime.dylib path (for example: /opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib)."
))
}
pub struct EmbeddingEngine {
session: std::sync::Mutex<Session>,
tokenizer: Tokenizer,
dimension: usize,
}
impl EmbeddingEngine {
/// Create a new embedding engine
pub async fn new(config: &EmbeddingConfig) -> Result<Self> {
let dylib_path = resolve_ort_dylib_path()?;
let model_path = PathBuf::from(&config.model_path);
let dimension = config.dimension;
info!("Loading ONNX model from {:?}", model_path.join("model.onnx"));
// Use spawn_blocking to avoid blocking the async runtime
let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> {
// Initialize ONNX Runtime first
init_ort_sync(&dylib_path)?;
info!("Creating ONNX session...");
// Load ONNX model with ort 2.0 API
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))?
.with_intra_threads(4)
.map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))?
.commit_from_file(model_path.join("model.onnx"))
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?;
info!("ONNX model loaded, loading tokenizer...");
// Load tokenizer
let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json"))
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
info!("Tokenizer loaded successfully");
Ok((session, tokenizer))
}).await
.map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??;
info!(
"Embedding engine initialized: model={}, dimension={}",
config.model_path, dimension
);
Ok(Self {
session: std::sync::Mutex::new(session),
tokenizer,
dimension,
})
}
/// Generate embedding for a single text
pub fn embed(&self, text: &str) -> Result<Vec<f32>> {
let encoding = self.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
let token_type_ids: Vec<i64> = encoding.get_type_ids().iter().map(|&x| x as i64).collect();
let seq_len = input_ids.len();
// Create input tensors with ort 2.0 API
let input_ids_tensor = Value::from_array(([1, seq_len], input_ids))?;
let attention_mask_tensor = Value::from_array(([1, seq_len], attention_mask))?;
let token_type_ids_tensor = Value::from_array(([1, seq_len], token_type_ids))?;
// Run inference
let inputs = ort::inputs![
"input_ids" => input_ids_tensor,
"attention_mask" => attention_mask_tensor,
"token_type_ids" => token_type_ids_tensor,
];
let mut session_guard = self.session.lock()
.map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?;
let outputs = session_guard.run(inputs)?;
// Extract output
let output = outputs.get("last_hidden_state")
.ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?;
// Get the tensor data
let (shape, data) = output.try_extract_tensor::<f32>()?;
// Mean pooling over sequence dimension
let hidden_size = *shape.last().unwrap_or(&384) as usize;
let seq_len = data.len() / hidden_size;
let mut embedding = vec![0.0f32; hidden_size];
for i in 0..seq_len {
for j in 0..hidden_size {
embedding[j] += data[i * hidden_size + j];
}
}
for val in &mut embedding {
*val /= seq_len as f32;
}
// L2 normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for val in &mut embedding {
*val /= norm;
}
}
Ok(embedding)
}
/// Generate embeddings for multiple texts
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
/// Get the embedding dimension
pub fn dimension(&self) -> usize {
self.dimension
}
}
/// Extract keywords from text using simple frequency analysis
pub fn extract_keywords(text: &str, limit: usize) -> Vec<String> {
use std::collections::HashMap;
let stop_words: std::collections::HashSet<&str> = [
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "must", "shall", "can", "this",
"that", "these", "those", "i", "you", "he", "she", "it", "we", "they",
"what", "which", "who", "whom", "whose", "where", "when", "why", "how",
"all", "each", "every", "both", "few", "more", "most", "other", "some",
"such", "no", "nor", "not", "only", "own", "same", "so", "than", "too",
"very", "just", "also", "now", "here", "there", "then", "once", "if",
].iter().cloned().collect();
let mut word_counts: HashMap<String, usize> = HashMap::new();
for word in text.split_whitespace() {
let clean: String = word
.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
.to_lowercase();
if clean.len() > 2 && !stop_words.contains(clean.as_str()) {
*word_counts.entry(clean).or_insert(0) += 1;
}
}
let mut sorted: Vec<_> = word_counts.into_iter().collect();
sorted.sort_by(|a, b| b.1.cmp(&a.1));
sorted.into_iter()
.take(limit)
.map(|(word, _)| word)
.collect()
}