File size: 10,144 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
"""Parallel beam search module for online simulation."""
import logging
from pathlib import Path
from typing import List
import yaml
import torch
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect
class BatchBeamSearchOnlineSim(BatchBeamSearch):
"""Online beam search implementation.
This simulates streaming decoding.
It requires encoded features of entire utterance and
extracts block by block from it as it shoud be done
in streaming processing.
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
(https://arxiv.org/abs/2006.14941).
"""
def set_streaming_config(self, asr_config: str):
"""Set config file for streaming decoding.
Args:
asr_config (str): The config file for asr training
"""
train_config_file = Path(asr_config)
self.block_size = None
self.hop_size = None
self.look_ahead = None
config = None
with train_config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if "encoder_conf" in args.keys():
if "block_size" in args["encoder_conf"].keys():
self.block_size = args["encoder_conf"]["block_size"]
if "hop_size" in args["encoder_conf"].keys():
self.hop_size = args["encoder_conf"]["hop_size"]
if "look_ahead" in args["encoder_conf"].keys():
self.look_ahead = args["encoder_conf"]["look_ahead"]
elif "config" in args.keys():
config = args["config"]
if config is None:
logging.info(
"Cannot find config file for streaming decoding: "
+ "apply batch beam search instead."
)
return
if (
self.block_size is None or self.hop_size is None or self.look_ahead is None
) and config is not None:
config_file = Path(config)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if "encoder_conf" in args.keys():
enc_args = args["encoder_conf"]
if enc_args and "block_size" in enc_args:
self.block_size = enc_args["block_size"]
if enc_args and "hop_size" in enc_args:
self.hop_size = enc_args["hop_size"]
if enc_args and "look_ahead" in enc_args:
self.look_ahead = enc_args["look_ahead"]
def set_block_size(self, block_size: int):
"""Set block size for streaming decoding.
Args:
block_size (int): The block size of encoder
"""
self.block_size = block_size
def set_hop_size(self, hop_size: int):
"""Set hop size for streaming decoding.
Args:
hop_size (int): The hop size of encoder
"""
self.hop_size = hop_size
def set_look_ahead(self, look_ahead: int):
"""Set look ahead size for streaming decoding.
Args:
look_ahead (int): The look ahead size of encoder
"""
self.look_ahead = look_ahead
def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
self.conservative = True # always true
if self.block_size and self.hop_size and self.look_ahead:
cur_end_frame = int(self.block_size - self.look_ahead)
else:
cur_end_frame = x.shape[0]
process_idx = 0
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
else:
h = x
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
minlen = int(minlenratio * x.size(0))
logging.info("decoder input length: " + str(x.shape[0]))
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(h)
prev_hyps = []
ended_hyps = []
prev_repeat = False
continue_decode = True
while continue_decode:
move_to_next_block = False
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
else:
h = x
# extend states for ctc
self.extend(h, running_hyps)
while process_idx < maxlen:
logging.debug("position " + str(process_idx))
best = self.search(running_hyps, h)
if process_idx == maxlen - 1:
# end decoding
running_hyps = self.post_process(
process_idx, maxlen, maxlenratio, best, ended_hyps
)
n_batch = best.yseq.shape[0]
local_ended_hyps = []
is_local_eos = (
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
)
for i in range(is_local_eos.shape[0]):
if is_local_eos[i]:
hyp = self._select(best, i)
local_ended_hyps.append(hyp)
# NOTE(tsunoo): check repetitions here
# This is a implicit implementation of
# Eq (11) in https://arxiv.org/abs/2006.14941
# A flag prev_repeat is used instead of using set
elif (
not prev_repeat
and best.yseq[i, -1] in best.yseq[i, :-1]
and cur_end_frame < x.shape[0]
):
move_to_next_block = True
prev_repeat = True
if maxlenratio == 0.0 and end_detect(
[lh.asdict() for lh in local_ended_hyps], process_idx
):
logging.info(f"end detected at {process_idx}")
continue_decode = False
break
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
move_to_next_block = True
if move_to_next_block:
if (
self.hop_size
and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
< x.shape[0]
):
cur_end_frame += int(self.hop_size)
else:
cur_end_frame = x.shape[0]
logging.debug("Going to next block: %d", cur_end_frame)
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
running_hyps = prev_hyps
process_idx -= 1
prev_hyps = []
break
prev_repeat = False
prev_hyps = running_hyps
running_hyps = self.post_process(
process_idx, maxlen, maxlenratio, best, ended_hyps
)
if cur_end_frame >= x.shape[0]:
for hyp in local_ended_hyps:
ended_hyps.append(hyp)
if len(running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
continue_decode = False
break
else:
logging.debug(f"remained hypotheses: {len(running_hyps)}")
# increment number
process_idx += 1
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return (
[]
if minlenratio < 0.1
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
)
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
return nbest_hyps
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
"""Extend probabilities and states with more encoded chunks.
Args:
x (torch.Tensor): The extended encoder output feature
hyps (Hypothesis): Current list of hypothesis
Returns:
Hypothesis: The exxtended hypothesis
"""
for k, d in self.scorers.items():
if hasattr(d, "extend_prob"):
d.extend_prob(x)
if hasattr(d, "extend_state"):
hyps.states[k] = d.extend_state(hyps.states[k])
|