File size: 4,948 Bytes
bfb4432
 
 
 
 
 
 
 
 
 
7c0d92c
 
bfb4432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75feb2
bfb4432
 
 
 
 
 
 
f75feb2
bfb4432
 
 
d2e0b39
bfb4432
 
 
f75feb2
 
 
 
 
bfb4432
 
 
d2e0b39
 
 
bfb4432
 
 
f75feb2
bfb4432
f75feb2
bfb4432
f75feb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb4432
f75feb2
7c0d92c
f75feb2
 
bfb4432
 
f75feb2
bfb4432
7c0d92c
 
 
bfb4432
f75feb2
 
 
bfb4432
 
 
 
 
 
 
 
b990ec8
bfb4432
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from pydantic.dataclasses import dataclass
from typing import Optional
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer

from story_beam_search.scoring import StoryScorer


@dataclass
class BeamSearchConfig:
    num_beams: int = 4
    num_return_sequences: int = 2
    max_length: int = 100
    no_repeat_ngram_size: int = 2
    temperature: float = 0.8
    top_k: int = 8
    top_p: float = 0.95
    num_iterations: int = 3
    continuation_length: int = 10


class BeamSearchGenerator:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        device: torch.device,
        config: Optional[BeamSearchConfig] = None,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        self.config = config or BeamSearchConfig()

    def generate_iterations(
        self, prompt: str, genre: str, evaluator: StoryScorer
    ) -> list[str]:
        """
        Generate story continuations using parallel beam search iterations.
        """
        instructions = (
            f"Continue the following story in the {genre} genre, "
            "ensuring coherence with the tone, characters, and narrative established so far:\n"
        )
        instructions_len = len(instructions)

        stories = self._generate_batch([instructions + prompt])
        ranked_stories = evaluator.evaluate_multiple(
            [story[instructions_len:] for story in stories]
        )
        stories = [story for story, _ in ranked_stories[: self.config.num_beams]]

        if stories:
            for _ in range(self.config.num_iterations):
                # Prepare all prompts for batch processing
                all_prompts = [instructions + story for story in stories]
                # Generate all continuations in one batch
                all_stories = self._generate_batch(all_prompts)

                ranked_stories = evaluator.evaluate_multiple(
                    [story[instructions_len:] for story in all_stories]
                )
                stories = [
                    story for story, _ in ranked_stories[: self.config.num_beams]
                ]

        return stories

    def _generate_batch(self, prompts: list[str]) -> list[str]:
        """
        Generate multiple continuations for multiple prompts in a single batch.
        """
        # Tokenize all prompts
        tokenized = [self.tokenizer(prompt, return_tensors="pt") for prompt in prompts]

        # Pad input_ids and attention_masks to same length
        max_length = max(inputs["input_ids"].size(1) for inputs in tokenized)
        padded_input_ids = []
        padded_attention_masks = []

        for inputs in tokenized:
            input_ids = inputs["input_ids"][0]
            attention_mask = inputs["attention_mask"][0]

            # Pad to max_length
            padding_length = max_length - input_ids.size(0)
            if padding_length > 0:
                input_ids = torch.cat(
                    [input_ids, torch.zeros(padding_length, dtype=torch.long)]
                )
                attention_mask = torch.cat(
                    [attention_mask, torch.zeros(padding_length, dtype=torch.long)]
                )

            padded_input_ids.append(input_ids)
            padded_attention_masks.append(attention_mask)

        # Stack into batches
        input_ids_batch = torch.stack(padded_input_ids).to(self.device)
        attention_mask_batch = torch.stack(padded_attention_masks).to(self.device)

        # Calculate continuation length
        # we want this length, times the num_iterations, to be roughly the max_length set by the user.
        continuation_length = (
            max_length + self.config.max_length // self.config.num_iterations
        )

        # Generate all continuations in one pass
        with torch.no_grad():
            # Technically speaking, this generation is also using beam search at the token level
            # in this case though, I'm using it to generate multiple sequences at once and evaluate them
            # not by token probability, but my custom metrics.
            outputs = self.model.generate(
                input_ids=input_ids_batch,
                attention_mask=attention_mask_batch,
                max_length=continuation_length,
                num_beams=self.config.num_beams,
                num_return_sequences=self.config.num_return_sequences,
                early_stopping=True,
                no_repeat_ngram_size=self.config.no_repeat_ngram_size,
                temperature=self.config.temperature,
                top_k=self.config.top_k,
                top_p=self.config.top_p,
                do_sample=True,
            ).to(self.device)

        stories = []
        for output in outputs:
            text = self.tokenizer.decode(output, skip_special_tokens=True)
            stories.append(text)

        return stories