Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2020 The HuggingFace Inc. team | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import warnings | |
from abc import ABC, abstractmethod | |
from collections import UserDict | |
from typing import Optional, Tuple | |
import torch | |
from .file_utils import add_start_docstrings | |
PROCESS_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using any class inheriting from :class:`~transformers.PreTrainedTokenizer`. See | |
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |
details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |
Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. | |
next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |
:obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. | |
next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |
Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
Return: | |
:obj:`UserDict`: A dictionary composed of the fields as defined above: | |
- **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated | |
scores of all non-finished beams. | |
- **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens | |
to be added to the non-finished beam_hypotheses. | |
- **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices | |
indicating to which beam the next tokens shall be added. | |
""" | |
FINALIZE_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using any class inheriting from :class:`~transformers.PreTrainedTokenizer`. See | |
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |
details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |
The final scores of all non-finished beams. | |
final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |
The last tokens to be added to the non-finished beam_hypotheses. | |
final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |
The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. | |
pad_token_id (:obj:`int`, `optional`): | |
The id of the `padding` token. | |
eos_token_id (:obj:`int`, `optional`): | |
The id of the `end-of-sequence` token. | |
Return: | |
:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated | |
sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all | |
batches finished early due to the :obj:`eos_token_id`. | |
""" | |
class BeamScorer(ABC): | |
""" | |
Abstract base class for all beam scorers that are used for :meth:`~transformers.PreTrainedModel.beam_search` and | |
:meth:`~transformers.PreTrainedModel.beam_sample`. | |
""" | |
def process( | |
self, | |
input_ids: torch.LongTensor, | |
next_scores: torch.FloatTensor, | |
next_tokens: torch.LongTensor, | |
next_indices: torch.LongTensor, | |
**kwargs | |
) -> Tuple[torch.Tensor]: | |
raise NotImplementedError("This is an abstract method.") | |
def finalize( | |
self, | |
input_ids: torch.LongTensor, | |
next_scores: torch.FloatTensor, | |
next_tokens: torch.LongTensor, | |
next_indices: torch.LongTensor, | |
max_length: int, | |
**kwargs | |
) -> torch.LongTensor: | |
raise NotImplementedError("This is an abstract method.") | |
class BeamSearchScorer(BeamScorer): | |
r""" | |
:class:`transformers.BeamScorer` implementing standard beam search decoding. | |
Adapted in part from `Facebook's XLM beam search code | |
<https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. | |
Reference for the diverse beam search algorithm and implementation `Ashwin Kalyan's DBS implementation | |
<https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua>`__ | |
Args: | |
batch_size (:obj:`int`): | |
Batch Size of :obj:`input_ids` for which standard beam search decoding is run in parallel. | |
max_length (:obj:`int`): | |
The maximum length of the sequence to be generated. | |
num_beams (:obj:`int`): | |
Number of beams for beam search. | |
device (:obj:`torch.device`): | |
Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of | |
:obj:`BeamSearchScorer` will be allocated. | |
length_penalty (:obj:`float`, `optional`, defaults to 1.0): | |
Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the | |
model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer | |
sequences. | |
do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. | |
num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): | |
The number of beam hypotheses that shall be returned upon calling | |
:meth:`~transformer.BeamSearchScorer.finalize`. | |
num_beam_groups (:obj:`int`): | |
Number of groups to divide :obj:`num_beams` into in order to ensure diversity among different groups of | |
beams. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__ for more details. | |
""" | |
def __init__( | |
self, | |
batch_size: int, | |
num_beams: int, | |
device: torch.device, | |
length_penalty: Optional[float] = 1.0, | |
do_early_stopping: Optional[bool] = False, | |
num_beam_hyps_to_keep: Optional[int] = 1, | |
num_beam_groups: Optional[int] = 1, | |
**kwargs, | |
): | |
self.num_beams = num_beams | |
self.device = device | |
self.length_penalty = length_penalty | |
self.do_early_stopping = do_early_stopping | |
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | |
self.num_beam_groups = num_beam_groups | |
self.group_size = self.num_beams // self.num_beam_groups | |
self._is_init = False | |
self._beam_hyps = [ | |
BeamHypotheses( | |
num_beams=self.num_beams, | |
length_penalty=self.length_penalty, | |
early_stopping=self.do_early_stopping, | |
) | |
for _ in range(batch_size) | |
] | |
self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) | |
if not isinstance(num_beams, int) or num_beams <= 1: | |
raise ValueError( | |
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead." | |
) | |
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): | |
raise ValueError( | |
f"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` " | |
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." | |
) | |
if "max_length" in kwargs: | |
warnings.warn( | |
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect." | |
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`" | |
",or `group_beam_search(...)`." | |
) | |
def is_done(self) -> bool: | |
return self._done.all() | |
def process( | |
self, | |
input_ids: torch.LongTensor, | |
next_scores: torch.FloatTensor, | |
next_tokens: torch.LongTensor, | |
next_indices: torch.LongTensor, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
) -> Tuple[torch.Tensor]: | |
cur_len = input_ids.shape[-1] | |
batch_size = len(self._beam_hyps) | |
assert batch_size == (input_ids.shape[0] // self.group_size) | |
device = input_ids.device | |
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) | |
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device) | |
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device) | |
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |
if self._done[batch_idx]: | |
assert ( | |
len(beam_hyp) >= self.num_beams | |
), f"Batch can only be done if at least {self.num_beams} beams have been generated" | |
assert ( | |
eos_token_id is not None and pad_token_id is not None | |
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined" | |
# pad the batch | |
next_beam_scores[batch_idx, :] = 0 | |
next_beam_tokens[batch_idx, :] = pad_token_id | |
next_beam_indices[batch_idx, :] = 0 | |
continue | |
# next tokens for this sentence | |
beam_idx = 0 | |
for beam_token_rank, (next_token, next_score, next_index) in enumerate( | |
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) | |
): | |
batch_beam_idx = batch_idx * self.group_size + next_index | |
# add to generated hypotheses if end of sentence | |
if (eos_token_id is not None) and (next_token.item() == eos_token_id): | |
# if beam_token does not belong to top num_beams tokens, it should not be added | |
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size | |
if is_beam_token_worse_than_top_num_beams: | |
continue | |
beam_hyp.add( | |
input_ids[batch_beam_idx].clone(), | |
next_score.item(), | |
) | |
else: | |
# add next predicted token since it is not eos_token | |
next_beam_scores[batch_idx, beam_idx] = next_score | |
next_beam_tokens[batch_idx, beam_idx] = next_token | |
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx | |
beam_idx += 1 | |
# once the beam for next step is full, don't add more tokens to it. | |
if beam_idx == self.group_size: | |
break | |
if beam_idx < self.group_size: | |
raise ValueError( | |
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." | |
) | |
# Check if we are done so that we can save a pad step if all(done) | |
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( | |
next_scores[batch_idx].max().item(), cur_len | |
) | |
return UserDict( | |
{ | |
"next_beam_scores": next_beam_scores.view(-1), | |
"next_beam_tokens": next_beam_tokens.view(-1), | |
"next_beam_indices": next_beam_indices.view(-1), | |
} | |
) | |
def finalize( | |
self, | |
input_ids: torch.LongTensor, | |
final_beam_scores: torch.FloatTensor, | |
final_beam_tokens: torch.LongTensor, | |
final_beam_indices: torch.LongTensor, | |
max_length: int, | |
pad_token_id: Optional[int] = None, | |
eos_token_id: Optional[int] = None, | |
) -> Tuple[torch.LongTensor]: | |
batch_size = len(self._beam_hyps) | |
# finalize all open beam hypotheses and add to generated hypotheses | |
for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |
if self._done[batch_idx]: | |
continue | |
# all open beam hypotheses are added to the beam hypothesis | |
# beam hypothesis class automatically keeps the best beams | |
for beam_id in range(self.num_beams): | |
batch_beam_idx = batch_idx * self.num_beams + beam_id | |
final_score = final_beam_scores[batch_beam_idx].item() | |
final_tokens = input_ids[batch_beam_idx] | |
beam_hyp.add(final_tokens, final_score) | |
# select the best hypotheses | |
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) | |
best = [] | |
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) | |
# retrieve best hypotheses | |
for i, beam_hyp in enumerate(self._beam_hyps): | |
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) | |
for j in range(self.num_beam_hyps_to_keep): | |
best_hyp_tuple = sorted_hyps.pop() | |
best_score = best_hyp_tuple[0] | |
best_hyp = best_hyp_tuple[1] | |
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) | |
# append to lists | |
best.append(best_hyp) | |
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score | |
# prepare for adding eos | |
sent_max_len = min(sent_lengths.max().item() + 1, max_length) | |
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |
# shorter batches are padded if needed | |
if sent_lengths.min().item() != sent_lengths.max().item(): | |
assert pad_token_id is not None, "`pad_token_id` has to be defined" | |
decoded.fill_(pad_token_id) | |
# fill with hypotheses and eos_token_id if the latter fits in | |
for i, hypo in enumerate(best): | |
decoded[i, : sent_lengths[i]] = hypo | |
if sent_lengths[i] < max_length: | |
decoded[i, sent_lengths[i]] = eos_token_id | |
return UserDict( | |
{ | |
"sequences": decoded, | |
"sequence_scores": best_scores, | |
} | |
) | |
class BeamHypotheses: | |
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool): | |
""" | |
Initialize n-best list of hypotheses. | |
""" | |
self.length_penalty = length_penalty | |
self.early_stopping = early_stopping | |
self.num_beams = num_beams | |
self.beams = [] | |
self.worst_score = 1e9 | |
def __len__(self): | |
""" | |
Number of hypotheses in the list. | |
""" | |
return len(self.beams) | |
def add(self, hyp: torch.LongTensor, sum_logprobs: float): | |
""" | |
Add a new hypothesis to the list. | |
""" | |
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) | |
if len(self) < self.num_beams or score > self.worst_score: | |
self.beams.append((score, hyp)) | |
if len(self) > self.num_beams: | |
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) | |
del self.beams[sorted_next_scores[0][1]] | |
self.worst_score = sorted_next_scores[1][0] | |
else: | |
self.worst_score = min(score, self.worst_score) | |
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | |
""" | |
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst | |
one in the heap, then we are done with this sentence. | |
""" | |
if len(self) < self.num_beams: | |
return False | |
elif self.early_stopping: | |
return True | |
else: | |
cur_score = best_sum_logprobs / cur_len ** self.length_penalty | |
ret = self.worst_score >= cur_score | |
return ret | |