| |
|
|
| use crate::{Error, Result}; |
| use serde::{Deserialize, Serialize}; |
| use std::path::{Path, PathBuf}; |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct Config { |
| |
| pub gpt: GptConfig, |
| |
| pub vocoder: VocoderConfig, |
| |
| pub s2mel: S2MelConfig, |
| |
| pub dataset: DatasetConfig, |
| |
| pub emotions: EmotionConfig, |
| |
| pub inference: InferenceConfig, |
| |
| pub model_dir: PathBuf, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct GptConfig { |
| |
| pub layers: usize, |
| |
| pub model_dim: usize, |
| |
| pub heads: usize, |
| |
| pub max_text_tokens: usize, |
| |
| pub max_mel_tokens: usize, |
| |
| pub stop_mel_token: usize, |
| |
| pub start_text_token: usize, |
| |
| pub start_mel_token: usize, |
| |
| pub num_mel_codes: usize, |
| |
| pub num_text_tokens: usize, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct VocoderConfig { |
| |
| pub name: String, |
| |
| pub checkpoint: Option<PathBuf>, |
| |
| pub use_fp16: bool, |
| |
| pub use_deepspeed: bool, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct S2MelConfig { |
| |
| pub checkpoint: PathBuf, |
| |
| pub preprocess: PreprocessConfig, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct PreprocessConfig { |
| |
| pub sr: u32, |
| |
| pub n_fft: usize, |
| |
| pub hop_length: usize, |
| |
| pub win_length: usize, |
| |
| pub n_mels: usize, |
| |
| pub fmin: f32, |
| |
| pub fmax: f32, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct DatasetConfig { |
| |
| pub bpe_model: PathBuf, |
| |
| pub vocab_size: usize, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct EmotionConfig { |
| |
| pub num_dims: usize, |
| |
| pub num: Vec<usize>, |
| |
| pub matrix_path: Option<PathBuf>, |
| } |
|
|
| |
| #[derive(Debug, Clone, Serialize, Deserialize)] |
| pub struct InferenceConfig { |
| |
| pub device: String, |
| |
| pub use_fp16: bool, |
| |
| pub batch_size: usize, |
| |
| pub top_k: usize, |
| |
| pub top_p: f32, |
| |
| pub temperature: f32, |
| |
| pub repetition_penalty: f32, |
| |
| pub length_penalty: f32, |
| } |
|
|
| impl Default for Config { |
| fn default() -> Self { |
| Self { |
| gpt: GptConfig::default(), |
| vocoder: VocoderConfig::default(), |
| s2mel: S2MelConfig::default(), |
| dataset: DatasetConfig::default(), |
| emotions: EmotionConfig::default(), |
| inference: InferenceConfig::default(), |
| model_dir: PathBuf::from("models"), |
| } |
| } |
| } |
|
|
| impl Default for GptConfig { |
| fn default() -> Self { |
| Self { |
| layers: 8, |
| model_dim: 512, |
| heads: 8, |
| max_text_tokens: 120, |
| max_mel_tokens: 250, |
| stop_mel_token: 8193, |
| start_text_token: 8192, |
| start_mel_token: 8192, |
| num_mel_codes: 8194, |
| num_text_tokens: 6681, |
| } |
| } |
| } |
|
|
| impl Default for VocoderConfig { |
| fn default() -> Self { |
| Self { |
| name: "bigvgan_v2_22khz_80band_256x".into(), |
| checkpoint: None, |
| use_fp16: true, |
| use_deepspeed: false, |
| } |
| } |
| } |
|
|
| impl Default for S2MelConfig { |
| fn default() -> Self { |
| Self { |
| checkpoint: PathBuf::from("models/s2mel.onnx"), |
| preprocess: PreprocessConfig::default(), |
| } |
| } |
| } |
|
|
| impl Default for PreprocessConfig { |
| fn default() -> Self { |
| Self { |
| sr: 22050, |
| n_fft: 1024, |
| hop_length: 256, |
| win_length: 1024, |
| n_mels: 80, |
| fmin: 0.0, |
| fmax: 8000.0, |
| } |
| } |
| } |
|
|
| impl Default for DatasetConfig { |
| fn default() -> Self { |
| Self { |
| bpe_model: PathBuf::from("models/bpe.model"), |
| vocab_size: 6681, |
| } |
| } |
| } |
|
|
| impl Default for EmotionConfig { |
| fn default() -> Self { |
| Self { |
| num_dims: 8, |
| num: vec![5, 6, 8, 6, 5, 4, 7, 6], |
| matrix_path: Some(PathBuf::from("models/emotion_matrix.safetensors")), |
| } |
| } |
| } |
|
|
| impl Default for InferenceConfig { |
| fn default() -> Self { |
| Self { |
| device: "cpu".into(), |
| use_fp16: false, |
| batch_size: 1, |
| top_k: 50, |
| top_p: 0.95, |
| temperature: 1.0, |
| repetition_penalty: 1.0, |
| length_penalty: 1.0, |
| } |
| } |
| } |
|
|
| impl Config { |
| |
| pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| let path = path.as_ref(); |
| if !path.exists() { |
| return Err(Error::FileNotFound(path.display().to_string())); |
| } |
|
|
| let content = std::fs::read_to_string(path)?; |
| let config: Config = serde_yaml::from_str(&content)?; |
| Ok(config) |
| } |
|
|
| |
| pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> { |
| let content = serde_yaml::to_string(self) |
| .map_err(|e| Error::Config(format!("Failed to serialize config: {}", e)))?; |
| std::fs::write(path, content)?; |
| Ok(()) |
| } |
|
|
| |
| pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self> { |
| let path = path.as_ref(); |
| if !path.exists() { |
| return Err(Error::FileNotFound(path.display().to_string())); |
| } |
|
|
| let content = std::fs::read_to_string(path)?; |
| let config: Config = serde_json::from_str(&content)?; |
| Ok(config) |
| } |
|
|
| |
| pub fn create_default<P: AsRef<Path>>(path: P) -> Result<Self> { |
| let config = Config::default(); |
| config.save(path)?; |
| Ok(config) |
| } |
|
|
| |
| pub fn validate(&self) -> Result<()> { |
| |
| if !self.model_dir.exists() { |
| log::warn!( |
| "Model directory does not exist: {}", |
| self.model_dir.display() |
| ); |
| } |
|
|
| |
| if self.gpt.layers == 0 { |
| return Err(Error::Config("GPT layers must be > 0".into())); |
| } |
| if self.gpt.model_dim == 0 { |
| return Err(Error::Config("GPT model_dim must be > 0".into())); |
| } |
| if self.gpt.heads == 0 { |
| return Err(Error::Config("GPT heads must be > 0".into())); |
| } |
| if !self.gpt.model_dim.is_multiple_of(self.gpt.heads) { |
| return Err(Error::Config( |
| "GPT model_dim must be divisible by heads".into(), |
| )); |
| } |
|
|
| |
| if self.s2mel.preprocess.sr == 0 { |
| return Err(Error::Config("Sample rate must be > 0".into())); |
| } |
| if self.s2mel.preprocess.n_fft == 0 { |
| return Err(Error::Config("n_fft must be > 0".into())); |
| } |
| if self.s2mel.preprocess.hop_length == 0 { |
| return Err(Error::Config("hop_length must be > 0".into())); |
| } |
|
|
| |
| if self.inference.temperature <= 0.0 { |
| return Err(Error::Config("Temperature must be > 0".into())); |
| } |
| if self.inference.top_p <= 0.0 || self.inference.top_p > 1.0 { |
| return Err(Error::Config("top_p must be in (0, 1]".into())); |
| } |
|
|
| Ok(()) |
| } |
| } |
|
|