ssalb
Update space with latest code and dependencies on Mon Jan 6 09:01:27 UTC 2025
7c0d92c
from pydantic.dataclasses import dataclass
from typing import Optional
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from story_beam_search.scoring import StoryScorer
@dataclass
class BeamSearchConfig:
num_beams: int = 4
num_return_sequences: int = 2
max_length: int = 100
no_repeat_ngram_size: int = 2
temperature: float = 0.8
top_k: int = 8
top_p: float = 0.95
num_iterations: int = 3
continuation_length: int = 10
class BeamSearchGenerator:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: torch.device,
config: Optional[BeamSearchConfig] = None,
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.config = config or BeamSearchConfig()
def generate_iterations(
self, prompt: str, genre: str, evaluator: StoryScorer
) -> list[str]:
"""
Generate story continuations using parallel beam search iterations.
"""
instructions = (
f"Continue the following story in the {genre} genre, "
"ensuring coherence with the tone, characters, and narrative established so far:\n"
)
instructions_len = len(instructions)
stories = self._generate_batch([instructions + prompt])
ranked_stories = evaluator.evaluate_multiple(
[story[instructions_len:] for story in stories]
)
stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
if stories:
for _ in range(self.config.num_iterations):
# Prepare all prompts for batch processing
all_prompts = [instructions + story for story in stories]
# Generate all continuations in one batch
all_stories = self._generate_batch(all_prompts)
ranked_stories = evaluator.evaluate_multiple(
[story[instructions_len:] for story in all_stories]
)
stories = [
story for story, _ in ranked_stories[: self.config.num_beams]
]
return stories
def _generate_batch(self, prompts: list[str]) -> list[str]:
"""
Generate multiple continuations for multiple prompts in a single batch.
"""
# Tokenize all prompts
tokenized = [self.tokenizer(prompt, return_tensors="pt") for prompt in prompts]
# Pad input_ids and attention_masks to same length
max_length = max(inputs["input_ids"].size(1) for inputs in tokenized)
padded_input_ids = []
padded_attention_masks = []
for inputs in tokenized:
input_ids = inputs["input_ids"][0]
attention_mask = inputs["attention_mask"][0]
# Pad to max_length
padding_length = max_length - input_ids.size(0)
if padding_length > 0:
input_ids = torch.cat(
[input_ids, torch.zeros(padding_length, dtype=torch.long)]
)
attention_mask = torch.cat(
[attention_mask, torch.zeros(padding_length, dtype=torch.long)]
)
padded_input_ids.append(input_ids)
padded_attention_masks.append(attention_mask)
# Stack into batches
input_ids_batch = torch.stack(padded_input_ids).to(self.device)
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
# Calculate continuation length
# we want this length, times the num_iterations, to be roughly the max_length set by the user.
continuation_length = (
max_length + self.config.max_length // self.config.num_iterations
)
# Generate all continuations in one pass
with torch.no_grad():
# Technically speaking, this generation is also using beam search at the token level
# in this case though, I'm using it to generate multiple sequences at once and evaluate them
# not by token probability, but my custom metrics.
outputs = self.model.generate(
input_ids=input_ids_batch,
attention_mask=attention_mask_batch,
max_length=continuation_length,
num_beams=self.config.num_beams,
num_return_sequences=self.config.num_return_sequences,
early_stopping=True,
no_repeat_ngram_size=self.config.no_repeat_ngram_size,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
do_sample=True,
).to(self.device)
stories = []
for output in outputs:
text = self.tokenizer.decode(output, skip_special_tokens=True)
stories.append(text)
return stories