Spaces:
Paused
Paused
ssalb
commited on
Commit
·
16746e5
1
Parent(s):
9db3ef0
Update space with latest code and dependencies on Sat Jan 4 09:51:32 UTC 2025
Browse files
story_beam_search/scoring.py
CHANGED
@@ -137,10 +137,8 @@ class FluencyScorer(StoryScorer):
|
|
137 |
# For each story in the batch
|
138 |
for j in range(len(batch_stories)):
|
139 |
story_scores = []
|
140 |
-
input_ids = batch_inputs.input_ids[j : j + 1]
|
141 |
-
attention_mask = batch_inputs.attention_mask[
|
142 |
-
j : j + 1
|
143 |
-
] # Get attention mask
|
144 |
|
145 |
# Only process tokens that aren't padding
|
146 |
valid_tokens = attention_mask[0].sum().item()
|
@@ -150,6 +148,11 @@ class FluencyScorer(StoryScorer):
|
|
150 |
masked_input_ids = input_ids.clone()
|
151 |
masked_input_ids[0, k] = mask_token_id
|
152 |
|
|
|
|
|
|
|
|
|
|
|
153 |
with torch.no_grad():
|
154 |
outputs = self.model(
|
155 |
input_ids=masked_input_ids, attention_mask=attention_mask
|
|
|
137 |
# For each story in the batch
|
138 |
for j in range(len(batch_stories)):
|
139 |
story_scores = []
|
140 |
+
input_ids = batch_inputs.input_ids[j : j + 1]
|
141 |
+
attention_mask = batch_inputs.attention_mask[j : j + 1]
|
|
|
|
|
142 |
|
143 |
# Only process tokens that aren't padding
|
144 |
valid_tokens = attention_mask[0].sum().item()
|
|
|
148 |
masked_input_ids = input_ids.clone()
|
149 |
masked_input_ids[0, k] = mask_token_id
|
150 |
|
151 |
+
# Ensure token is within vocab range
|
152 |
+
masked_input_ids = masked_input_ids.clamp(
|
153 |
+
0, self.tokenizer.vocab_size - 1
|
154 |
+
)
|
155 |
+
|
156 |
with torch.no_grad():
|
157 |
outputs = self.model(
|
158 |
input_ids=masked_input_ids, attention_mask=attention_mask
|