Spaces:
Paused
Paused
File size: 4,948 Bytes
bfb4432 7c0d92c bfb4432 f75feb2 bfb4432 f75feb2 bfb4432 d2e0b39 bfb4432 f75feb2 bfb4432 d2e0b39 bfb4432 f75feb2 bfb4432 f75feb2 bfb4432 f75feb2 bfb4432 f75feb2 7c0d92c f75feb2 bfb4432 f75feb2 bfb4432 7c0d92c bfb4432 f75feb2 bfb4432 b990ec8 bfb4432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from pydantic.dataclasses import dataclass
from typing import Optional
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
from story_beam_search.scoring import StoryScorer
@dataclass
class BeamSearchConfig:
num_beams: int = 4
num_return_sequences: int = 2
max_length: int = 100
no_repeat_ngram_size: int = 2
temperature: float = 0.8
top_k: int = 8
top_p: float = 0.95
num_iterations: int = 3
continuation_length: int = 10
class BeamSearchGenerator:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: torch.device,
config: Optional[BeamSearchConfig] = None,
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.config = config or BeamSearchConfig()
def generate_iterations(
self, prompt: str, genre: str, evaluator: StoryScorer
) -> list[str]:
"""
Generate story continuations using parallel beam search iterations.
"""
instructions = (
f"Continue the following story in the {genre} genre, "
"ensuring coherence with the tone, characters, and narrative established so far:\n"
)
instructions_len = len(instructions)
stories = self._generate_batch([instructions + prompt])
ranked_stories = evaluator.evaluate_multiple(
[story[instructions_len:] for story in stories]
)
stories = [story for story, _ in ranked_stories[: self.config.num_beams]]
if stories:
for _ in range(self.config.num_iterations):
# Prepare all prompts for batch processing
all_prompts = [instructions + story for story in stories]
# Generate all continuations in one batch
all_stories = self._generate_batch(all_prompts)
ranked_stories = evaluator.evaluate_multiple(
[story[instructions_len:] for story in all_stories]
)
stories = [
story for story, _ in ranked_stories[: self.config.num_beams]
]
return stories
def _generate_batch(self, prompts: list[str]) -> list[str]:
"""
Generate multiple continuations for multiple prompts in a single batch.
"""
# Tokenize all prompts
tokenized = [self.tokenizer(prompt, return_tensors="pt") for prompt in prompts]
# Pad input_ids and attention_masks to same length
max_length = max(inputs["input_ids"].size(1) for inputs in tokenized)
padded_input_ids = []
padded_attention_masks = []
for inputs in tokenized:
input_ids = inputs["input_ids"][0]
attention_mask = inputs["attention_mask"][0]
# Pad to max_length
padding_length = max_length - input_ids.size(0)
if padding_length > 0:
input_ids = torch.cat(
[input_ids, torch.zeros(padding_length, dtype=torch.long)]
)
attention_mask = torch.cat(
[attention_mask, torch.zeros(padding_length, dtype=torch.long)]
)
padded_input_ids.append(input_ids)
padded_attention_masks.append(attention_mask)
# Stack into batches
input_ids_batch = torch.stack(padded_input_ids).to(self.device)
attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)
# Calculate continuation length
# we want this length, times the num_iterations, to be roughly the max_length set by the user.
continuation_length = (
max_length + self.config.max_length // self.config.num_iterations
)
# Generate all continuations in one pass
with torch.no_grad():
# Technically speaking, this generation is also using beam search at the token level
# in this case though, I'm using it to generate multiple sequences at once and evaluate them
# not by token probability, but my custom metrics.
outputs = self.model.generate(
input_ids=input_ids_batch,
attention_mask=attention_mask_batch,
max_length=continuation_length,
num_beams=self.config.num_beams,
num_return_sequences=self.config.num_return_sequences,
early_stopping=True,
no_repeat_ngram_size=self.config.no_repeat_ngram_size,
temperature=self.config.temperature,
top_k=self.config.top_k,
top_p=self.config.top_p,
do_sample=True,
).to(self.device)
stories = []
for output in outputs:
text = self.tokenizer.decode(output, skip_special_tokens=True)
stories.append(text)
return stories
|