Spaces:
Runtime error
Runtime error
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
import json | |
import math | |
import numpy as np | |
def _load_json_cmvn(json_cmvn_file): | |
"""Load the json format cmvn stats file and calculate cmvn | |
Args: | |
json_cmvn_file: cmvn stats file in json format | |
Returns: | |
a numpy array of [means, vars] | |
""" | |
with open(json_cmvn_file) as f: | |
cmvn_stats = json.load(f) | |
means = cmvn_stats["mean_stat"] | |
variance = cmvn_stats["var_stat"] | |
count = cmvn_stats["frame_num"] | |
for i in range(len(means)): | |
means[i] /= count | |
variance[i] = variance[i] / count - means[i] * means[i] | |
if variance[i] < 1.0e-20: | |
variance[i] = 1.0e-20 | |
variance[i] = 1.0 / math.sqrt(variance[i]) | |
cmvn = np.array([means, variance]) | |
return cmvn | |
def _load_kaldi_cmvn(kaldi_cmvn_file): | |
"""Load the kaldi format cmvn stats file and calculate cmvn | |
Args: | |
kaldi_cmvn_file: kaldi text style global cmvn file, which | |
is generated by: | |
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn | |
Returns: | |
a numpy array of [means, vars] | |
""" | |
means = [] | |
variance = [] | |
with open(kaldi_cmvn_file, "r") as fid: | |
# kaldi binary file start with '\0B' | |
if fid.read(2) == "\0B": | |
logging.error( | |
"kaldi cmvn binary file is not supported, please " | |
"recompute it by: compute-cmvn-stats --binary=false " | |
" scp:feats.scp global_cmvn" | |
) | |
sys.exit(1) | |
fid.seek(0) | |
arr = fid.read().split() | |
assert arr[0] == "[" | |
assert arr[-2] == "0" | |
assert arr[-1] == "]" | |
feat_dim = int((len(arr) - 2 - 2) / 2) | |
for i in range(1, feat_dim + 1): | |
means.append(float(arr[i])) | |
count = float(arr[feat_dim + 1]) | |
for i in range(feat_dim + 2, 2 * feat_dim + 2): | |
variance.append(float(arr[i])) | |
for i in range(len(means)): | |
means[i] /= count | |
variance[i] = variance[i] / count - means[i] * means[i] | |
if variance[i] < 1.0e-20: | |
variance[i] = 1.0e-20 | |
variance[i] = 1.0 / math.sqrt(variance[i]) | |
cmvn = np.array([means, variance]) | |
return cmvn | |
def load_cmvn(cmvn_file, is_json): | |
if is_json: | |
cmvn = _load_json_cmvn(cmvn_file) | |
else: | |
cmvn = _load_kaldi_cmvn(cmvn_file) | |
return cmvn[0], cmvn[1] | |