Spaces:
Paused
Paused
ssalb
commited on
Commit
·
7c0d92c
1
Parent(s):
16746e5
Update space with latest code and dependencies on Mon Jan 6 09:01:27 UTC 2025
Browse files
story_beam_search/beam_search.py
CHANGED
@@ -8,8 +8,8 @@ from story_beam_search.scoring import StoryScorer
|
|
8 |
|
9 |
@dataclass
|
10 |
class BeamSearchConfig:
|
11 |
-
num_beams: int =
|
12 |
-
num_return_sequences: int =
|
13 |
max_length: int = 100
|
14 |
no_repeat_ngram_size: int = 2
|
15 |
temperature: float = 0.8
|
@@ -100,12 +100,16 @@ class BeamSearchGenerator:
|
|
100 |
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
|
101 |
|
102 |
# Calculate continuation length
|
|
|
103 |
continuation_length = (
|
104 |
max_length + self.config.max_length // self.config.num_iterations
|
105 |
)
|
106 |
|
107 |
# Generate all continuations in one pass
|
108 |
with torch.no_grad():
|
|
|
|
|
|
|
109 |
outputs = self.model.generate(
|
110 |
input_ids=input_ids_batch,
|
111 |
attention_mask=attention_mask_batch,
|
|
|
8 |
|
9 |
@dataclass
|
10 |
class BeamSearchConfig:
|
11 |
+
num_beams: int = 4
|
12 |
+
num_return_sequences: int = 2
|
13 |
max_length: int = 100
|
14 |
no_repeat_ngram_size: int = 2
|
15 |
temperature: float = 0.8
|
|
|
100 |
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
|
101 |
|
102 |
# Calculate continuation length
|
103 |
+
# we want this length, times the num_iterations, to be roughly the max_length set by the user.
|
104 |
continuation_length = (
|
105 |
max_length + self.config.max_length // self.config.num_iterations
|
106 |
)
|
107 |
|
108 |
# Generate all continuations in one pass
|
109 |
with torch.no_grad():
|
110 |
+
# Technically speaking, this generation is also using beam search at the token level
|
111 |
+
# in this case though, I'm using it to generate multiple sequences at once and evaluate them
|
112 |
+
# not by token probability, but my custom metrics.
|
113 |
outputs = self.model.generate(
|
114 |
input_ids=input_ids_batch,
|
115 |
attention_mask=attention_mask_batch,
|