Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
5.95 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
"""ScorerInterface implementation for CTC."""
import numpy as np
import torch
from modules.wenet_extractor.paraformer.search.ctc_prefix_score import CTCPrefixScore
from modules.wenet_extractor.paraformer.search.ctc_prefix_score import CTCPrefixScoreTH
from modules.wenet_extractor.paraformer.search.scorer_interface import (
BatchPartialScorerInterface,
)
class CTCPrefixScorer(BatchPartialScorerInterface):
"""Decoder interface wrapper for CTCPrefixScore."""
def __init__(self, ctc: torch.nn.Module, eos: int):
"""Initialize class.
Args:
ctc (torch.nn.Module): The CTC implementation.
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
eos (int): The end-of-sequence id.
"""
self.ctc = ctc
self.eos = eos
self.impl = None
def init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
# TODO(karita): use CTCPrefixScoreTH
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
return 0, self.impl.initial_state()
def select_state(self, state, i, new_id=None):
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label id to select a state if necessary
Returns:
state: pruned state
"""
if type(state) == tuple:
if len(state) == 2: # for CTCPrefixScore
sc, st = state
return sc[i], st[i]
else: # for CTCPrefixScoreTH (need new_id > 0)
r, log_psi, f_min, f_max, scoring_idmap = state
s = log_psi[i, new_id].expand(log_psi.size(1))
if scoring_idmap is not None:
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
else:
return r[:, :, i, new_id], s, f_min, f_max
return None if state is None else state[i]
def score_partial(self, y, ids, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape
`(len(next_tokens),)` and next state for ys
"""
prev_score, state = state
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
tscore = torch.as_tensor(
presub_score - prev_score, device=x.device, dtype=x.dtype
)
return tscore, (presub_score, new_st)
def batch_init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
xlen = torch.tensor([logp.size(1)])
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
return None
def batch_score_partial(self, y, ids, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
ids (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape
`(len(next_tokens),)` and next state for ys
"""
batch_state = (
(
torch.stack([s[0] for s in state], dim=2),
torch.stack([s[1] for s in state]),
state[0][2],
state[0][3],
)
if state[0] is not None
else None
)
return self.impl(y, batch_state, ids)
def extend_prob(self, x: torch.Tensor):
"""Extend probs for decoding.
This extension is for streaming decoding
as in Eq (14) in https://arxiv.org/abs/2006.14941
Args:
x (torch.Tensor): The encoded feature tensor
"""
logp = self.ctc.log_softmax(x.unsqueeze(0))
self.impl.extend_prob(logp)
def extend_state(self, state):
"""Extend state for decoding.
This extension is for streaming decoding
as in Eq (14) in https://arxiv.org/abs/2006.14941
Args:
state: The states of hyps
Returns: exteded state
"""
new_state = []
for s in state:
new_state.append(self.impl.extend_state(s))
return new_state