Spaces:
Paused
Paused
# coding=utf-8 | |
# Copyright 2020 The HuggingFace Team Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a clone of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
from transformers import is_torch_available | |
from transformers.testing_utils import require_torch | |
if is_torch_available(): | |
import torch | |
from transformers.generation import DisjunctiveConstraint | |
class ConstraintTest(unittest.TestCase): | |
def test_input_types(self): | |
# For consistency across different places the DisjunctiveConstraint is called, | |
# dc.token_ids is a list of integers. It is also initialized only by integers. | |
cset = [[1, 2, 4], [1, 2, 3, 4]] | |
dc = DisjunctiveConstraint(cset) | |
self.assertTrue(isinstance(dc.token_ids, list)) | |
with self.assertRaises(ValueError): | |
DisjunctiveConstraint(torch.LongTensor([[1, 2, 4], [1, 2, 3]])) | |
with self.assertRaises(ValueError): | |
DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])]) | |
def test_check_illegal_input(self): | |
# We can't have constraints that are complete subsets of another. This leads to a preverse | |
# interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint? | |
# It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially | |
# fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm | |
# will necessarily never reach [1,2,3,4], giving users a false sense of control (better to just not allow it). | |
cset = [[1, 2], [1, 2, 3, 4]] | |
with self.assertRaises(ValueError): | |
DisjunctiveConstraint(cset) # fails here | |
def test_example_progression(self): | |
cset = [[1, 2, 3], [1, 2, 4]] | |
dc = DisjunctiveConstraint(cset) | |
stepped, completed, reset = dc.update(1) | |
desired = stepped is True and completed is False and reset is False | |
self.assertTrue(desired) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.current_seq == [1]) | |
stepped, completed, reset = dc.update(2) | |
desired = stepped is True and completed is False and reset is False | |
self.assertTrue(desired) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.current_seq == [1, 2]) | |
stepped, completed, reset = dc.update(3) | |
desired = stepped is True and completed is True and reset is False | |
self.assertTrue(desired) | |
self.assertTrue(dc.completed) # Completed! | |
self.assertTrue(dc.current_seq == [1, 2, 3]) | |
def test_example_progression_unequal_three_mid_and_reset(self): | |
cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]] | |
dc = DisjunctiveConstraint(cset) | |
stepped, completed, reset = dc.update(1) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.current_seq == [1]) | |
stepped, completed, reset = dc.update(2) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.current_seq == [1, 2]) | |
stepped, completed, reset = dc.update(4) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.current_seq == [1, 2, 4]) | |
stepped, completed, reset = dc.update(5) | |
self.assertTrue(dc.completed) # Completed! | |
self.assertTrue(dc.current_seq == [1, 2, 4, 5]) | |
dc.reset() | |
stepped, completed, reset = dc.update(1) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.remaining() == 3) | |
self.assertTrue(dc.current_seq == [1]) | |
stepped, completed, reset = dc.update(2) | |
self.assertTrue(not dc.completed) | |
self.assertTrue(dc.remaining() == 2) | |
self.assertTrue(dc.current_seq == [1, 2]) | |
stepped, completed, reset = dc.update(5) | |
self.assertTrue(dc.completed) # Completed! | |
self.assertTrue(dc.remaining() == 0) | |
self.assertTrue(dc.current_seq == [1, 2, 5]) | |