Spaces:
Runtime error
Runtime error
# This module is from [WeNet](https://github.com/wenet-e2e/wenet). | |
# ## Citations | |
# ```bibtex | |
# @inproceedings{yao2021wenet, | |
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit}, | |
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin}, | |
# booktitle={Proc. Interspeech}, | |
# year={2021}, | |
# address={Brno, Czech Republic }, | |
# organization={IEEE} | |
# } | |
# @article{zhang2022wenet, | |
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit}, | |
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei}, | |
# journal={arXiv preprint arXiv:2203.15455}, | |
# year={2022} | |
# } | |
# | |
from typing import Optional | |
import six | |
import torch | |
import numpy as np | |
def sequence_mask( | |
lengths, | |
maxlen: Optional[int] = None, | |
dtype: torch.dtype = torch.float32, | |
device: Optional[torch.device] = None, | |
) -> torch.Tensor: | |
if maxlen is None: | |
maxlen = lengths.max() | |
row_vector = torch.arange(0, maxlen, 1).to(lengths.device) | |
matrix = torch.unsqueeze(lengths, dim=-1) | |
mask = row_vector < matrix | |
mask = mask.detach() | |
return mask.type(dtype).to(device) if device is not None else mask.type(dtype) | |
def end_detect(ended_hyps, i, M=3, d_end=np.log(1 * np.exp(-10))): | |
"""End detection. | |
described in Eq. (50) of S. Watanabe et al | |
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" | |
:param ended_hyps: | |
:param i: | |
:param M: | |
:param d_end: | |
:return: | |
""" | |
if len(ended_hyps) == 0: | |
return False | |
count = 0 | |
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] | |
for m in six.moves.range(M): | |
# get ended_hyps with their length is i - m | |
hyp_length = i - m | |
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] | |
if len(hyps_same_length) > 0: | |
best_hyp_same_length = sorted( | |
hyps_same_length, key=lambda x: x["score"], reverse=True | |
)[0] | |
if best_hyp_same_length["score"] - best_hyp["score"] < d_end: | |
count += 1 | |
if count == M: | |
return True | |
else: | |
return False | |