conex / espnet /nets /beam_search_transducer.py
tobiasc's picture
Initial commit
ad16788
"""Search algorithms for transducer models."""
from typing import List
from typing import Union
import numpy as np
import torch
from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state
from espnet.nets.pytorch_backend.transducer.utils import init_lm_state
from espnet.nets.pytorch_backend.transducer.utils import is_prefix
from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps
from espnet.nets.pytorch_backend.transducer.utils import select_lm_state
from espnet.nets.pytorch_backend.transducer.utils import substract
from espnet.nets.transducer_decoder_interface import Hypothesis
from espnet.nets.transducer_decoder_interface import NSCHypothesis
from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface
class BeamSearchTransducer:
"""Beam search implementation for transducer."""
def __init__(
self,
decoder: Union[TransducerDecoderInterface, torch.nn.Module],
joint_network: torch.nn.Module,
beam_size: int,
lm: torch.nn.Module = None,
lm_weight: float = 0.1,
search_type: str = "default",
max_sym_exp: int = 2,
u_max: int = 50,
nstep: int = 1,
prefix_alpha: int = 1,
score_norm: bool = True,
nbest: int = 1,
):
"""Initialize transducer beam search.
Args:
decoder: Decoder class to use
joint_network: Joint Network class
beam_size: Number of hypotheses kept during search
lm: LM class to use
lm_weight: lm weight for soft fusion
search_type: type of algorithm to use for search
max_sym_exp: number of maximum symbol expansions at each time step ("tsd")
u_max: maximum output sequence length ("alsd")
nstep: number of maximum expansion steps at each time step ("nsc")
prefix_alpha: maximum prefix length in prefix search ("nsc")
score_norm: normalize final scores by length ("default")
nbest: number of returned final hypothesis
"""
self.decoder = decoder
self.joint_network = joint_network
self.beam_size = beam_size
self.hidden_size = decoder.dunits
self.vocab_size = decoder.odim
self.blank = decoder.blank
if self.beam_size <= 1:
self.search_algorithm = self.greedy_search
elif search_type == "default":
self.search_algorithm = self.default_beam_search
elif search_type == "tsd":
self.search_algorithm = self.time_sync_decoding
elif search_type == "alsd":
self.search_algorithm = self.align_length_sync_decoding
elif search_type == "nsc":
self.search_algorithm = self.nsc_beam_search
else:
raise NotImplementedError
self.lm = lm
self.lm_weight = lm_weight
if lm is not None:
self.use_lm = True
self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False
self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor
self.lm_layers = len(self.lm_predictor.rnn)
else:
self.use_lm = False
self.max_sym_exp = max_sym_exp
self.u_max = u_max
self.nstep = nstep
self.prefix_alpha = prefix_alpha
self.score_norm = score_norm
self.nbest = nbest
def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]:
"""Perform beam search.
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
self.decoder.set_device(h.device)
if not hasattr(self.decoder, "decoders"):
self.decoder.set_data_type(h.dtype)
nbest_hyps = self.search_algorithm(h)
return nbest_hyps
def sort_nbest(
self, hyps: Union[List[Hypothesis], List[NSCHypothesis]]
) -> Union[List[Hypothesis], List[NSCHypothesis]]:
"""Sort hypotheses by score or score given sequence length.
Args:
hyps: list of hypotheses
Return:
hyps: sorted list of hypotheses
"""
if self.score_norm:
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
else:
hyps.sort(key=lambda x: x.score, reverse=True)
return hyps[: self.nbest]
def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]:
"""Greedy search implementation for transformer-transducer.
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
hyp: 1-best decoding results
"""
dec_state = self.decoder.init_state(1)
hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)
cache = {}
y, state, _ = self.decoder.score(hyp, cache)
for i, hi in enumerate(h):
ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
logp, pred = torch.max(ytu, dim=-1)
if pred != self.blank:
hyp.yseq.append(int(pred))
hyp.score += float(logp)
hyp.dec_state = state
y, state, _ = self.decoder.score(hyp, cache)
return [hyp]
def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]:
"""Beam search implementation.
Args:
x: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
beam_k = min(beam, (self.vocab_size - 1))
dec_state = self.decoder.init_state(1)
kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)]
cache = {}
for hi in h:
hyps = kept_hyps
kept_hyps = []
while True:
max_hyp = max(hyps, key=lambda x: x.score)
hyps.remove(max_hyp)
y, state, lm_tokens = self.decoder.score(max_hyp, cache)
ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1)
top_k = ytu[1:].topk(beam_k, dim=-1)
kept_hyps.append(
Hypothesis(
score=(max_hyp.score + float(ytu[0:1])),
yseq=max_hyp.yseq[:],
dec_state=max_hyp.dec_state,
lm_state=max_hyp.lm_state,
)
)
if self.use_lm:
lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens)
else:
lm_state = max_hyp.lm_state
for logp, k in zip(*top_k):
score = max_hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * lm_scores[0][k + 1]
hyps.append(
Hypothesis(
score=score,
yseq=max_hyp.yseq[:] + [int(k + 1)],
dec_state=state,
lm_state=lm_state,
)
)
hyps_max = float(max(hyps, key=lambda x: x.score).score)
kept_most_prob = sorted(
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
key=lambda x: x.score,
)
if len(kept_most_prob) >= beam:
kept_hyps = kept_most_prob
break
return self.sort_nbest(kept_hyps)
def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
"""Time synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
beam_state = self.decoder.init_state(beam)
B = [
Hypothesis(
yseq=[self.blank],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
if self.use_lm and not self.is_wordlm:
B[0].lm_state = init_lm_state(self.lm_predictor)
for hi in h:
A = []
C = B
h_enc = hi.unsqueeze(0)
for v in range(self.max_sym_exp):
D = []
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
C,
beam_state,
cache,
self.use_lm,
)
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
seq_A = [h.yseq for h in A]
for i, hyp in enumerate(C):
if hyp.yseq not in seq_A:
A.append(
Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
)
else:
dict_pos = seq_A.index(hyp.yseq)
A[dict_pos].score = np.logaddexp(
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
)
if v < (self.max_sym_exp - 1):
if self.use_lm:
beam_lm_states = create_lm_batch_state(
[c.lm_state for c in C], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(C)
)
for i, hyp in enumerate(C):
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
D.append(new_hyp)
C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
return self.sort_nbest(B)
def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]:
"""Alignment-length synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
h_length = int(h.size(0))
u_max = min(self.u_max, (h_length - 1))
beam_state = self.decoder.init_state(beam)
B = [
Hypothesis(
yseq=[self.blank],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
final = []
cache = {}
if self.use_lm and not self.is_wordlm:
B[0].lm_state = init_lm_state(self.lm_predictor)
for i in range(h_length + u_max):
A = []
B_ = []
h_states = []
for hyp in B:
u = len(hyp.yseq) - 1
t = i - u + 1
if t > (h_length - 1):
continue
B_.append(hyp)
h_states.append((t, h[t]))
if B_:
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
B_,
beam_state,
cache,
self.use_lm,
)
h_enc = torch.stack([h[1] for h in h_states])
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
if self.use_lm:
beam_lm_states = create_lm_batch_state(
[b.lm_state for b in B_], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(B_)
)
for i, hyp in enumerate(B_):
new_hyp = Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
A.append(new_hyp)
if h_states[i][0] == (h_length - 1):
final.append(new_hyp)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq[:] + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
A.append(new_hyp)
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
B = recombine_hyps(B)
if final:
return self.sort_nbest(final)
else:
return B
def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]:
"""N-step constrained beam search implementation.
Based and modified from https://arxiv.org/pdf/2002.03577.pdf.
Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
until further modifications.
Note: the algorithm is not in his "complete" form but works almost as
intended.
Args:
h: Encoded speech features (T_max, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
beam = min(self.beam_size, self.vocab_size)
beam_k = min(beam, (self.vocab_size - 1))
beam_state = self.decoder.init_state(beam)
init_tokens = [
NSCHypothesis(
yseq=[self.blank],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
init_tokens,
beam_state,
cache,
self.use_lm,
)
state = self.decoder.select_state(beam_state, 0)
if self.use_lm:
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
None, beam_lm_tokens, 1
)
lm_state = select_lm_state(
beam_lm_states, 0, self.lm_layers, self.is_wordlm
)
lm_scores = beam_lm_scores[0]
else:
lm_state = None
lm_scores = None
kept_hyps = [
NSCHypothesis(
yseq=[self.blank],
score=0.0,
dec_state=state,
y=[beam_y[0]],
lm_state=lm_state,
lm_scores=lm_scores,
)
]
for hi in h:
hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True)
kept_hyps = []
h_enc = hi.unsqueeze(0)
for j, hyp_j in enumerate(hyps[:-1]):
for hyp_i in hyps[(j + 1) :]:
curr_id = len(hyp_j.yseq)
next_id = len(hyp_i.yseq)
if (
is_prefix(hyp_j.yseq, hyp_i.yseq)
and (curr_id - next_id) <= self.prefix_alpha
):
ytu = torch.log_softmax(
self.joint_network(hi, hyp_i.y[-1]), dim=-1
)
curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]])
for k in range(next_id, (curr_id - 1)):
ytu = torch.log_softmax(
self.joint_network(hi, hyp_j.y[k]), dim=-1
)
curr_score += float(ytu[hyp_j.yseq[k + 1]])
hyp_j.score = np.logaddexp(hyp_j.score, curr_score)
S = []
V = []
for n in range(self.nstep):
beam_y = torch.stack([hyp.y[-1] for hyp in hyps])
beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1)
beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)
for i, hyp in enumerate(hyps):
S.append(
NSCHypothesis(
yseq=hyp.yseq[:],
score=hyp.score + float(beam_logp[i, 0:1]),
y=hyp.y[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
score = hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * float(hyp.lm_scores[k])
V.append(
NSCHypothesis(
yseq=hyp.yseq[:] + [int(k)],
score=score,
y=hyp.y[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
)
V.sort(key=lambda x: x.score, reverse=True)
V = substract(V, hyps)[:beam]
beam_state = self.decoder.create_batch_states(
beam_state,
[v.dec_state for v in V],
[v.yseq for v in V],
)
beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score(
V,
beam_state,
cache,
self.use_lm,
)
if self.use_lm:
beam_lm_states = create_lm_batch_state(
[v.lm_state for v in V], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(V)
)
if n < (self.nstep - 1):
for i, v in enumerate(V):
v.y.append(beam_y[i])
v.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
v.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
v.lm_scores = beam_lm_scores[i]
hyps = V[:]
else:
beam_logp = torch.log_softmax(
self.joint_network(h_enc, beam_y), dim=-1
)
for i, v in enumerate(V):
if self.nstep != 1:
v.score += float(beam_logp[i, 0])
v.y.append(beam_y[i])
v.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
v.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
v.lm_scores = beam_lm_scores[i]
kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]
return self.sort_nbest(kept_hyps)