File size: 46,609 Bytes
2010c83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 |
"""
This is a self-contained and flexible beam search implementation adapted from
AllenNLP's beam search: https://github.com/allenai/allennlp/blob/main/allennlp/nn/beam_search.py
"""
import copy
import warnings
from abc import abstractmethod
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
import torch
__all__ = [
"Sampler",
"DeterministicSampler",
"MultinomialSampler",
"TopKSampler",
"TopPSampler",
"GumbelSampler",
"FinalSequenceScorer",
"SequenceLogProbabilityScorer",
"LengthNormalizedSequenceLogProbabilityScorer",
"Constraint",
"RepeatedNGramBlockingConstraint",
"BeamSearch",
]
StateType = Dict[str, torch.Tensor]
StepFunctionTypeWithTimestep = Callable[[torch.Tensor, StateType, int], Tuple[torch.Tensor, StateType]]
StepFunctionTypeNoTimestep = Callable[[torch.Tensor, StateType], Tuple[torch.Tensor, StateType]]
StepFunctionType = TypeVar("StepFunctionType", StepFunctionTypeWithTimestep, StepFunctionTypeNoTimestep)
"""
The type of step function that can be passed to [`BeamSearch.search`](#search).
This can either be [`StepFunctionTypeWithTimestep`](#stepfunctiontypewithtimestep)
or [`StepFunctionTypeNoTimestep`](#stepfunctiontypenotimestep).
"""
ConstraintStateType = List[List[Dict[str, Any]]]
class Sampler:
"""
An abstract class that can be used to sample candidates (either nodes or beams)
within `BeamSearch`.
A `Sampler` just has three methods, `init_state()`, `sample_nodes()` and `sample_beams()`.
`init_state()` takes three arguments:
- a tensor of starting log probs with shape `(batch_size,, num_classes)`,
- the batch size, an int,
- and the number of classes, also an int.
It returns a state dictionary with any state tensors needed for subsequent
calls to `sample_nodes()` and `sample_beams()`.
By default this method just returns an empty dictionary.
Both `sample_nodes()` and `sample_beams()` should take three arguments:
- tensor of normalized log probabilities with shape `(batch_size, num_examples)`,
- an integer representing the number of samples to take for each example in the batch,
- and a state dictionary which could contain any tensors needed for the `Sampler` to keep
track of state.
For `sample_nodes()`, `num_examples = num_classes`, but for `sample_beams`,
`num_examples = beam_size * per_node_beam_size`.
The return value should be a tuple containing:
- a tensor of log probabilities of the sampled examples with shape `(batch_size, num_samples)`,
- a tensor of indices of the sampled examples with shape `(batch_size, num_samples)`,
- and the updated state dictionary.
A default implementation of `sample_beams` is provided, which just deterministically
picks the `k` examples with highest log probability.
"""
def init_state(
self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
) -> StateType:
del start_class_log_probabilities, batch_size, num_classes
return {}
@abstractmethod
def sample_nodes(
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
raise NotImplementedError
def sample_beams(
self, log_probs: torch.Tensor, beam_size: int, state: StateType
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
del state
selected_log_probs, selected_indices = torch.topk(log_probs, beam_size, dim=-1)
return selected_log_probs, selected_indices, {}
class DeterministicSampler(Sampler):
"""
A `Sampler` that just deterministically returns the `k` nodes or beams with highest
log probability.
"""
def sample_nodes(
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
del state
selected_log_probs, selected_indices = torch.topk(log_probs, per_node_beam_size, dim=-1)
return selected_log_probs, selected_indices, {}
class MultinomialSampler(Sampler):
"""
A `Sampler` which samples nodes from the given multinomial distribution. Beams are sampled
in the default, non-deterministic way.
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
above 1.0 produces a flatter probability distribution.
:param with_replacement: Whether to sample with replacement.
"""
def __init__(
self,
temperature: float = 1.0,
with_replacement: bool = False,
) -> None:
self.temperature = temperature
self.with_replacement = with_replacement
def sample_nodes(
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
if self.temperature != 1.0:
_probabilities = torch.nn.functional.softmax(log_probs / self.temperature, dim=-1)
else:
_probabilities = log_probs.exp()
selected_indices = torch.multinomial(_probabilities, per_node_beam_size, replacement=self.with_replacement)
return torch.gather(log_probs, 1, selected_indices), selected_indices, state
class TopKSampler(Sampler):
"""
A `Sampler` which redistributes the probability mass function for nodes among the
top `k` choices, then samples from that subset after re-normalizing the probabilities.
Beams are sampled in the default, deterministic way.
:param k: The number of top choices to be selected from.
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
above 1.0 produces a flatter probability distribution.
:param with_replacement: If set to `True`, samples will be selected with replacement from the top k choices.
"""
def __init__(
self,
k: int = 1,
temperature: float = 1.0,
with_replacement: bool = False,
):
self.k = k
self.temperature = temperature or 1.0
self.with_replacement = with_replacement
def sample_nodes(
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
if not per_node_beam_size <= self.k <= log_probs.size()[1]:
raise ValueError(
"k must be a postive integer no less than per_node_beam_size and no greater than vocabulary size"
)
# shape (both): (batch_size, k)
top_k_log_probs, top_k_indices = log_probs.topk(self.k, dim=-1)
# Apply temperature if necessary.
# shape: (batch_size, k)
if self.temperature != 1.0:
top_k_log_probs = top_k_log_probs / self.temperature
# Re-normalize the subset.
# shape: (batch_size, k)
normalized_top_k_probs = torch.nn.functional.softmax(top_k_log_probs, dim=-1)
# Sample from the re-normalized subset.
# NOTE: These indices are not indices into `log_probs`, they are indices into `top_k_log_probs`.
# shape: (batch_size, per_node_beam_size)
sampled_indices = torch.multinomial(
normalized_top_k_probs, per_node_beam_size, replacement=self.with_replacement
)
# Convert `sampled_indices` back to indices in the original `log_probs` tensor.
# shape: (batch_size, per_node_beam_size)
indices = top_k_indices.gather(-1, sampled_indices)
return log_probs.gather(1, indices), indices, state
class TopPSampler(Sampler):
"""
A `Sampler` which redistributes the probability mass function for nodes among
the top choices with a cumulative probability of at least `p`, then samples from that subset
after re-normalizing the probabilities.
Beams are sampled in the default, deterministic way.
:param p:
The cumulative probability cutoff threshold. A higher value of `p` will result in more possible
examples to sample from. If `with_replacement` is `False` and the number of possible samples is
insufficient to sample without replacement from when calling `sample_nodes`, then the top
`per_node_beam_size` examples will be chosen.
:param temperature:
A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
above 1.0 produces a flatter probability distribution.
:param with_replacement:
If set to `True`, samples will be selected with replacement from the top choices.
"""
def __init__(
self,
p: float = 0.9,
temperature: float = 1.0,
with_replacement: bool = False,
):
if p < 0.0 or p > 1.0:
raise ValueError("p must be a positive float no greater than 1.0")
self.p = p
self.temperature = temperature or 1.0
self.with_replacement = with_replacement
def sample_nodes(
self, log_probs: torch.Tensor, per_node_beam_size: int, state: StateType
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
if not per_node_beam_size <= log_probs.size()[1]:
raise ValueError("per_node_beam_size cannot be greater than vocabulary size")
# First apply temperature coefficient:
if self.temperature != 1.0:
_log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
else:
_log_probs = log_probs
# Sort the probabilities in descending order to then find cumulative sum
log_probs_descending, sorting_indices = torch.sort(_log_probs, descending=True)
# shape: (batch_size, num_classes)
probabilities_descending = log_probs_descending.exp()
probabilities_summed = torch.cumsum(probabilities_descending, dim=-1)
# Create a mask for filtering out probabilities that don't make the top `p`.
# shape: (batch_size, num_classes)
exclusion_mask = probabilities_summed >= self.p
# We want to include the first index where probabilities_summed >= p, so we shift over one.
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
exclusion_mask[..., 0] = False
# Make sure there's at least `per_node_beam_size` options to be selected.
if not self.with_replacement:
exclusion_mask[..., :per_node_beam_size] = False
log_probs_descending[exclusion_mask] = torch.finfo(log_probs.dtype).min
# Now re-normalized the included log probs.
# shape: (batch_size, num_classes)
filtered_probabilities = torch.nn.functional.softmax(log_probs_descending, dim=-1)
# Sample from the re-normalized subset.
# NOTE: These indices are not indices into `log_probs`, they are indices into `log_probs_descending`.
# shape: (batch_size, per_node_beam_size)
sampled_indices = torch.multinomial(
filtered_probabilities, per_node_beam_size, replacement=self.with_replacement
)
# Convert `sampled_indices` back to indices in the original `log_probs` tensor.
# shape: (batch_size, per_node_beam_size)
selected_indices = sorting_indices.gather(-1, sampled_indices)
# Return (selected log probabilities, selected classes)
# shape: (len(log_probs),1) , (len(log_probs), 1)
return torch.gather(log_probs, 1, selected_indices), selected_indices, state
class GumbelSampler(Sampler):
"""
A `Sampler` which uses the Gumbel-Top-K trick to sample without replacement. See
[*Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling
Sequences Without Replacement*, W Kool, H Van Hoof and M Welling, 2010]
(https://api.semanticscholar.org/CorpusID:76662039).
:param temperature: A `temperature` below 1.0 produces a sharper probability distribution and a `temperature`
above 1.0 produces a flatter probability distribution.
"""
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def init_state(
self, start_class_log_probabilities: torch.Tensor, batch_size: int, num_classes: int
) -> StateType:
# shape: (batch_size, num_classes)
zeros = start_class_log_probabilities.new_zeros((batch_size, num_classes))
# shape: (batch_size, num_classes)
G_phi_S = self.gumbel_with_max(start_class_log_probabilities, zeros)
return {"G_phi_S": G_phi_S}
def sample_nodes(
self,
log_probs: torch.Tensor,
per_node_beam_size: int,
state: StateType,
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
# First apply temperature coefficient:
# shape: (batch_size * beam_size, num_classes)
if self.temperature != 1.0:
_log_probs = torch.nn.functional.log_softmax(log_probs / self.temperature, dim=-1)
else:
_log_probs = log_probs
# shape: (group_size,)
phi_S = state["phi_S"]
# shape: (group_size, num_classes)
phi_S = phi_S.unsqueeze(-1).expand_as(_log_probs)
# shape: (group_size, num_classes)
phi_S_new = phi_S + _log_probs
# shape: (group_size, 1)
G_phi_S = state["G_phi_S"].unsqueeze(-1)
# shape: (group_size, num_classes)
G_phi_S_new = self.gumbel_with_max(phi_S_new, G_phi_S)
# Replace NaNs with very negative number.
# shape: (group_size, num_classes)
# G_phi_S_new[G_phi_S_new.isnan()] = torch.finfo(G_phi_S_new.dtype).min
# shape (both): (group_size, per_node_beam_size)
top_G_phi_S_new, top_indices = torch.topk(G_phi_S_new, per_node_beam_size, dim=-1)
# shape: (group_size, per_node_beam_size)
top_log_probs = log_probs.gather(1, top_indices)
return top_log_probs, top_indices, {"G_phi_S": top_G_phi_S_new}
def sample_beams(
self,
log_probs: torch.Tensor,
beam_size: int,
state: StateType,
) -> Tuple[torch.Tensor, torch.Tensor, StateType]:
"""
Returns the beams with the highest perturbed log probabilities.
"""
# shape (log_probs): (batch_size, beam_size * per_node_beam_size)
batch_size = log_probs.size()[0]
# shape: (batch_size * beam_size, per_node_beam_size)
G_phi_S = state["G_phi_S"]
# shape: (batch_size, beam_size * per_node_beam_size)
G_phi_S = G_phi_S.reshape_as(log_probs)
# shape (both): (batch_size, beam_size)
G_phi_S_new, selected_indices = torch.topk(G_phi_S, beam_size, dim=-1)
# shape: (batch_size, beam_size)
selected_log_probs = log_probs.gather(1, selected_indices)
# Now sort the selected beams by their true log prob.
# shape (all): (batch_size, beam_size)
selected_log_probs, sort_indices = selected_log_probs.sort(dim=-1, descending=True)
selected_indices = selected_indices.gather(1, sort_indices)
G_phi_S_new = G_phi_S_new.gather(1, sort_indices)
# shape: (batch_size * beam_size,)
G_phi_S_new = G_phi_S_new.reshape(batch_size * beam_size)
# shape: (batch_size * beam_size,)
phi_S = selected_log_probs.reshape(batch_size * beam_size)
return selected_log_probs, selected_indices, {"G_phi_S": G_phi_S_new, "phi_S": phi_S}
def gumbel(self, phi) -> torch.Tensor:
"""
Sample `Gumbel(phi)`.
`phi` should have shape `(batch_size, num_classes)`.
"""
return -torch.log(-torch.log(torch.rand_like(phi))) + phi
def gumbel_with_max(self, phi, T) -> torch.Tensor:
"""
Sample `Gumbel(phi)` conditioned on the maximum value being equal to `T`.
`phi` should have shape `(batch_size, num_classes)` and `T` should have
shape `(batch_size, 1)`.
"""
# Shape: (batch_size, num_classes)
G_phi = self.gumbel(phi)
# Now we find the maximum from these samples.
# Shape: (batch_size, )
Z, _ = G_phi.max(dim=-1)
# Shape: (batch_size, num_classes)
v = T - G_phi + torch.log1p(-torch.exp(G_phi - Z.unsqueeze(-1)))
# Shape: (batch_size, num_classes)
return T - torch.nn.functional.relu(v) - torch.log1p(torch.exp(-v.abs()))
class FinalSequenceScorer:
"""
An abstract class that can be used to score the final generated sequences found
by beam search. Given the predicted sequences and the corresponding log probabilities of
those sequences, the class calculates and returns the final score of the sequences.
The default implementation scores the sequences using the sum of the log probabilities of
the sequence, which is passed as input.
"""
@abstractmethod
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
"""
Score the final predictions found by beam search.
Returns a tensor of the final sequence scores of shape `(batch_size, beam_size)`.
:param predictions: A tensor containing the initial predictions with shape `(batch_size, beam_size, max_steps)`.
:param log_probabilities: A tensor containing the log probabilities of the sequence, defined as the sum
of the log probabilities per token, with shape `(batch_size, beam_size)`.
:param end_index: The index of the end symbol.
"""
raise NotImplementedError
class SequenceLogProbabilityScorer(FinalSequenceScorer):
"""
A :class:`FinalSequenceScorer` which scores the sequences by the sum of the log probabilities
across the sequence's tokens.
"""
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
del predictions, end_index
# The sum of the sequence log probabilities is the input parameter, so just
# return it.
return log_probabilities
class LengthNormalizedSequenceLogProbabilityScorer(FinalSequenceScorer):
"""
A :class:`FinalSequenceScorer` which scores the sequences by the average log probability of the
tokens in the sequence. It optionally includes a length penalty which promotes
or demotes sequences based on their lengths. The final score for a sequence will
be `(sequence_log_probability) / (sequence_length ** length_penalty)`. The sequence length
here includes the end token.
:param length_penalty: The length penalty to use. A value of 1.0 means no length penalty is used.
A value > 1.0 favors longer sequences, and < 1.0 favors shorter sequences.
"""
def __init__(self, length_penalty: float = 1.0):
super().__init__()
self.length_penalty = length_penalty
def score(self, predictions: torch.Tensor, log_probabilities: torch.Tensor, end_index: int) -> torch.Tensor:
# shape: (batch_size, beam_size)
lengths = (predictions != end_index).long().sum(dim=2)
# If the sequence ended during beam search, the `log_probabilities` will include
# the transition to the end token. Therefore, in such situations, `lengths` is
# actually off by 1. This corrects for that.
# shape: (batch_size, beam_size)
is_end_token = predictions[:, :, -1] == end_index
lengths += is_end_token.long()
# shape: (batch_size, beam_size)
average_log_probs = log_probabilities / (lengths**self.length_penalty)
return average_log_probs
class Constraint:
"""
An abstract class that can be used to enforce constraints on the output predictions
by manipulating the class log probabilities during beam search.
A `Constraint` just has three methods that need to be implemented by subclasses:
`init_state()`, `apply()` and `_update_state()`.
`init_state()` takes one argument:
- the batch size, an int
It returns a constraint state, which is a nested list of dictionaries, with any state needed for subsequent
calls to `apply()` and `update_state()`. The length of the outer list should be equal to `batch_size`.
Each inner list should be of length 1.
`apply()` takes two arguments:
- the constraint state, which is a nested list of dictionaries. The length of the outer list is `batch_size`
and the length of each inner list is `beam_size` except on the first time `apply()` is called when it is 1.
- `class_log_probabilities`, a tensor of shape `(batch_size, beam_size, num_classes)` that contains the
log probabilities for the classes during search. The first time `apply()` is called, `beam_size = 1`.
The `apply()` method should return new `class_log_probabilities` that enforce the constraint
for this step of beam search. For instance, it may prevent a specific class from being selected by setting
the corresponding log probability to a negligible value such as `float("-inf")` or
`torch.finfo(class_log_probabilities.dtype).min`.
`_update_state()` takes two arguments:
- the copied parent constraint state, which is a nested list of dictionaries. `state[i][j]` contains the
copied state for the parent of `last_prediction[i, j]`. It is unique to that batch and beam, so it can be
directly edited in-place without affecting the others.
- last_prediction, a tensor of shape `(batch_size, beam_size)` containing the predictions from the last
step of beam search.
The `_update_state()` function should return a new constraint state, a nested list of dictionaries of
length `batch_size` and inner list of length `beam_size`, one for each of the predictions in `last_prediction`.
"""
@abstractmethod
def init_state(
self,
batch_size: int,
) -> ConstraintStateType:
raise NotImplementedError
@abstractmethod
def apply(
self,
state: ConstraintStateType,
class_log_probabilities: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def _copy_state(
state: ConstraintStateType,
batch_size: int,
beam_size: int,
last_backpointer: Optional[torch.Tensor] = None,
) -> ConstraintStateType:
"""
Copies the `state` . This method copies the data in `state` using `copy.deepcopy()`. If this
is not appropriate for your constraint, you will need to implement the copying yourself.
"""
new_state = []
for i in range(batch_size):
batch_state = []
for j in range(beam_size):
if last_backpointer is None:
# This is the first prediction, so the backpointer is 0
backpointer = 0
else:
backpointer = last_backpointer[i, j].item()
batch_state.append(copy.deepcopy(state[i][backpointer])) # type: ignore
new_state.append(batch_state)
return new_state
def update_state(
self,
state: ConstraintStateType,
last_prediction: torch.Tensor,
last_backpointer: Optional[torch.Tensor] = None,
) -> ConstraintStateType:
batch_size, beam_size = last_prediction.size()
new_state = self._copy_state(state, batch_size, beam_size, last_backpointer)
return self._update_state(new_state, last_prediction)
@abstractmethod
def _update_state(
self,
state: ConstraintStateType,
last_prediction: torch.Tensor,
) -> ConstraintStateType:
raise NotImplementedError
class RepeatedNGramBlockingConstraint(Constraint):
def __init__(self, ngram_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.ngram_size = ngram_size
def init_state(
self,
batch_size: int,
) -> ConstraintStateType:
return [[{"seen_ngrams": {}, "current_prefix": []}] for _ in range(batch_size)]
def apply(
self,
state: ConstraintStateType,
class_log_probabilities: torch.Tensor,
) -> torch.Tensor:
for i, batch in enumerate(state):
for j, beam in enumerate(batch):
current_prefix = tuple(beam["current_prefix"])
seen_ngrams = beam["seen_ngrams"]
try:
disallowed_indices = seen_ngrams[current_prefix]
class_log_probabilities[i, j, disallowed_indices] = torch.finfo(
class_log_probabilities.dtype
).min
except KeyError:
# We have not seen this prefix before, so there is no index
# that needs to be blocked
pass
return class_log_probabilities
def _update_state(
self,
state: ConstraintStateType,
last_prediction: torch.Tensor,
) -> ConstraintStateType:
for i, batch in enumerate(state):
for j, beam in enumerate(batch):
prediction = last_prediction[i, j].item()
prefix = beam["current_prefix"]
seen_ngrams = beam["seen_ngrams"]
if len(prefix) == self.ngram_size - 1:
# This is a new ngram that we have to remember
if tuple(prefix) not in seen_ngrams:
seen_ngrams[tuple(prefix)] = []
seen_ngrams[tuple(prefix)].append(prediction)
# Create the new prefix, removing the oldest index if the prefix
# is too long
prefix.append(prediction)
if len(prefix) == self.ngram_size:
prefix.pop(0)
return state
class BeamSearch:
"""
Implements the beam search algorithm for decoding the most likely sequences.
:param end_index: The index of the "stop" or "end" token in the vocabulary. Usually the EOS token ID.
:param max_steps: The maximum number of decoding steps to take, i.e. the maximum length
of the predicted sequences.
:param beam_size: The width of the beam used.
:param per_node_beam_size: The maximum number of candidates to consider per node, at each step in the search.
If not given, this just defaults to `beam_size`. Setting this parameter
to a number smaller than `beam_size` may give better results, as it can introduce
more diversity into the search. See
[*Beam Search Strategies for Neural Machine Translation*, Freitag and Al-Onaizan, 2017]
(https://api.semanticscholar.org/CorpusID:2229477).
:param sampler: An optional `Sampler` which is used to pick next candidate nodes and beams.
If not specified, `DeterministicSampler` will be used, which just takes the
`per_node_beam_size` most likely nodes and the `beam_size` most likely beams.
Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
[Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).
:param min_steps: The minimum number of decoding steps to take, i.e. the minimum length of
the predicted sequences. This does not include the start or end tokens. If `None`,
no minimum is enforced.
:param final_sequence_scorer: An optional `FinalSequenceScorer` which is used to score the final generated sequences.
The output from this module is what is returned by the `search` method. If not
specified, `SequenceLogProbabilityScorer` will be used, which scores the sequences
by the sum of the token log probabilities.
:param constraints: An optional list of `Constraint`s which should be applied during beam search. If not
provided, no constraints will be enforced.
"""
def __init__(
self,
end_index: int,
*,
max_steps: int = 50,
beam_size: int = 10,
per_node_beam_size: Optional[int] = None,
sampler: Optional[Sampler] = None,
min_steps: Optional[int] = None,
final_sequence_scorer: Optional[FinalSequenceScorer] = None,
constraints: Optional[List[Constraint]] = None,
) -> None:
if not max_steps > 0:
raise ValueError("max_steps must be positive")
if not beam_size > 0:
raise ValueError("beam_size must be positive")
if per_node_beam_size is not None and not per_node_beam_size > 0:
raise ValueError("per_node_beam_size must be positive")
if min_steps is not None:
if not min_steps >= 0:
raise ValueError("min_steps must be non-negative")
if not min_steps <= max_steps:
raise ValueError("min_steps must be less than or equal to max_steps")
self._end_index = end_index
self.max_steps = max_steps
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size or beam_size
self.sampler = sampler or DeterministicSampler()
self.min_steps = min_steps or 0
self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer()
self.constraints = constraints or []
@staticmethod
def _reconstruct_sequences(predictions, backpointers):
# Reconstruct the sequences.
# shape: [(batch_size, beam_size, 1)]
reconstructed_predictions = [predictions[-1].unsqueeze(2)]
if not backpointers:
return reconstructed_predictions
# shape: (batch_size, beam_size)
cur_backpointers = backpointers[-1]
for timestep in range(len(predictions) - 2, 0, -1):
# shape: (batch_size, beam_size, 1)
cur_preds = predictions[timestep].gather(1, cur_backpointers).unsqueeze(2)
reconstructed_predictions.append(cur_preds)
# shape: (batch_size, beam_size)
cur_backpointers = backpointers[timestep - 1].gather(1, cur_backpointers)
# shape: (batch_size, beam_size, 1)
final_preds = predictions[0].gather(1, cur_backpointers).unsqueeze(2)
reconstructed_predictions.append(final_preds)
return reconstructed_predictions
def search(
self,
start_predictions: torch.Tensor,
start_state: StateType,
step: StepFunctionType,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a starting state and a step function, apply beam search to find the
most likely target sequences.
Returns a tuple of `(predictions, final_scores)`, where `predictions`
has shape `(batch_size, beam_size, max_steps)` and `final_scores`
has shape `(batch_size, beam_size)`.
.. note::
If your step function returns `-inf` for some log probabilities
(like if you're using a masked log-softmax) then some of the "best"
sequences returned may also have `-inf` log probability. Specifically
this happens when the beam size is smaller than the number of actions
with finite log probability (non-zero probability) returned by the step function.
Therefore if you're using a mask you may want to check the results from `search`
and potentially discard sequences with non-finite log probability.
:param start_predictions: A tensor containing the initial predictions with shape `(batch_size,)`.
Usually the initial predictions are just the index of the "start" token
in the target vocabulary.
:param start_state: The initial state passed to the `step` function. Each value of the state dict
should be a tensor of shape `(batch_size, *)`, where `*` means any other
number of dimensions.
:param step: A function that is responsible for computing the next most likely tokens,
given the current state and the predictions from the last time step.
The function should accept two or three arguments:
- a tensor of shape `(group_size,)` or representing the index of the predicted
tokens from the last time step,
- the current state, a `StateType`, and
- optionally, the timestep, an `int`.
The `group_size` will be `batch_size * beam_size`, except in the initial
step, for which it will just be `batch_size`.
The function is expected to return a tuple, where the first element
is a tensor of shape `(group_size, vocab_size)` containing
the log probabilities of the tokens for the next step, and the second
element is the updated state. The tensor in the state should have shape
`(group_size, *)`, where `*` means any other number of dimensions.
"""
step_signature = signature(step)
if len(step_signature.parameters) < 3:
# If the step function we're given does not take the time step argument, wrap it
# in one that does.
old_step = cast(StepFunctionTypeNoTimestep, step)
def new_step(last_predictions: torch.Tensor, state: Dict[str, torch.Tensor], time_step: int):
del time_step
return old_step(last_predictions, state)
return self._search(start_predictions, start_state, new_step)
else:
return self._search(start_predictions, start_state, cast(StepFunctionTypeWithTimestep, step))
def _search(
self,
start_predictions: torch.Tensor,
start_state: StateType,
step: StepFunctionTypeWithTimestep,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = start_predictions.size()[0]
# List of (batch_size, beam_size) tensors. One for each time step. Does not
# include the start symbols, which are implicit.
predictions: List[torch.Tensor] = []
# List of (batch_size, beam_size) tensors. One for each time step. None for
# the first. Stores the index n for the parent prediction, i.e.
# predictions[t-1][i][n], that it came from.
backpointers: List[torch.Tensor] = []
constraint_states = [constraint.init_state(batch_size) for constraint in self.constraints]
# Calculate the first timestep. This is done outside the main loop
# because we are going from a single decoder input (the output from the
# encoder) to the top `beam_size` decoder outputs. On the other hand,
# within the main loop we are going from the `beam_size` elements of the
# beam to `beam_size`^2 candidates from which we will select the top
# `beam_size` elements for the next iteration.
# shape: (batch_size, num_classes)
start_class_log_probabilities, state = step(start_predictions, start_state, 0)
num_classes = start_class_log_probabilities.size()[1]
# Make sure `per_node_beam_size` is not larger than `num_classes`.
if self.per_node_beam_size > num_classes:
raise ValueError(
f"Vocab size ({num_classes:d}) too small "
f"relative to per_node_beam_size ({self.per_node_beam_size:d}).\n"
f"Please decrease beam_size or per_node_beam_size."
)
sampler_state = self.sampler.init_state(start_class_log_probabilities, batch_size, num_classes)
# Apply all constraints.
if self.constraints:
# shape: (batch_size, 1, num_classes)
expanded_start_class_log_probabilities = start_class_log_probabilities.unsqueeze(1)
for constraint, constraint_state in zip(self.constraints, constraint_states):
expanded_start_class_log_probabilities = constraint.apply(
constraint_state, expanded_start_class_log_probabilities
)
start_class_log_probabilities = expanded_start_class_log_probabilities.squeeze(1)
# Prevent selecting the end symbol if there is any min_steps constraint
if self.min_steps >= 1:
start_class_log_probabilities[:, self._end_index] = torch.finfo(
start_class_log_probabilities.dtype
).min
# Get the initial predicted classed and their log probabilities.
# shape: (batch_size, beam_size), (batch_size, beam_size)
(
start_top_log_probabilities,
start_predicted_classes,
sampler_state,
) = self.sampler.sample_beams(start_class_log_probabilities, self.beam_size, sampler_state)
if self.beam_size == 1 and (start_predicted_classes == self._end_index).all():
warnings.warn(
"Empty sequences predicted. You may want to increase the beam size or ensure "
"your step function is working properly.",
RuntimeWarning,
)
return start_predicted_classes.unsqueeze(-1), start_top_log_probabilities
# The log probabilities for the last time step.
# shape: (batch_size, beam_size)
last_log_probabilities = start_top_log_probabilities
# shape: [(batch_size, beam_size)]
predictions.append(start_predicted_classes)
# Log probability tensor that mandates that the end token is selected.
# shape: (batch_size * beam_size, num_classes)
log_probs_after_end = start_class_log_probabilities.new_full(
(batch_size * self.beam_size, num_classes),
torch.finfo(start_class_log_probabilities.dtype).min,
)
log_probs_after_end[:, self._end_index] = 0.0
# Set the same state for each element in the beam.
self._update_initial_state(state, batch_size)
for i, constraint in enumerate(self.constraints):
constraint_states[i] = constraint.update_state(constraint_states[i], start_predicted_classes)
for timestep in range(self.max_steps - 1):
# shape: (batch_size * beam_size,)
last_predictions = predictions[-1].reshape(batch_size * self.beam_size)
# If every predicted token from the last step is `self._end_index`,
# then we can stop early.
if (last_predictions == self._end_index).all():
break
# Take a step. This get the predicted log probs of the next classes
# and updates the state.
# shape: (batch_size * beam_size, num_classes)
class_log_probabilities, state = step(last_predictions, state, timestep + 1)
# Apply all constraints.
if self.constraints:
# shape: (batch_size, beam_size, num_classes)
reshaped_class_log_probabilities = class_log_probabilities.view(batch_size, self.beam_size, -1)
for constraint, constraint_state in zip(self.constraints, constraint_states):
reshaped_class_log_probabilities = constraint.apply(
constraint_state, reshaped_class_log_probabilities
)
# shape: (batch_size * beam_size, num_classes)
class_log_probabilities = reshaped_class_log_probabilities.view(batch_size * self.beam_size, -1)
# The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
# of the sequence (because `timestep` is 0-indexed and we generated the first token
# before the for loop). Here we block the end index if the search is not allowed to
# terminate on this iteration.
if timestep + 2 <= self.min_steps:
class_log_probabilities[:, self._end_index] = torch.finfo(class_log_probabilities.dtype).min
# shape: (batch_size * beam_size, num_classes)
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
batch_size * self.beam_size, num_classes
)
# Here we are finding any beams where we predicted the end token in
# the previous timestep and replacing the distribution with a
# one-hot distribution, forcing the beam to predict the end token
# this timestep as well.
# shape: (batch_size * beam_size, num_classes)
cleaned_log_probabilities = torch.where(
last_predictions_expanded == self._end_index,
log_probs_after_end,
class_log_probabilities,
)
# shape (both): (batch_size * beam_size, per_node_beam_size)
top_log_probabilities, predicted_classes, sampler_state = self.sampler.sample_nodes(
cleaned_log_probabilities, self.per_node_beam_size, sampler_state
)
# Here we expand the last log probabilities to (batch_size * beam_size, per_node_beam_size)
# so that we can add them to the current log probs for this timestep.
# This lets us maintain the log probability of each element on the beam.
# shape: (batch_size * beam_size, per_node_beam_size)
expanded_last_log_probabilities = (
last_log_probabilities.unsqueeze(2)
.expand(batch_size, self.beam_size, self.per_node_beam_size)
.reshape(batch_size * self.beam_size, self.per_node_beam_size)
)
# shape: (batch_size * beam_size, per_node_beam_size)
summed_top_log_probabilities = top_log_probabilities + expanded_last_log_probabilities
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_summed = summed_top_log_probabilities.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# shape: (batch_size, beam_size * per_node_beam_size)
reshaped_predicted_classes = predicted_classes.reshape(
batch_size, self.beam_size * self.per_node_beam_size
)
# Keep only the top `beam_size` beam indices.
# shape (both): (batch_size, beam_size)
(
restricted_beam_log_probs,
restricted_beam_indices,
sampler_state,
) = self.sampler.sample_beams(reshaped_summed, self.beam_size, sampler_state)
# Use the beam indices to extract the corresponding classes.
# shape: (batch_size, beam_size)
restricted_predicted_classes = reshaped_predicted_classes.gather(1, restricted_beam_indices)
predictions.append(restricted_predicted_classes)
# shape: (batch_size, beam_size)
last_log_probabilities = restricted_beam_log_probs
# The beam indices come from a `beam_size * per_node_beam_size` dimension where the
# indices with a common ancestor are grouped together. Hence
# dividing by per_node_beam_size gives the ancestor. (Note that this is integer
# division as the tensor is a LongTensor.)
# shape: (batch_size, beam_size)
backpointer = torch.divide(restricted_beam_indices, self.per_node_beam_size, rounding_mode="trunc")
backpointers.append(backpointer)
# Keep only the pieces of the state tensors corresponding to the
# ancestors created this iteration.
self._update_state(state, backpointer)
for i, constraint in enumerate(self.constraints):
constraint_states[i] = constraint.update_state(
constraint_states[i], restricted_predicted_classes, last_backpointer=backpointer
)
# Warn about "-inf" log probabilities if not using any constraints (negligible
# log probabilities are expected when using constraints).
if not self.constraints and (
not torch.isfinite(last_log_probabilities).all()
or (last_log_probabilities == torch.finfo(last_log_probabilities.dtype).min).any()
):
warnings.warn(
"Negligible log probabilities encountered ('-inf' or equivalent). "
"Some final sequences may not make sense. "
"This can happen when the beam size is larger than the number of valid (non-zero "
"probability) transitions that the step function produces.",
RuntimeWarning,
)
reconstructed_predictions = self._reconstruct_sequences(predictions, backpointers)
# shape: (batch_size, beam_size, max_steps)
all_predictions = torch.cat(list(reversed(reconstructed_predictions)), 2)
# Calculate the final sequence scores
# shape: (batch_size, beam_size)
final_scores = self.final_sequence_scorer.score(all_predictions, last_log_probabilities, self._end_index)
# Sort the sequences based on the final scores so the best scoring
# sequence is at index 0
sorted_final_scores, sorted_indices = torch.sort(final_scores, dim=1, descending=True)
sorted_all_predictions = torch.gather(
all_predictions, 1, sorted_indices.unsqueeze(-1).expand_as(all_predictions)
)
return sorted_all_predictions, sorted_final_scores
def _update_initial_state(self, state: StateType, batch_size: int):
"""
Expand tensors in a state dictionary from `(batch_size, *)` to `(batch_size * beam_size, *)`.
"""
for key, state_tensor in state.items():
if state_tensor is None:
continue
# shape: (batch_size * beam_size, *)
_, *last_dims = state_tensor.size()
state[key] = (
state_tensor.unsqueeze(1)
.expand(batch_size, self.beam_size, *last_dims)
.reshape(batch_size * self.beam_size, *last_dims)
)
def _update_state(self, state: StateType, backpointer: torch.Tensor):
batch_size = backpointer.size()[0]
for key, state_tensor in state.items():
if state_tensor is None:
continue
_, *last_dims = state_tensor.size()
# shape: (batch_size, beam_size, *)
expanded_backpointer = backpointer.view(batch_size, self.beam_size, *([1] * len(last_dims))).expand(
batch_size, self.beam_size, *last_dims
)
# shape: (batch_size * beam_size, *)
state[key] = (
state_tensor.reshape(batch_size, self.beam_size, *last_dims)
.gather(1, expanded_backpointer)
.reshape(batch_size * self.beam_size, *last_dims)
)
|