ssalb commited on
Commit
7c0d92c
·
1 Parent(s): 16746e5

Update space with latest code and dependencies on Mon Jan 6 09:01:27 UTC 2025

Browse files
Files changed (1) hide show
  1. story_beam_search/beam_search.py +6 -2
story_beam_search/beam_search.py CHANGED
@@ -8,8 +8,8 @@ from story_beam_search.scoring import StoryScorer
8
 
9
  @dataclass
10
  class BeamSearchConfig:
11
- num_beams: int = 3
12
- num_return_sequences: int = 3
13
  max_length: int = 100
14
  no_repeat_ngram_size: int = 2
15
  temperature: float = 0.8
@@ -100,12 +100,16 @@ class BeamSearchGenerator:
100
  attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
101
 
102
  # Calculate continuation length
 
103
  continuation_length = (
104
  max_length + self.config.max_length // self.config.num_iterations
105
  )
106
 
107
  # Generate all continuations in one pass
108
  with torch.no_grad():
 
 
 
109
  outputs = self.model.generate(
110
  input_ids=input_ids_batch,
111
  attention_mask=attention_mask_batch,
 
8
 
9
  @dataclass
10
  class BeamSearchConfig:
11
+ num_beams: int = 4
12
+ num_return_sequences: int = 2
13
  max_length: int = 100
14
  no_repeat_ngram_size: int = 2
15
  temperature: float = 0.8
 
100
  attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
101
 
102
  # Calculate continuation length
103
+ # we want this length, times the num_iterations, to be roughly the max_length set by the user.
104
  continuation_length = (
105
  max_length + self.config.max_length // self.config.num_iterations
106
  )
107
 
108
  # Generate all continuations in one pass
109
  with torch.no_grad():
110
+ # Technically speaking, this generation is also using beam search at the token level
111
+ # in this case though, I'm using it to generate multiple sequences at once and evaluate them
112
+ # not by token probability, but my custom metrics.
113
  outputs = self.model.generate(
114
  input_ids=input_ids_batch,
115
  attention_mask=attention_mask_batch,