Spaces:
Paused
Paused
from pydantic.dataclasses import dataclass | |
from typing import Optional | |
import torch | |
from transformers import PreTrainedModel, PreTrainedTokenizer | |
from story_beam_search.scoring import StoryScorer | |
class BeamSearchConfig: | |
num_beams: int = 4 | |
num_return_sequences: int = 2 | |
max_length: int = 100 | |
no_repeat_ngram_size: int = 2 | |
temperature: float = 0.8 | |
top_k: int = 8 | |
top_p: float = 0.95 | |
num_iterations: int = 3 | |
continuation_length: int = 10 | |
class BeamSearchGenerator: | |
def __init__( | |
self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer, | |
device: torch.device, | |
config: Optional[BeamSearchConfig] = None, | |
): | |
self.model = model | |
self.tokenizer = tokenizer | |
self.device = device | |
self.config = config or BeamSearchConfig() | |
def generate_iterations( | |
self, prompt: str, genre: str, evaluator: StoryScorer | |
) -> list[str]: | |
""" | |
Generate story continuations using parallel beam search iterations. | |
""" | |
instructions = ( | |
f"Continue the following story in the {genre} genre, " | |
"ensuring coherence with the tone, characters, and narrative established so far:\n" | |
) | |
instructions_len = len(instructions) | |
stories = self._generate_batch([instructions + prompt]) | |
ranked_stories = evaluator.evaluate_multiple( | |
[story[instructions_len:] for story in stories] | |
) | |
stories = [story for story, _ in ranked_stories[: self.config.num_beams]] | |
if stories: | |
for _ in range(self.config.num_iterations): | |
# Prepare all prompts for batch processing | |
all_prompts = [instructions + story for story in stories] | |
# Generate all continuations in one batch | |
all_stories = self._generate_batch(all_prompts) | |
ranked_stories = evaluator.evaluate_multiple( | |
[story[instructions_len:] for story in all_stories] | |
) | |
stories = [ | |
story for story, _ in ranked_stories[: self.config.num_beams] | |
] | |
return stories | |
def _generate_batch(self, prompts: list[str]) -> list[str]: | |
""" | |
Generate multiple continuations for multiple prompts in a single batch. | |
""" | |
# Tokenize all prompts | |
tokenized = [self.tokenizer(prompt, return_tensors="pt") for prompt in prompts] | |
# Pad input_ids and attention_masks to same length | |
max_length = max(inputs["input_ids"].size(1) for inputs in tokenized) | |
padded_input_ids = [] | |
padded_attention_masks = [] | |
for inputs in tokenized: | |
input_ids = inputs["input_ids"][0] | |
attention_mask = inputs["attention_mask"][0] | |
# Pad to max_length | |
padding_length = max_length - input_ids.size(0) | |
if padding_length > 0: | |
input_ids = torch.cat( | |
[input_ids, torch.zeros(padding_length, dtype=torch.long)] | |
) | |
attention_mask = torch.cat( | |
[attention_mask, torch.zeros(padding_length, dtype=torch.long)] | |
) | |
padded_input_ids.append(input_ids) | |
padded_attention_masks.append(attention_mask) | |
# Stack into batches | |
input_ids_batch = torch.stack(padded_input_ids).to(self.device) | |
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device) | |
# Calculate continuation length | |
# we want this length, times the num_iterations, to be roughly the max_length set by the user. | |
continuation_length = ( | |
max_length + self.config.max_length // self.config.num_iterations | |
) | |
# Generate all continuations in one pass | |
with torch.no_grad(): | |
# Technically speaking, this generation is also using beam search at the token level | |
# in this case though, I'm using it to generate multiple sequences at once and evaluate them | |
# not by token probability, but my custom metrics. | |
outputs = self.model.generate( | |
input_ids=input_ids_batch, | |
attention_mask=attention_mask_batch, | |
max_length=continuation_length, | |
num_beams=self.config.num_beams, | |
num_return_sequences=self.config.num_return_sequences, | |
early_stopping=True, | |
no_repeat_ngram_size=self.config.no_repeat_ngram_size, | |
temperature=self.config.temperature, | |
top_k=self.config.top_k, | |
top_p=self.config.top_p, | |
do_sample=True, | |
).to(self.device) | |
stories = [] | |
for output in outputs: | |
text = self.tokenizer.decode(output, skip_special_tokens=True) | |
stories.append(text) | |
return stories | |