|
import torch |
|
import speechbrain as sb |
|
|
|
class FeatureScaler(torch.nn.Module): |
|
def __init__(self, num_in, scale): |
|
super().__init__() |
|
self.scaler = torch.ones((num_in,))* scale |
|
|
|
def forward(self, x): |
|
return x * self.scaler |
|
|
|
class CustomInterface(sb.pretrained.interfaces.Pretrained): |
|
MODULES_NEEDED = ["normalizer"] |
|
HPARAMS_NEEDED = ["feature_extractor"] |
|
|
|
def feats_from_audio(self, audio, lengths=torch.tensor([1.0])): |
|
feats = self.hparams.feature_extractor(audio) |
|
normalized = self.mods.normalizer(feats, lengths) |
|
scaled = self.mods.feature_scaler(normalized) |
|
return scaled |
|
|
|
def feats_from_file(self, path): |
|
audio = self.load_audio(path) |
|
return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0) |
|
|