| from dataclasses import make_dataclass |
|
|
| import torch |
| import torchaudio |
| from torch import nn |
|
|
| from .usad_modules import ConformerEncoder |
|
|
| MAX_MEL_LENGTH = 3000 |
|
|
|
|
| @torch.no_grad() |
| def wav_to_fbank( |
| wavs: torch.Tensor, |
| mel_dim: int = 128, |
| norm_mean: float = -4.268, |
| norm_std: float = 4.569, |
| ) -> torch.Tensor: |
| """Convert waveform to fbank features. |
| |
| Args: |
| wavs (torch.Tensor): (B, T_wav) waveform tensor. |
| mel_dim (int, optional): mel dimension. Defaults to 128. |
| norm_mean (float, optional): |
| mean for normalization. Defaults to -4.268. |
| norm_std (float, optional): |
| std for normalization. Defaults to 4.569. |
| |
| Returns: |
| torch.Tensor: (B, T_mel, mel_dim) fbank features. |
| """ |
| |
| dtype = wavs.dtype |
| wavs = wavs.to(torch.float32) |
| wavs = wavs - wavs.mean(dim=-1, keepdim=True) |
| feats = [ |
| torchaudio.compliance.kaldi.fbank( |
| wavs[i : i + 1], |
| htk_compat=True, |
| sample_frequency=16000, |
| use_energy=False, |
| window_type="hanning", |
| num_mel_bins=mel_dim, |
| dither=0.0, |
| frame_shift=10, |
| ).to(dtype=dtype) |
| for i in range(wavs.shape[0]) |
| ] |
|
|
| mels = torch.stack(feats, dim=0) |
| mels = (mels - norm_mean) / (norm_std * 2) |
|
|
| return mels |
|
|
|
|
| class UsadModel(nn.Module): |
| def __init__(self, cfg) -> None: |
| """Initialize the UsadModel. |
| Args: |
| cfg: Configuration object containing model parameters. |
| """ |
| super().__init__() |
|
|
| self.cfg = cfg |
| self.encoder = ConformerEncoder(cfg) |
| self.max_mel_length = MAX_MEL_LENGTH |
| |
| |
|
|
| @property |
| def sample_rate(self) -> int: |
| return 16000 |
|
|
| @property |
| def encoder_frame_rate(self) -> int: |
| return 50 |
|
|
| @property |
| def mel_dim(self) -> int: |
| return self.cfg.input_dim |
|
|
| @property |
| def encoder_dim(self) -> int: |
| return self.cfg.encoder_dim |
|
|
| @property |
| def num_layers(self) -> int: |
| return self.cfg.num_layers |
|
|
| @property |
| def scene_embedding_size(self) -> int: |
| return self.cfg.encoder_dim * self.cfg.num_layers |
|
|
| @property |
| def timestamp_embedding_size(self) -> int: |
| return self.cfg.encoder_dim * self.cfg.num_layers |
|
|
| @property |
| def device(self) -> torch.device: |
| """Get the device on which the model is located.""" |
| return next(self.parameters()).device |
|
|
| def set_audio_chunk_size(self, seconds: float = 30.0) -> None: |
| """Set the maximum chunk size for feature extraction. |
| |
| Args: |
| seconds (float, optional): Chunk size in seconds. Defaults to 30.0. |
| """ |
| assert ( |
| seconds >= 0.1 |
| ), f"Chunk size must be greater than 0.1s, got {seconds} seconds." |
| self.max_mel_length = int(seconds * 100) |
|
|
| def load_audio(self, audio_path: str) -> torch.Tensor: |
| """Load audio file and return waveform tensor. |
| Args: |
| audio_path (str): Path to the audio file. |
| |
| Returns: |
| torch.Tensor: Waveform tensor of shape (wav_len,). |
| """ |
|
|
| waveform, sr = torchaudio.load(audio_path) |
| if sr != self.sample_rate: |
| waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) |
| if waveform.shape[0] > 1: |
| |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| waveform = waveform.squeeze(0) |
| return waveform.to(self.device) |
|
|
| def forward( |
| self, |
| wavs: torch.Tensor, |
| norm_mean: float = -4.268, |
| norm_std: float = 4.569, |
| ) -> dict: |
| """Forward pass for the model. |
| |
| Args: |
| wavs (torch.Tensor): |
| Input waveform tensor of shape (batch_size, wav_len). |
| norm_mean (float, optional): |
| Mean for normalization. Defaults to -4.268. |
| norm_std (float, optional): |
| Standard deviation for normalization. Defaults to 4.569. |
| |
| Returns: |
| dict: A dictionary containing the model's outputs. |
| """ |
| |
|
|
| mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std) |
| mel = mel[:, : mel.shape[1] - mel.shape[1] % 2] |
| if mel.shape[1] <= self.max_mel_length: |
| x, x_len, layer_results = self.encoder(mel, return_hidden=True) |
|
|
| result = { |
| "x": x, |
| "mel": mel, |
| "hidden_states": layer_results["hidden_states"], |
| "ffn": layer_results["ffn_1"], |
| } |
| return result |
|
|
| result = { |
| "x": [], |
| "mel": mel, |
| "hidden_states": [[] for _ in range(self.cfg.num_layers)], |
| "ffn": [[] for _ in range(self.cfg.num_layers)], |
| } |
| for i in range(0, mel.shape[1], self.max_mel_length): |
| if mel.shape[1] - i < 10: |
| break |
|
|
| x, x_len, layer_results = self.encoder( |
| mel[:, i : i + self.max_mel_length], return_hidden=True |
| ) |
| result["x"].append(x) |
| for j in range(self.cfg.num_layers): |
| result["hidden_states"][j].append(layer_results["hidden_states"][j]) |
| result["ffn"][j].append(layer_results["ffn_1"][j]) |
|
|
| result["x"] = torch.cat(result["x"], dim=1) |
| for j in range(self.cfg.num_layers): |
| result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1) |
| result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) |
|
|
| |
| |
| |
| |
| return result |
|
|
| @classmethod |
| def load_from_fairseq_ckpt(cls, ckpt_path: str): |
| checkpoint = torch.load(ckpt_path, weights_only=False) |
| config = checkpoint["cfg"]["model"] |
| config = make_dataclass("Config", config.keys())(**config) |
| model = cls(config) |
| state_dict = checkpoint["model"] |
| for k in list(state_dict.keys()): |
| if not k.startswith("encoder."): |
| del state_dict[k] |
| model.load_state_dict(state_dict, strict=True) |
| return model |
|
|