| |
|
|
| use crate::{Error, Result}; |
| use ndarray::{Array, Array1, Array2, IxDyn}; |
| use std::collections::HashMap; |
| use std::path::Path; |
|
|
| use super::{OnnxSession, SamplingStrategy, sample_from_logits, apply_repetition_penalty}; |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct GptConfig { |
| |
| pub num_layers: usize, |
| |
| pub hidden_size: usize, |
| |
| pub num_heads: usize, |
| |
| pub max_seq_len: usize, |
| |
| pub vocab_size: usize, |
| |
| pub stop_token: usize, |
| |
| pub start_token: usize, |
| } |
|
|
| impl Default for GptConfig { |
| fn default() -> Self { |
| Self { |
| num_layers: 8, |
| hidden_size: 512, |
| num_heads: 8, |
| max_seq_len: 250, |
| vocab_size: 8194, |
| stop_token: 8193, |
| start_token: 8192, |
| } |
| } |
| } |
|
|
| |
| pub struct GptModel { |
| session: OnnxSession, |
| config: GptConfig, |
| } |
|
|
| impl GptModel { |
| |
| pub fn load<P: AsRef<Path>>(path: P, config: GptConfig) -> Result<Self> { |
| let session = OnnxSession::load(path)?; |
| Ok(Self { session, config }) |
| } |
|
|
| |
| pub fn generate( |
| &self, |
| semantic_tokens: &[i64], |
| speaker_embedding: &Array1<f32>, |
| max_length: usize, |
| strategy: &SamplingStrategy, |
| repetition_penalty: f32, |
| ) -> Result<Vec<i64>> { |
| let mut generated_tokens = vec![self.config.start_token as i64]; |
| let mut past_tokens = Vec::new(); |
|
|
| for _ in 0..max_length { |
| |
| let input_tokens = Array::from_shape_vec( |
| IxDyn(&[1, generated_tokens.len()]), |
| generated_tokens.clone(), |
| )?; |
|
|
| let speaker_emb = speaker_embedding |
| .clone() |
| .into_shape(IxDyn(&[1, speaker_embedding.len()]))?; |
|
|
| let semantic_input = Array::from_shape_vec( |
| IxDyn(&[1, semantic_tokens.len()]), |
| semantic_tokens.to_vec(), |
| )?; |
|
|
| |
| let mut inputs = HashMap::new(); |
| inputs.insert("input_ids".to_string(), input_tokens.mapv(|x| x as f32)); |
| inputs.insert("speaker_embedding".to_string(), speaker_emb); |
| inputs.insert("semantic_tokens".to_string(), semantic_input.mapv(|x| x as f32)); |
|
|
| |
| let outputs = self.session.run(inputs)?; |
|
|
| |
| let logits = outputs |
| .get("logits") |
| .ok_or_else(|| Error::Model("Missing logits output".into()))?; |
|
|
| |
| let seq_len = logits.shape()[1]; |
| let vocab_size = logits.shape()[2]; |
| let last_logits: Vec<f32> = (0..vocab_size) |
| .map(|i| logits[[0, seq_len - 1, i]]) |
| .collect(); |
|
|
| |
| let mut logits_vec = last_logits; |
| let past_usize: Vec<usize> = past_tokens.iter().map(|&x| x as usize).collect(); |
| apply_repetition_penalty(&mut logits_vec, &past_usize, repetition_penalty); |
|
|
| |
| let next_token = sample_from_logits(&logits_vec, strategy) as i64; |
|
|
| |
| if next_token == self.config.stop_token as i64 { |
| break; |
| } |
|
|
| generated_tokens.push(next_token); |
| past_tokens.push(next_token); |
| } |
|
|
| Ok(generated_tokens) |
| } |
|
|
| |
| pub fn generate_with_cache( |
| &self, |
| semantic_tokens: &[i64], |
| speaker_embedding: &Array1<f32>, |
| max_length: usize, |
| strategy: &SamplingStrategy, |
| repetition_penalty: f32, |
| ) -> Result<Vec<i64>> { |
| |
| |
| self.generate( |
| semantic_tokens, |
| speaker_embedding, |
| max_length, |
| strategy, |
| repetition_penalty, |
| ) |
| } |
|
|
| |
| pub fn config(&self) -> &GptConfig { |
| &self.config |
| } |
|
|
| |
| pub fn estimate_memory_mb(&self) -> f32 { |
| let params = self.config.num_layers |
| * self.config.hidden_size |
| * self.config.hidden_size |
| * 4; |
| (params * 4) as f32 / 1_000_000.0 |
| } |
| } |
|
|
| |
| pub struct SimpleGptModel { |
| config: GptConfig, |
| |
| token_embeddings: Array2<f32>, |
| |
| position_embeddings: Array2<f32>, |
| |
| output_projection: Array2<f32>, |
| } |
|
|
| impl SimpleGptModel { |
| |
| pub fn new_random(config: GptConfig) -> Self { |
| use rand::Rng; |
| let mut rng = rand::thread_rng(); |
|
|
| let token_embeddings = Array2::from_shape_fn( |
| (config.vocab_size, config.hidden_size), |
| |_| rng.gen_range(-0.1..0.1), |
| ); |
|
|
| let position_embeddings = Array2::from_shape_fn( |
| (config.max_seq_len, config.hidden_size), |
| |_| rng.gen_range(-0.1..0.1), |
| ); |
|
|
| let output_projection = Array2::from_shape_fn( |
| (config.hidden_size, config.vocab_size), |
| |_| rng.gen_range(-0.1..0.1), |
| ); |
|
|
| Self { |
| config, |
| token_embeddings, |
| position_embeddings, |
| output_projection, |
| } |
| } |
|
|
| |
| pub fn forward(&self, tokens: &[i64]) -> Vec<f32> { |
| |
| let mut hidden = vec![0.0f32; self.config.hidden_size]; |
|
|
| for (pos, &token) in tokens.iter().enumerate().take(self.config.max_seq_len) { |
| let token_idx = (token as usize).min(self.config.vocab_size - 1); |
|
|
| for i in 0..self.config.hidden_size { |
| hidden[i] += self.token_embeddings[[token_idx, i]] |
| + self.position_embeddings[[pos, i]]; |
| } |
| } |
|
|
| |
| let norm: f32 = hidden.iter().map(|x| x * x).sum::<f32>().sqrt(); |
| if norm > 1e-8 { |
| for h in hidden.iter_mut() { |
| *h /= norm; |
| } |
| } |
|
|
| |
| let mut logits = vec![0.0f32; self.config.vocab_size]; |
| for (i, logit) in logits.iter_mut().enumerate() { |
| for j in 0..self.config.hidden_size { |
| *logit += hidden[j] * self.output_projection[[j, i]]; |
| } |
| } |
|
|
| logits |
| } |
|
|
| |
| pub fn generate( |
| &self, |
| prompt: &[i64], |
| max_length: usize, |
| strategy: &SamplingStrategy, |
| ) -> Vec<i64> { |
| let mut tokens = prompt.to_vec(); |
|
|
| for _ in 0..max_length { |
| let logits = self.forward(&tokens); |
| let next_token = sample_from_logits(&logits, strategy) as i64; |
|
|
| if next_token == self.config.stop_token as i64 { |
| break; |
| } |
|
|
| tokens.push(next_token); |
|
|
| if tokens.len() >= self.config.max_seq_len { |
| break; |
| } |
| } |
|
|
| tokens |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_gpt_config_default() { |
| let config = GptConfig::default(); |
| assert_eq!(config.num_layers, 8); |
| assert_eq!(config.hidden_size, 512); |
| } |
|
|
| #[test] |
| fn test_simple_gpt_forward() { |
| let config = GptConfig { |
| vocab_size: 100, |
| hidden_size: 32, |
| max_seq_len: 10, |
| ..Default::default() |
| }; |
|
|
| let model = SimpleGptModel::new_random(config); |
| let tokens = vec![1i64, 2, 3]; |
| let logits = model.forward(&tokens); |
|
|
| assert_eq!(logits.len(), 100); |
| } |
|
|
| #[test] |
| fn test_simple_gpt_generate() { |
| let config = GptConfig { |
| vocab_size: 100, |
| hidden_size: 32, |
| max_seq_len: 20, |
| stop_token: 99, |
| ..Default::default() |
| }; |
|
|
| let model = SimpleGptModel::new_random(config); |
| let prompt = vec![1i64, 2, 3]; |
| let generated = model.generate(&prompt, 10, &SamplingStrategy::Greedy); |
|
|
| assert!(generated.len() >= 3); |
| assert!(generated.len() <= 20); |
| } |
| } |
|
|