Spaces:
Runtime error
Runtime error
from typing import List | |
import torch | |
def basic_greedy_search( | |
model: torch.nn.Module, | |
encoder_out: torch.Tensor, | |
encoder_out_lens: torch.Tensor, | |
n_steps: int = 64, | |
) -> List[List[int]]: | |
# fake padding | |
padding = torch.zeros(1, 1).to(encoder_out.device) | |
# sos | |
pred_input_step = torch.tensor([model.blank]).reshape(1, 1) | |
cache = model.predictor.init_state(1, method="zero", device=encoder_out.device) | |
new_cache: List[torch.Tensor] = [] | |
t = 0 | |
hyps = [] | |
prev_out_nblk = True | |
pred_out_step = None | |
per_frame_max_noblk = n_steps | |
per_frame_noblk = 0 | |
while t < encoder_out_lens: | |
encoder_out_step = encoder_out[:, t : t + 1, :] # [1, 1, E] | |
if prev_out_nblk: | |
step_outs = model.predictor.forward_step( | |
pred_input_step, padding, cache | |
) # [1, 1, P] | |
pred_out_step, new_cache = step_outs[0], step_outs[1] | |
joint_out_step = model.joint(encoder_out_step, pred_out_step) # [1,1,v] | |
joint_out_probs = joint_out_step.log_softmax(dim=-1) | |
joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() # [] | |
if joint_out_max != model.blank: | |
hyps.append(joint_out_max.item()) | |
prev_out_nblk = True | |
per_frame_noblk = per_frame_noblk + 1 | |
pred_input_step = joint_out_max.reshape(1, 1) | |
# state_m, state_c = clstate_out_m, state_out_c | |
cache = new_cache | |
if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk: | |
if joint_out_max == model.blank: | |
prev_out_nblk = False | |
# TODO(Mddct): make t in chunk for streamming | |
# or t should't be too lang to predict none blank | |
t = t + 1 | |
per_frame_noblk = 0 | |
return [hyps] | |