from typing import List, Tuple import torch from modules.wenet_extractor.utils.common import log_add class Sequence: __slots__ = {"hyp", "score", "cache"} def __init__( self, hyp: List[torch.Tensor], score, cache: List[torch.Tensor], ): self.hyp = hyp self.score = score self.cache = cache class PrefixBeamSearch: def __init__(self, encoder, predictor, joint, ctc, blank): self.encoder = encoder self.predictor = predictor self.joint = joint self.ctc = ctc self.blank = blank def forward_decoder_one_step( self, encoder_x: torch.Tensor, pre_t: torch.Tensor, cache: List[torch.Tensor] ) -> Tuple[torch.Tensor, List[torch.Tensor]]: padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device) pre_t, new_cache = self.predictor.forward_step( pre_t.unsqueeze(-1), padding, cache ) x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab] x = x.log_softmax(dim=-1) return x, new_cache def prefix_beam_search( self, speech: torch.Tensor, speech_lengths: torch.Tensor, decoding_chunk_size: int = -1, beam_size: int = 5, num_decoding_left_chunks: int = -1, simulate_streaming: bool = False, ctc_weight: float = 0.3, transducer_weight: float = 0.7, ): """prefix beam search also see wenet.transducer.transducer.beam_search """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 device = speech.device batch_size = speech.shape[0] assert batch_size == 1 # 1. Encoder encoder_out, _ = self.encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks ) # (B, maxlen, encoder_dim) maxlen = encoder_out.size(1) ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) beam_init: List[Sequence] = [] # 2. init beam using Sequence to save beam unit cache = self.predictor.init_state(1, method="zero", device=device) beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache)) # 3. start decoding (notice: we use breathwise first searching) # !!!! In this decoding method: one frame do not output multi units. !!!! # !!!! Experiments show that this strategy has little impact !!!! for i in range(maxlen): # 3.1 building input # decoder taking the last token to predict the next token input_hyp = [s.hyp[-1] for s in beam_init] input_hyp_tensor = torch.tensor(input_hyp, dtype=torch.int, device=device) # building statement from beam cache_batch = self.predictor.cache_to_batch([s.cache for s in beam_init]) # build score tensor to do torch.add() function scores = torch.tensor([s.score for s in beam_init]).to(device) # 3.2 forward decoder logp, new_cache = self.forward_decoder_one_step( encoder_out[:, i, :].unsqueeze(1), input_hyp_tensor, cache_batch, ) # logp: (N, 1, 1, vocab_size) logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size) new_cache = self.predictor.batch_to_cache(new_cache) # 3.3 shallow fusion for transducer score # and ctc score where we can also add the LM score logp = torch.log( torch.add( transducer_weight * torch.exp(logp), ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)), ) ) # 3.4 first beam prune top_k_logp, top_k_index = logp.topk(beam_size) # (N, N) scores = torch.add(scores.unsqueeze(1), top_k_logp) # 3.5 generate new beam (N*N) beam_A = [] for j in range(len(beam_init)): # update seq base_seq = beam_init[j] for t in range(beam_size): # blank: only update the score if top_k_index[j, t] == self.blank: new_seq = Sequence( hyp=base_seq.hyp.copy(), score=scores[j, t].item(), cache=base_seq.cache, ) beam_A.append(new_seq) # other unit: update hyp score statement and last else: hyp_new = base_seq.hyp.copy() hyp_new.append(top_k_index[j, t].item()) new_seq = Sequence( hyp=hyp_new, score=scores[j, t].item(), cache=new_cache[j] ) beam_A.append(new_seq) # 3.6 prefix fusion fusion_A = [beam_A[0]] for j in range(1, len(beam_A)): s1 = beam_A[j] if_do_append = True for t in range(len(fusion_A)): # notice: A_ can not fusion with A if s1.hyp == fusion_A[t].hyp: fusion_A[t].score = log_add([fusion_A[t].score, s1.score]) if_do_append = False break if if_do_append: fusion_A.append(s1) # 4. second pruned fusion_A.sort(key=lambda x: x.score, reverse=True) beam_init = fusion_A[:beam_size] return beam_init, encoder_out