ssalb commited on
Commit
16746e5
·
1 Parent(s): 9db3ef0

Update space with latest code and dependencies on Sat Jan 4 09:51:32 UTC 2025

Browse files
Files changed (1) hide show
  1. story_beam_search/scoring.py +7 -4
story_beam_search/scoring.py CHANGED
@@ -137,10 +137,8 @@ class FluencyScorer(StoryScorer):
137
  # For each story in the batch
138
  for j in range(len(batch_stories)):
139
  story_scores = []
140
- input_ids = batch_inputs.input_ids[j : j + 1] # Keep batch dimension
141
- attention_mask = batch_inputs.attention_mask[
142
- j : j + 1
143
- ] # Get attention mask
144
 
145
  # Only process tokens that aren't padding
146
  valid_tokens = attention_mask[0].sum().item()
@@ -150,6 +148,11 @@ class FluencyScorer(StoryScorer):
150
  masked_input_ids = input_ids.clone()
151
  masked_input_ids[0, k] = mask_token_id
152
 
 
 
 
 
 
153
  with torch.no_grad():
154
  outputs = self.model(
155
  input_ids=masked_input_ids, attention_mask=attention_mask
 
137
  # For each story in the batch
138
  for j in range(len(batch_stories)):
139
  story_scores = []
140
+ input_ids = batch_inputs.input_ids[j : j + 1]
141
+ attention_mask = batch_inputs.attention_mask[j : j + 1]
 
 
142
 
143
  # Only process tokens that aren't padding
144
  valid_tokens = attention_mask[0].sum().item()
 
148
  masked_input_ids = input_ids.clone()
149
  masked_input_ids[0, k] = mask_token_id
150
 
151
+ # Ensure token is within vocab range
152
+ masked_input_ids = masked_input_ids.clamp(
153
+ 0, self.tokenizer.vocab_size - 1
154
+ )
155
+
156
  with torch.no_grad():
157
  outputs = self.model(
158
  input_ids=masked_input_ids, attention_mask=attention_mask