| import torch |
| import speechbrain as sb |
|
|
| class FeatureScaler(torch.nn.Module): |
| def __init__(self, num_in, scale): |
| super().__init__() |
| self.scaler = torch.ones((num_in,))* scale |
|
|
| def forward(self, x): |
| return x * self.scaler |
|
|
| class CustomInterface(sb.pretrained.interfaces.Pretrained): |
| MODULES_NEEDED = ["normalizer"] |
| HPARAMS_NEEDED = ["feature_extractor"] |
|
|
| def feats_from_audio(self, audio, lengths=torch.tensor([1.0])): |
| feats = self.hparams.feature_extractor(audio) |
| normalized = self.mods.normalizer(feats, lengths) |
| scaled = self.mods.feature_scaler(normalized) |
| return scaled |
|
|
| def feats_from_file(self, path): |
| audio = self.load_audio(path) |
| return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0) |
|
|