Spaces:
Sleeping
Sleeping
import time | |
import warnings | |
from abc import ABC | |
from copy import deepcopy | |
from typing import Optional | |
import torch | |
from .file_utils import add_start_docstrings | |
STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using :class:`~transformers.BertTokenizer`. See | |
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |
details. | |
`What are input IDs? <../glossary.html#input-ids>`__ | |
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax | |
or scores for each vocabulary token after SoftMax. | |
kwargs: | |
Additional stopping criteria specific kwargs. | |
Return: | |
:obj:`bool`. :obj:`False` indicates we should continue, :obj:`True` indicates we should stop. | |
""" | |
class StoppingCriteria(ABC): | |
"""Abstract base class for all stopping criteria that can be applied during generation.""" | |
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: | |
raise NotImplementedError("StoppingCriteria needs to be subclassed") | |
class MaxLengthCriteria(StoppingCriteria): | |
""" | |
This class can be used to stop generation whenever the full generated number of tokens exceeds :obj:`max_length`. | |
Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens. | |
Args: | |
max_length (:obj:`int`): | |
The maximum length that the output sequence can have in number of tokens. | |
""" | |
def __init__(self, max_length: int): | |
self.max_length = max_length | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return input_ids.shape[-1] >= self.max_length | |
class MaxNewTokensCriteria(StoppingCriteria): | |
""" | |
This class can be used to stop generation whenever the generated number of tokens exceeds :obj:`max_new_tokens`. | |
Keep in mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is | |
very close to :obj:`MaxLengthCriteria` but ignores the number of initial tokens. | |
Args: | |
start_length (:obj:`int`): | |
The number of initial tokens. | |
max_new_tokens (:obj:`int`): | |
The maximum number of tokens to generate. | |
""" | |
def __init__(self, start_length: int, max_new_tokens: int): | |
self.start_length = start_length | |
self.max_new_tokens = max_new_tokens | |
self.max_length = start_length + max_new_tokens | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return input_ids.shape[-1] >= self.max_length | |
class MaxTimeCriteria(StoppingCriteria): | |
""" | |
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the | |
time will start being counted when you initialize this function. You can override this by passing an | |
:obj:`initial_time`. | |
Args: | |
max_time (:obj:`float`): | |
The maximum allowed time in seconds for the generation. | |
initial_time (:obj:`float`, `optional`, defaults to :obj:`time.time()`): | |
The start of the generation allowed time. | |
""" | |
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): | |
self.max_time = max_time | |
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return time.time() - self.initial_timestamp > self.max_time | |
class StoppingCriteriaList(list): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return any(criteria(input_ids, scores) for criteria in self) | |
def max_length(self) -> Optional[int]: | |
for stopping_criterium in self: | |
if isinstance(stopping_criterium, MaxLengthCriteria): | |
return stopping_criterium.max_length | |
elif isinstance(stopping_criterium, MaxNewTokensCriteria): | |
return stopping_criterium.max_length | |
return None | |
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: | |
stopping_max_length = stopping_criteria.max_length | |
new_stopping_criteria = deepcopy(stopping_criteria) | |
if stopping_max_length is not None and stopping_max_length != max_length: | |
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |
elif stopping_max_length is None: | |
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | |
return new_stopping_criteria | |