//! Embedding engine using local ONNX models use anyhow::Result; use ort::session::{builder::GraphOptimizationLevel, Session}; use ort::value::Value; use std::path::{Path, PathBuf}; use std::sync::Once; use tokenizers::Tokenizer; use tracing::info; use crate::config::EmbeddingConfig; static ORT_INIT: Once = Once::new(); /// Initialize ONNX Runtime synchronously (called inside spawn_blocking) fn init_ort_sync(dylib_path: &str) -> Result<()> { info!("Initializing ONNX Runtime from: {}", dylib_path); let mut init_error: Option = None; ORT_INIT.call_once(|| { info!("ORT_INIT.call_once - starting initialization"); match ort::init_from(dylib_path) { Ok(builder) => { info!("ort::init_from succeeded, calling commit()"); let committed = builder.commit(); info!("commit() returned: {}", committed); if !committed { init_error = Some("ONNX Runtime commit returned false".to_string()); } } Err(e) => { let err_msg = format!("ONNX Runtime init_from failed: {:?}", e); info!("{}", err_msg); init_error = Some(err_msg); } } info!("ORT_INIT.call_once - finished"); }); // Note: init_error won't be set if ORT_INIT was already called // This is fine - we only initialize once if let Some(err) = init_error { return Err(anyhow::anyhow!("{}", err)); } info!("ONNX Runtime initialization complete"); Ok(()) } /// Resolve ONNX Runtime dylib path from env var or common local install locations. fn resolve_ort_dylib_path() -> Result { if let Ok(path) = std::env::var("ORT_DYLIB_PATH") { if Path::new(&path).exists() { return Ok(path); } return Err(anyhow::anyhow!( "ORT_DYLIB_PATH is set but file does not exist: {}", path )); } let candidates = [ "/opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib", "/usr/local/opt/onnxruntime/lib/libonnxruntime.dylib", ]; for candidate in candidates { if Path::new(candidate).exists() { return Ok(candidate.to_string()); } } Err(anyhow::anyhow!( "ORT_DYLIB_PATH environment variable not set and ONNX Runtime dylib not found. \ Set ORT_DYLIB_PATH to your libonnxruntime.dylib path (for example: /opt/homebrew/opt/onnxruntime/lib/libonnxruntime.dylib)." )) } pub struct EmbeddingEngine { session: std::sync::Mutex, tokenizer: Tokenizer, dimension: usize, } impl EmbeddingEngine { /// Create a new embedding engine pub async fn new(config: &EmbeddingConfig) -> Result { let dylib_path = resolve_ort_dylib_path()?; let model_path = PathBuf::from(&config.model_path); let dimension = config.dimension; info!( "Loading ONNX model from {:?}", model_path.join("model.onnx") ); // Use spawn_blocking to avoid blocking the async runtime let (session, tokenizer) = tokio::task::spawn_blocking(move || -> Result<(Session, Tokenizer)> { // Initialize ONNX Runtime first init_ort_sync(&dylib_path)?; info!("Creating ONNX session..."); // Load ONNX model with ort 2.0 API let session = Session::builder() .map_err(|e| anyhow::anyhow!("Failed to create session builder: {:?}", e))? .with_optimization_level(GraphOptimizationLevel::Level3) .map_err(|e| anyhow::anyhow!("Failed to set optimization level: {:?}", e))? .with_intra_threads(4) .map_err(|e| anyhow::anyhow!("Failed to set intra threads: {:?}", e))? .commit_from_file(model_path.join("model.onnx")) .map_err(|e| anyhow::anyhow!("Failed to load ONNX model: {:?}", e))?; info!("ONNX model loaded, loading tokenizer..."); // Load tokenizer let tokenizer = Tokenizer::from_file(model_path.join("tokenizer.json")) .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?; info!("Tokenizer loaded successfully"); Ok((session, tokenizer)) }) .await .map_err(|e| anyhow::anyhow!("Spawn blocking failed: {:?}", e))??; info!( "Embedding engine initialized: model={}, dimension={}", config.model_path, dimension ); Ok(Self { session: std::sync::Mutex::new(session), tokenizer, dimension, }) } /// Generate embedding for a single text pub fn embed(&self, text: &str) -> Result> { let encoding = self .tokenizer .encode(text, true) .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; let input_ids: Vec = encoding.get_ids().iter().map(|&x| x as i64).collect(); let attention_mask: Vec = encoding .get_attention_mask() .iter() .map(|&x| x as i64) .collect(); let token_type_ids: Vec = encoding.get_type_ids().iter().map(|&x| x as i64).collect(); let seq_len = input_ids.len(); // Create input tensors with ort 2.0 API let input_ids_tensor = Value::from_array(([1, seq_len], input_ids))?; let attention_mask_tensor = Value::from_array(([1, seq_len], attention_mask))?; let token_type_ids_tensor = Value::from_array(([1, seq_len], token_type_ids))?; // Run inference let inputs = ort::inputs![ "input_ids" => input_ids_tensor, "attention_mask" => attention_mask_tensor, "token_type_ids" => token_type_ids_tensor, ]; let mut session_guard = self .session .lock() .map_err(|e| anyhow::anyhow!("Session lock poisoned: {}", e))?; let outputs = session_guard.run(inputs)?; // Extract output let output = outputs .get("last_hidden_state") .ok_or_else(|| anyhow::anyhow!("Missing last_hidden_state output"))?; // Get the tensor data let (shape, data) = output.try_extract_tensor::()?; // Mean pooling over sequence dimension let hidden_size = *shape.last().unwrap_or(&384) as usize; let seq_len = data.len() / hidden_size; let mut embedding = vec![0.0f32; hidden_size]; for i in 0..seq_len { for j in 0..hidden_size { embedding[j] += data[i * hidden_size + j]; } } for val in &mut embedding { *val /= seq_len as f32; } // L2 normalize let norm: f32 = embedding.iter().map(|x| x * x).sum::().sqrt(); if norm > 0.0 { for val in &mut embedding { *val /= norm; } } Ok(embedding) } /// Generate embeddings for multiple texts pub fn embed_batch(&self, texts: &[&str]) -> Result>> { texts.iter().map(|text| self.embed(text)).collect() } /// Get the embedding dimension pub fn dimension(&self) -> usize { self.dimension } } /// Extract keywords from text using simple frequency analysis pub fn extract_keywords(text: &str, limit: usize) -> Vec { use std::collections::HashMap; let stop_words: std::collections::HashSet<&str> = [ "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can", "this", "that", "these", "those", "i", "you", "he", "she", "it", "we", "they", "what", "which", "who", "whom", "whose", "where", "when", "why", "how", "all", "each", "every", "both", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just", "also", "now", "here", "there", "then", "once", "if", ] .iter() .cloned() .collect(); let mut word_counts: HashMap = HashMap::new(); for word in text.split_whitespace() { let clean: String = word .chars() .filter(|c| c.is_alphanumeric()) .collect::() .to_lowercase(); if clean.len() > 2 && !stop_words.contains(clean.as_str()) { *word_counts.entry(clean).or_insert(0) += 1; } } let mut sorted: Vec<_> = word_counts.into_iter().collect(); sorted.sort_by(|a, b| b.1.cmp(&a.1)); sorted .into_iter() .take(limit) .map(|(word, _)| word) .collect() }