| |
| |
| |
|
|
| use crate::{Error, Result}; |
| use ndarray::{Array1, Array2, Axis}; |
| use num_complex::Complex; |
| use realfft::RealFftPlanner; |
| use std::f32::consts::PI; |
|
|
| use super::AudioConfig; |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct MelFilterbank { |
| |
| pub filters: Array2<f32>, |
| |
| pub sample_rate: u32, |
| |
| pub n_mels: usize, |
| |
| pub n_fft: usize, |
| } |
|
|
| impl MelFilterbank { |
| |
| pub fn new(sample_rate: u32, n_fft: usize, n_mels: usize, fmin: f32, fmax: f32) -> Self { |
| let filters = create_mel_filterbank(sample_rate, n_fft, n_mels, fmin, fmax); |
| Self { |
| filters, |
| sample_rate, |
| n_mels, |
| n_fft, |
| } |
| } |
|
|
| |
| pub fn apply(&self, spectrogram: &Array2<f32>) -> Array2<f32> { |
| |
| |
| |
| self.filters.dot(spectrogram) |
| } |
| } |
|
|
| |
| pub fn hz_to_mel(hz: f32) -> f32 { |
| 2595.0 * (1.0 + hz / 700.0).log10() |
| } |
|
|
| |
| pub fn mel_to_hz(mel: f32) -> f32 { |
| 700.0 * (10f32.powf(mel / 2595.0) - 1.0) |
| } |
|
|
| |
| fn create_mel_filterbank( |
| sample_rate: u32, |
| n_fft: usize, |
| n_mels: usize, |
| fmin: f32, |
| fmax: f32, |
| ) -> Array2<f32> { |
| let n_freqs = n_fft / 2 + 1; |
|
|
| |
| let mel_min = hz_to_mel(fmin); |
| let mel_max = hz_to_mel(fmax); |
|
|
| |
| let mel_points: Vec<f32> = (0..=n_mels + 1) |
| .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32) |
| .collect(); |
|
|
| |
| let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); |
|
|
| |
| let bin_points: Vec<usize> = hz_points |
| .iter() |
| .map(|&hz| ((n_fft as f32 + 1.0) * hz / sample_rate as f32).floor() as usize) |
| .collect(); |
|
|
| |
| let mut filters = Array2::zeros((n_mels, n_freqs)); |
|
|
| for m in 0..n_mels { |
| let f_left = bin_points[m]; |
| let f_center = bin_points[m + 1]; |
| let f_right = bin_points[m + 2]; |
|
|
| |
| for k in f_left..f_center { |
| if k < n_freqs { |
| filters[[m, k]] = (k - f_left) as f32 / (f_center - f_left).max(1) as f32; |
| } |
| } |
|
|
| |
| for k in f_center..f_right { |
| if k < n_freqs { |
| filters[[m, k]] = (f_right - k) as f32 / (f_right - f_center).max(1) as f32; |
| } |
| } |
| } |
|
|
| filters |
| } |
|
|
| |
| fn hann_window(size: usize) -> Vec<f32> { |
| (0..size) |
| .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / size as f32).cos())) |
| .collect() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn stft( |
| signal: &[f32], |
| n_fft: usize, |
| hop_length: usize, |
| win_length: usize, |
| ) -> Result<Array2<Complex<f32>>> { |
| if signal.is_empty() { |
| return Err(Error::Audio("Empty signal".into())); |
| } |
|
|
| |
| let window = hann_window(win_length); |
|
|
| |
| let pad_length = n_fft / 2; |
| let mut padded = vec![0.0f32; pad_length]; |
| padded.extend_from_slice(signal); |
| padded.extend(vec![0.0f32; pad_length]); |
|
|
| |
| let num_frames = (padded.len() - n_fft) / hop_length + 1; |
| let n_freqs = n_fft / 2 + 1; |
|
|
| |
| let mut planner = RealFftPlanner::<f32>::new(); |
| let fft = planner.plan_fft_forward(n_fft); |
|
|
| |
| let mut stft_matrix = Array2::zeros((n_freqs, num_frames)); |
|
|
| |
| let mut input_buffer = vec![0.0f32; n_fft]; |
| let mut output_buffer = vec![Complex::new(0.0f32, 0.0f32); n_freqs]; |
|
|
| for (frame_idx, start) in (0..padded.len() - n_fft + 1) |
| .step_by(hop_length) |
| .enumerate() |
| { |
| if frame_idx >= num_frames { |
| break; |
| } |
|
|
| |
| for i in 0..win_length { |
| input_buffer[i] = padded[start + i] * window[i]; |
| } |
| |
| for i in win_length..n_fft { |
| input_buffer[i] = 0.0; |
| } |
|
|
| |
| fft.process(&mut input_buffer, &mut output_buffer) |
| .map_err(|e| Error::Audio(format!("FFT failed: {}", e)))?; |
|
|
| |
| for (freq_idx, &val) in output_buffer.iter().enumerate() { |
| stft_matrix[[freq_idx, frame_idx]] = val; |
| } |
| } |
|
|
| Ok(stft_matrix) |
| } |
|
|
| |
| pub fn magnitude_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> { |
| stft_matrix.mapv(|c| c.norm()) |
| } |
|
|
| |
| pub fn power_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> { |
| stft_matrix.mapv(|c| c.norm_sqr()) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn mel_spectrogram(signal: &[f32], config: &AudioConfig) -> Result<Array2<f32>> { |
| |
| let stft_matrix = stft(signal, config.n_fft, config.hop_length, config.win_length)?; |
|
|
| |
| let power_spec = power_spectrogram(&stft_matrix); |
|
|
| |
| let mel_fb = MelFilterbank::new( |
| config.sample_rate, |
| config.n_fft, |
| config.n_mels, |
| config.fmin, |
| config.fmax, |
| ); |
|
|
| |
| let mel_spec = mel_fb.apply(&power_spec); |
|
|
| |
| let log_mel_spec = mel_spec.mapv(|x| (x.max(1e-10)).ln()); |
|
|
| Ok(log_mel_spec) |
| } |
|
|
| |
| pub fn mel_spectrogram_normalized( |
| signal: &[f32], |
| config: &AudioConfig, |
| mean: Option<f32>, |
| std: Option<f32>, |
| ) -> Result<Array2<f32>> { |
| let mut mel_spec = mel_spectrogram(signal, config)?; |
|
|
| |
| if let (Some(m), Some(s)) = (mean, std) { |
| mel_spec.mapv_inplace(|x| (x - m) / s); |
| } else { |
| |
| let m = mel_spec.mean().unwrap_or(0.0); |
| let s = mel_spec.std(0.0); |
| if s > 1e-8 { |
| mel_spec.mapv_inplace(|x| (x - m) / s); |
| } |
| } |
|
|
| Ok(mel_spec) |
| } |
|
|
| |
| pub fn mel_to_linear(mel_spec: &Array2<f32>, mel_fb: &MelFilterbank) -> Array2<f32> { |
| |
| let filters_t = mel_fb.filters.t(); |
| let gram = mel_fb.filters.dot(&filters_t); |
|
|
| |
| filters_t.dot(mel_spec) |
| } |
|
|
| |
| pub fn frame_energy(mel_spec: &Array2<f32>) -> Array1<f32> { |
| mel_spec.sum_axis(Axis(0)) |
| } |
|
|
| |
| pub fn voice_activity_detection(mel_spec: &Array2<f32>, threshold_db: f32) -> Vec<bool> { |
| let energy = frame_energy(mel_spec); |
| let max_energy = energy.iter().cloned().fold(f32::NEG_INFINITY, f32::max); |
| let threshold = max_energy + threshold_db; |
|
|
| energy.iter().map(|&e| e > threshold).collect() |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_hz_to_mel() { |
| |
| assert!((hz_to_mel(0.0) - 0.0).abs() < 1e-6); |
| assert!((hz_to_mel(1000.0) - 1000.0).abs() < 50.0); |
| } |
|
|
| #[test] |
| fn test_mel_to_hz() { |
| |
| let hz = 440.0; |
| let mel = hz_to_mel(hz); |
| let hz_back = mel_to_hz(mel); |
| assert!((hz - hz_back).abs() < 1e-4); |
| } |
|
|
| #[test] |
| fn test_mel_filterbank_creation() { |
| let fb = MelFilterbank::new(22050, 1024, 80, 0.0, 8000.0); |
| assert_eq!(fb.filters.shape(), &[80, 513]); |
|
|
| |
| let total_sum: f32 = fb.filters.iter().sum(); |
| assert!(total_sum > 0.0, "Filterbank should have some non-zero values"); |
| } |
|
|
| #[test] |
| fn test_hann_window() { |
| let window = hann_window(1024); |
| assert_eq!(window.len(), 1024); |
| |
| assert!(window[0].abs() < 1e-6); |
| |
| assert!((window[512] - 1.0).abs() < 1e-4); |
| } |
|
|
| #[test] |
| fn test_stft_basic() { |
| |
| let sr = 22050; |
| let freq = 440.0; |
| let duration = 0.1; |
| let num_samples = (sr as f32 * duration) as usize; |
|
|
| let signal: Vec<f32> = (0..num_samples) |
| .map(|i| (2.0 * PI * freq * i as f32 / sr as f32).sin()) |
| .collect(); |
|
|
| let result = stft(&signal, 1024, 256, 1024); |
| assert!(result.is_ok()); |
|
|
| let stft_matrix = result.unwrap(); |
| assert_eq!(stft_matrix.shape()[0], 513); |
| assert!(stft_matrix.shape()[1] > 0); |
| } |
|
|
| #[test] |
| fn test_mel_spectrogram() { |
| let config = AudioConfig::default(); |
| let num_samples = (config.sample_rate as f32 * 0.1) as usize; |
| let signal: Vec<f32> = (0..num_samples).map(|i| (i as f32 * 0.01).sin()).collect(); |
|
|
| let result = mel_spectrogram(&signal, &config); |
| assert!(result.is_ok()); |
|
|
| let mel_spec = result.unwrap(); |
| assert_eq!(mel_spec.shape()[0], config.n_mels); |
| assert!(mel_spec.shape()[1] > 0); |
| } |
| } |
|
|