Spaces:
Running
Running
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import random | |
import unittest | |
import timeout_decorator | |
from ..file_utils import cached_property, is_torch_available | |
from ..testing_utils import require_torch | |
if is_torch_available(): | |
import torch | |
from ..models.marian import MarianConfig, MarianMTModel | |
class GenerationUtilsTest(unittest.TestCase): | |
def config(self): | |
config = MarianConfig.from_pretrained("sshleifer/tiny-marian-en-de") | |
return config | |
def model(self): | |
return MarianMTModel(self.config) | |
def test_postprocess_next_token_scores(self): | |
config = self.config | |
model = self.model | |
# Initialize an input id tensor with batch size 8 and sequence length 12 | |
input_ids = torch.arange(0, 96, 1).view((8, 12)) | |
eos = config.eos_token_id | |
bad_words_ids_test_cases = [[[299]], [[23, 24], [54]], [[config.eos_token_id]], []] | |
masked_scores = [ | |
[(0, 299), (1, 299), (2, 299), (3, 299), (4, 299), (5, 299), (6, 299), (7, 299)], | |
[(1, 24), (0, 54), (1, 54), (2, 54), (3, 54), (4, 54), (5, 54), (6, 54), (7, 54)], | |
[(0, eos), (1, eos), (2, eos), (3, eos), (4, eos), (5, eos), (6, eos), (7, eos)], | |
[], | |
] | |
for test_case_index, bad_words_ids in enumerate(bad_words_ids_test_cases): | |
# Initialize a scores tensor with batch size 8 and vocabulary size 300 | |
scores = torch.rand((8, 300)) | |
output = model.postprocess_next_token_scores( | |
scores, | |
input_ids, | |
0, | |
bad_words_ids, | |
13, | |
15, | |
config.max_length, | |
config.eos_token_id, | |
config.repetition_penalty, | |
32, | |
5, | |
) | |
for masked_score in masked_scores[test_case_index]: | |
self.assertTrue(output[masked_score[0], masked_score[1]] == -float("inf")) | |
def test_postprocess_next_token_scores_large_bad_words_list(self): | |
config = self.config | |
model = self.model | |
# Initialize an input id tensor with batch size 8 and sequence length 12 | |
input_ids = torch.arange(0, 96, 1).view((8, 12)) | |
bad_words_ids = [] | |
for _ in range(100): | |
length_bad_word = random.randint(1, 4) | |
bad_words_ids.append(random.sample(range(1, 300), length_bad_word)) | |
scores = torch.rand((8, 300)) | |
_ = model.postprocess_next_token_scores( | |
scores, | |
input_ids, | |
0, | |
bad_words_ids, | |
13, | |
15, | |
config.max_length, | |
config.eos_token_id, | |
config.repetition_penalty, | |
32, | |
5, | |
) | |