| |
| |
| |
|
|
| use crate::{Error, Result}; |
| use ndarray::{Array2, IxDyn}; |
| use std::collections::HashMap; |
| use std::path::Path; |
|
|
| use crate::model::OnnxSession; |
| use super::{Vocoder, snake_activation_vec}; |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct BigVGANConfig { |
| |
| pub sample_rate: u32, |
| |
| pub num_mels: usize, |
| |
| pub upsample_rates: Vec<usize>, |
| |
| pub upsample_kernel_sizes: Vec<usize>, |
| |
| pub resblock_kernel_sizes: Vec<usize>, |
| |
| pub resblock_dilation_sizes: Vec<Vec<usize>>, |
| |
| pub upsample_initial_channel: usize, |
| |
| pub use_anti_alias: bool, |
| } |
|
|
| impl Default for BigVGANConfig { |
| fn default() -> Self { |
| Self { |
| sample_rate: 22050, |
| num_mels: 80, |
| upsample_rates: vec![8, 8, 2, 2], |
| upsample_kernel_sizes: vec![16, 16, 4, 4], |
| resblock_kernel_sizes: vec![3, 7, 11], |
| resblock_dilation_sizes: vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]], |
| upsample_initial_channel: 512, |
| use_anti_alias: true, |
| } |
| } |
| } |
|
|
| impl BigVGANConfig { |
| |
| pub fn total_upsample_factor(&self) -> usize { |
| self.upsample_rates.iter().product() |
| } |
|
|
| |
| pub fn hop_length(&self) -> usize { |
| self.total_upsample_factor() |
| } |
| } |
|
|
| |
| pub struct BigVGAN { |
| session: Option<OnnxSession>, |
| config: BigVGANConfig, |
| } |
|
|
| impl BigVGAN { |
| |
| pub fn load<P: AsRef<Path>>(path: P, config: BigVGANConfig) -> Result<Self> { |
| let session = OnnxSession::load(path)?; |
| Ok(Self { |
| session: Some(session), |
| config, |
| }) |
| } |
|
|
| |
| pub fn new_fallback(config: BigVGANConfig) -> Self { |
| Self { |
| session: None, |
| config, |
| } |
| } |
|
|
| |
| pub fn config(&self) -> &BigVGANConfig { |
| &self.config |
| } |
|
|
| |
| fn synthesize_fallback(&self, mel: &Array2<f32>) -> Result<Vec<f32>> { |
| |
| let num_frames = mel.ncols(); |
| let hop_length = self.config.hop_length(); |
| let frame_size = hop_length * 4; |
|
|
| let output_length = (num_frames - 1) * hop_length + frame_size; |
| let mut output = vec![0.0f32; output_length]; |
| let mut window_sum = vec![0.0f32; output_length]; |
|
|
| |
| let window: Vec<f32> = (0..frame_size) |
| .map(|n| { |
| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / frame_size as f32).cos()) |
| }) |
| .collect(); |
|
|
| |
| for frame_idx in 0..num_frames { |
| let start = frame_idx * hop_length; |
|
|
| |
| let mel_frame: Vec<f32> = (0..self.config.num_mels) |
| .map(|i| mel[[i, frame_idx]]) |
| .collect(); |
|
|
| |
| let frame = self.generate_frame(&mel_frame, frame_size); |
|
|
| |
| for i in 0..frame_size { |
| if start + i < output_length { |
| output[start + i] += frame[i] * window[i]; |
| window_sum[start + i] += window[i] * window[i]; |
| } |
| } |
| } |
|
|
| |
| for i in 0..output_length { |
| if window_sum[i] > 1e-8 { |
| output[i] /= window_sum[i]; |
| } |
| } |
|
|
| |
| let output = snake_activation_vec(&output, 0.3); |
|
|
| Ok(output) |
| } |
|
|
| |
| fn generate_frame(&self, mel: &[f32], frame_size: usize) -> Vec<f32> { |
| use rand::Rng; |
| let mut rng = rand::thread_rng(); |
|
|
| |
| let energy: f32 = mel.iter().map(|x| x.exp()).sum::<f32>() / mel.len() as f32; |
| let energy = energy.sqrt().min(2.0); |
|
|
| |
| let mut frame = vec![0.0f32; frame_size]; |
|
|
| |
| for (freq_idx, &mel_val) in mel.iter().enumerate() { |
| let freq = (freq_idx as f32 / mel.len() as f32) * (self.config.sample_rate as f32 / 2.0); |
| let amplitude = mel_val.exp().min(1.0) * 0.1; |
|
|
| |
| for i in 0..frame_size { |
| let t = i as f32 / self.config.sample_rate as f32; |
| frame[i] += amplitude * (2.0 * std::f32::consts::PI * freq * t).sin(); |
| } |
| } |
|
|
| |
| for i in 0..frame_size { |
| frame[i] += rng.gen_range(-0.1..0.1) * energy * 0.1; |
| } |
|
|
| |
| let max_abs = frame.iter().map(|x| x.abs()).fold(0.0f32, f32::max); |
| if max_abs > 1.0 { |
| for v in frame.iter_mut() { |
| *v /= max_abs; |
| } |
| } |
|
|
| frame |
| } |
|
|
| |
| pub fn post_process(&self, audio: &[f32]) -> Vec<f32> { |
| use crate::audio::{normalize_audio, apply_fade}; |
|
|
| let normalized = normalize_audio(audio); |
|
|
| |
| let fade_samples = (self.config.sample_rate as f32 * 0.01) as usize; |
| apply_fade(&normalized, fade_samples, fade_samples) |
| } |
| } |
|
|
| impl Vocoder for BigVGAN { |
| fn synthesize(&self, mel: &Array2<f32>) -> Result<Vec<f32>> { |
| if let Some(ref session) = self.session { |
| |
| let input = mel.clone().into_shape(IxDyn(&[1, mel.nrows(), mel.ncols()]))?; |
|
|
| let mut inputs = HashMap::new(); |
| inputs.insert("mel".to_string(), input); |
|
|
| let outputs = session.run(inputs)?; |
|
|
| let audio = outputs |
| .get("audio") |
| .ok_or_else(|| Error::Vocoder("Missing audio output".into()))?; |
|
|
| |
| let samples: Vec<f32> = audio.iter().cloned().collect(); |
|
|
| Ok(self.post_process(&samples)) |
| } else { |
| |
| let audio = self.synthesize_fallback(mel)?; |
| Ok(self.post_process(&audio)) |
| } |
| } |
|
|
| fn sample_rate(&self) -> u32 { |
| self.config.sample_rate |
| } |
|
|
| fn hop_length(&self) -> usize { |
| self.config.hop_length() |
| } |
| } |
|
|
| |
| pub fn create_bigvgan_22k() -> BigVGAN { |
| let config = BigVGANConfig { |
| sample_rate: 22050, |
| ..Default::default() |
| }; |
| BigVGAN::new_fallback(config) |
| } |
|
|
| |
| pub fn create_bigvgan_24k() -> BigVGAN { |
| let config = BigVGANConfig { |
| sample_rate: 24000, |
| upsample_rates: vec![12, 10, 2, 2], |
| ..Default::default() |
| }; |
| BigVGAN::new_fallback(config) |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_bigvgan_config() { |
| let config = BigVGANConfig::default(); |
| assert_eq!(config.total_upsample_factor(), 256); |
| assert_eq!(config.hop_length(), 256); |
| } |
|
|
| #[test] |
| fn test_bigvgan_fallback() { |
| let vocoder = create_bigvgan_22k(); |
| assert_eq!(vocoder.sample_rate(), 22050); |
|
|
| |
| let mel = Array2::zeros((80, 10)); |
| let result = vocoder.synthesize(&mel); |
| assert!(result.is_ok()); |
|
|
| let audio = result.unwrap(); |
| assert!(audio.len() > 0); |
| } |
|
|
| #[test] |
| fn test_generate_frame() { |
| let vocoder = create_bigvgan_22k(); |
| let mel = vec![0.0f32; 80]; |
| let frame = vocoder.generate_frame(&mel, 256); |
| assert_eq!(frame.len(), 256); |
| } |
|
|
| #[test] |
| fn test_post_process() { |
| let vocoder = create_bigvgan_22k(); |
| let audio = vec![0.5f32; 1000]; |
| let processed = vocoder.post_process(&audio); |
| assert_eq!(processed.len(), audio.len()); |
| |
| assert!(processed[0].abs() < 0.1); |
| } |
| } |
|
|