| |
|
|
| use crate::{Error, Result}; |
| use ndarray::{Array, IxDyn}; |
| use std::collections::HashMap; |
| use std::path::{Path, PathBuf}; |
| use std::sync::{Arc, RwLock}; |
|
|
| |
| pub struct OnnxSession { |
| input_names: Vec<String>, |
| output_names: Vec<String>, |
| } |
|
|
| impl OnnxSession { |
| |
| pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| let path = path.as_ref(); |
| if !path.exists() { |
| return Err(Error::FileNotFound(path.display().to_string())); |
| } |
|
|
| |
| log::info!("Loading ONNX model from: {}", path.display()); |
|
|
| Ok(Self { |
| input_names: vec!["input".to_string()], |
| output_names: vec!["output".to_string()], |
| }) |
| } |
|
|
| |
| pub fn run( |
| &self, |
| _inputs: HashMap<String, Array<f32, IxDyn>>, |
| ) -> Result<HashMap<String, Array<f32, IxDyn>>> { |
| |
| let mut result = HashMap::new(); |
| for name in &self.output_names { |
| let dummy = Array::zeros(IxDyn(&[1, 1])); |
| result.insert(name.clone(), dummy); |
| } |
| Ok(result) |
| } |
|
|
| |
| pub fn run_i64( |
| &self, |
| _inputs: HashMap<String, Array<i64, IxDyn>>, |
| ) -> Result<HashMap<String, Array<f32, IxDyn>>> { |
| let mut result = HashMap::new(); |
| for name in &self.output_names { |
| let dummy = Array::zeros(IxDyn(&[1, 1])); |
| result.insert(name.clone(), dummy); |
| } |
| Ok(result) |
| } |
|
|
| pub fn input_names(&self) -> &[String] { |
| &self.input_names |
| } |
|
|
| pub fn output_names(&self) -> &[String] { |
| &self.output_names |
| } |
| } |
|
|
| |
| pub struct ModelCache { |
| sessions: RwLock<HashMap<String, Arc<OnnxSession>>>, |
| model_dir: PathBuf, |
| } |
|
|
| impl ModelCache { |
| pub fn new<P: AsRef<Path>>(model_dir: P) -> Self { |
| Self { |
| sessions: RwLock::new(HashMap::new()), |
| model_dir: model_dir.as_ref().to_path_buf(), |
| } |
| } |
|
|
| pub fn get_or_load(&self, name: &str) -> Result<Arc<OnnxSession>> { |
| { |
| let cache = self.sessions.read().unwrap(); |
| if let Some(session) = cache.get(name) { |
| return Ok(Arc::clone(session)); |
| } |
| } |
|
|
| let model_path = self.model_dir.join(format!("{}.onnx", name)); |
| let session = OnnxSession::load(&model_path)?; |
| let session = Arc::new(session); |
|
|
| { |
| let mut cache = self.sessions.write().unwrap(); |
| cache.insert(name.to_string(), Arc::clone(&session)); |
| } |
|
|
| Ok(session) |
| } |
|
|
| pub fn preload(&self, model_names: &[&str]) -> Result<()> { |
| for name in model_names { |
| self.get_or_load(name)?; |
| } |
| Ok(()) |
| } |
|
|
| pub fn clear(&self) { |
| let mut cache = self.sessions.write().unwrap(); |
| cache.clear(); |
| } |
|
|
| pub fn is_cached(&self, name: &str) -> bool { |
| let cache = self.sessions.read().unwrap(); |
| cache.contains_key(name) |
| } |
|
|
| pub fn cached_models(&self) -> Vec<String> { |
| let cache = self.sessions.read().unwrap(); |
| cache.keys().cloned().collect() |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_model_cache_creation() { |
| let cache = ModelCache::new("/tmp/models"); |
| assert!(cache.cached_models().is_empty()); |
| } |
| } |
|
|