Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- encoding: utf-8 -*- | |
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
# MIT License (https://opensource.org/licenses/MIT) | |
# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) | |
import time | |
import torch | |
import numpy as np | |
from collections import OrderedDict | |
from contextlib import contextmanager | |
from distutils.version import LooseVersion | |
from funasr_detach.register import tables | |
from funasr_detach.models.campplus.utils import extract_feature | |
from funasr_detach.utils.load_utils import load_audio_text_image_video | |
from funasr_detach.models.campplus.components import ( | |
DenseLayer, | |
StatsPool, | |
TDNNLayer, | |
CAMDenseTDNNBlock, | |
TransitLayer, | |
get_nonlinear, | |
FCM, | |
) | |
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
from torch.cuda.amp import autocast | |
else: | |
# Nothing to do if torch<1.6.0 | |
def autocast(enabled=True): | |
yield | |
class CAMPPlus(torch.nn.Module): | |
def __init__( | |
self, | |
feat_dim=80, | |
embedding_size=192, | |
growth_rate=32, | |
bn_size=4, | |
init_channels=128, | |
config_str="batchnorm-relu", | |
memory_efficient=True, | |
output_level="segment", | |
**kwargs, | |
): | |
super().__init__() | |
self.head = FCM(feat_dim=feat_dim) | |
channels = self.head.out_channels | |
self.output_level = output_level | |
self.xvector = torch.nn.Sequential( | |
OrderedDict( | |
[ | |
( | |
"tdnn", | |
TDNNLayer( | |
channels, | |
init_channels, | |
5, | |
stride=2, | |
dilation=1, | |
padding=-1, | |
config_str=config_str, | |
), | |
), | |
] | |
) | |
) | |
channels = init_channels | |
for i, (num_layers, kernel_size, dilation) in enumerate( | |
zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) | |
): | |
block = CAMDenseTDNNBlock( | |
num_layers=num_layers, | |
in_channels=channels, | |
out_channels=growth_rate, | |
bn_channels=bn_size * growth_rate, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
config_str=config_str, | |
memory_efficient=memory_efficient, | |
) | |
self.xvector.add_module("block%d" % (i + 1), block) | |
channels = channels + num_layers * growth_rate | |
self.xvector.add_module( | |
"transit%d" % (i + 1), | |
TransitLayer( | |
channels, channels // 2, bias=False, config_str=config_str | |
), | |
) | |
channels //= 2 | |
self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels)) | |
if self.output_level == "segment": | |
self.xvector.add_module("stats", StatsPool()) | |
self.xvector.add_module( | |
"dense", | |
DenseLayer(channels * 2, embedding_size, config_str="batchnorm_"), | |
) | |
else: | |
assert ( | |
self.output_level == "frame" | |
), "`output_level` should be set to 'segment' or 'frame'. " | |
for m in self.modules(): | |
if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): | |
torch.nn.init.kaiming_normal_(m.weight.data) | |
if m.bias is not None: | |
torch.nn.init.zeros_(m.bias) | |
def forward(self, x): | |
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) | |
x = self.head(x) | |
x = self.xvector(x) | |
if self.output_level == "frame": | |
x = x.transpose(1, 2) | |
return x | |
def inference( | |
self, | |
data_in, | |
data_lengths=None, | |
key: list = None, | |
tokenizer=None, | |
frontend=None, | |
**kwargs, | |
): | |
# extract fbank feats | |
meta_data = {} | |
time1 = time.perf_counter() | |
audio_sample_list = load_audio_text_image_video( | |
data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound" | |
) | |
time2 = time.perf_counter() | |
meta_data["load_data"] = f"{time2 - time1:0.3f}" | |
speech, speech_lengths, speech_times = extract_feature(audio_sample_list) | |
speech = speech.to(device=kwargs["device"]) | |
time3 = time.perf_counter() | |
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" | |
meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0 | |
results = [{"spk_embedding": self.forward(speech.to(torch.float32))}] | |
return results, meta_data | |