|
"""Scorer interface module.""" |
|
|
|
from typing import Any |
|
from typing import List |
|
from typing import Tuple |
|
|
|
import torch |
|
import warnings |
|
|
|
|
|
class ScorerInterface: |
|
"""Scorer interface for beam search. |
|
|
|
The scorer performs scoring of the all tokens in vocabulary. |
|
|
|
Examples: |
|
* Search heuristics |
|
* :class:`espnet.nets.scorers.length_bonus.LengthBonus` |
|
* Decoder networks of the sequence-to-sequence models |
|
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` |
|
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` |
|
* Neural language models |
|
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` |
|
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` |
|
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` |
|
|
|
""" |
|
|
|
def init_state(self, x: torch.Tensor) -> Any: |
|
"""Get an initial state for decoding (optional). |
|
|
|
Args: |
|
x (torch.Tensor): The encoded feature tensor |
|
|
|
Returns: initial state |
|
|
|
""" |
|
return None |
|
|
|
def select_state(self, state: Any, i: int, new_id: int = None) -> Any: |
|
"""Select state with relative ids in the main beam search. |
|
|
|
Args: |
|
state: Decoder state for prefix tokens |
|
i (int): Index to select a state in the main beam search |
|
new_id (int): New label index to select a state if necessary |
|
|
|
Returns: |
|
state: pruned state |
|
|
|
""" |
|
return None if state is None else state[i] |
|
|
|
def score( |
|
self, y: torch.Tensor, state: Any, x: torch.Tensor |
|
) -> Tuple[torch.Tensor, Any]: |
|
"""Score new token (required). |
|
|
|
Args: |
|
y (torch.Tensor): 1D torch.int64 prefix tokens. |
|
state: Scorer state for prefix tokens |
|
x (torch.Tensor): The encoder feature that generates ys. |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: Tuple of |
|
scores for next token that has a shape of `(n_vocab)` |
|
and next state for ys |
|
|
|
""" |
|
raise NotImplementedError |
|
|
|
def final_score(self, state: Any) -> float: |
|
"""Score eos (optional). |
|
|
|
Args: |
|
state: Scorer state for prefix tokens |
|
|
|
Returns: |
|
float: final score |
|
|
|
""" |
|
return 0.0 |
|
|
|
|
|
class BatchScorerInterface(ScorerInterface): |
|
"""Batch scorer interface.""" |
|
|
|
def batch_init_state(self, x: torch.Tensor) -> Any: |
|
"""Get an initial state for decoding (optional). |
|
|
|
Args: |
|
x (torch.Tensor): The encoded feature tensor |
|
|
|
Returns: initial state |
|
|
|
""" |
|
return self.init_state(x) |
|
|
|
def batch_score( |
|
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
|
) -> Tuple[torch.Tensor, List[Any]]: |
|
"""Score new token batch (required). |
|
|
|
Args: |
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
|
states (List[Any]): Scorer states for prefix tokens. |
|
xs (torch.Tensor): |
|
The encoder feature that generates ys (n_batch, xlen, n_feat). |
|
|
|
Returns: |
|
tuple[torch.Tensor, List[Any]]: Tuple of |
|
batchfied scores for next token with shape of `(n_batch, n_vocab)` |
|
and next state list for ys. |
|
|
|
""" |
|
warnings.warn( |
|
"{} batch score is implemented through for loop not parallelized".format( |
|
self.__class__.__name__ |
|
) |
|
) |
|
scores = list() |
|
outstates = list() |
|
for i, (y, state, x) in enumerate(zip(ys, states, xs)): |
|
score, outstate = self.score(y, state, x) |
|
outstates.append(outstate) |
|
scores.append(score) |
|
scores = torch.cat(scores, 0).view(ys.shape[0], -1) |
|
return scores, outstates |
|
|
|
|
|
class PartialScorerInterface(ScorerInterface): |
|
"""Partial scorer interface for beam search. |
|
|
|
The partial scorer performs scoring when non-partial scorer finished scoring, |
|
and recieves pre-pruned next tokens to score because it is too heavy to score |
|
all the tokens. |
|
|
|
Examples: |
|
* Prefix search for connectionist-temporal-classification models |
|
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` |
|
|
|
""" |
|
|
|
def score_partial( |
|
self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor |
|
) -> Tuple[torch.Tensor, Any]: |
|
"""Score new token (required). |
|
|
|
Args: |
|
y (torch.Tensor): 1D prefix token |
|
next_tokens (torch.Tensor): torch.int64 next token to score |
|
state: decoder state for prefix tokens |
|
x (torch.Tensor): The encoder feature that generates ys |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: |
|
Tuple of a score tensor for y that has a shape `(len(next_tokens),)` |
|
and next state for ys |
|
|
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): |
|
"""Batch partial scorer interface for beam search.""" |
|
|
|
def batch_score_partial( |
|
self, |
|
ys: torch.Tensor, |
|
next_tokens: torch.Tensor, |
|
states: List[Any], |
|
xs: torch.Tensor, |
|
) -> Tuple[torch.Tensor, Any]: |
|
"""Score new token (required). |
|
|
|
Args: |
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
|
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). |
|
states (List[Any]): Scorer states for prefix tokens. |
|
xs (torch.Tensor): |
|
The encoder feature that generates ys (n_batch, xlen, n_feat). |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: |
|
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` |
|
and next states for ys |
|
""" |
|
raise NotImplementedError |
|
|