Aku Rouhe
Okay now
b57ac94
raw
history blame contribute delete
793 Bytes
import torch
import speechbrain as sb
class FeatureScaler(torch.nn.Module):
def __init__(self, num_in, scale):
super().__init__()
self.scaler = torch.ones((num_in,))* scale
def forward(self, x):
return x * self.scaler
class CustomInterface(sb.pretrained.interfaces.Pretrained):
MODULES_NEEDED = ["normalizer"]
HPARAMS_NEEDED = ["feature_extractor"]
def feats_from_audio(self, audio, lengths=torch.tensor([1.0])):
feats = self.hparams.feature_extractor(audio)
normalized = self.mods.normalizer(feats, lengths)
scaled = self.mods.feature_scaler(normalized)
return scaled
def feats_from_file(self, path):
audio = self.load_audio(path)
return self.feats_from_audio(audio.unsqueeze(0)).squeeze(0)