File size: 8,885 Bytes
f4a41d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
""""This module provides a function to get all"""
from dataclasses import dataclass
from typing import Union, List, Tuple
import re
import os
from scripts.Logger import Logger
from scripts.Utils import Utils

import modules

import modules.shared as shared

@dataclass()
class BatchParams:
    """Dataclass to store the parameters for a batch
    
    Args:
        checkpoint (str): the checkpoint name
        prompt (str): the prompt
        hr_prompt (str): the hires. fix prompt
        neg_prompt (str): the negative prompt
        style (List[str]): the style (A1111 styles)
        batch_count (int, optional): the batch count. Defaults to -1. (don't overwrite the UI value)
        clip_skip (int, optional): the clip skip. Defaults to 1.
        width (int, optional): the width. Defaults to -1. (don't overwrite the UI value)
        height (int, optional): the height. Defaults to -1. (don't overwrite the UI value)
    """
    checkpoint: str
    prompt: str
    hr_prompt: str
    neg_prompt: str
    style : List[str]
    batch_count: int = -1
    clip_skip: int = 1
    width: int = -1
    height: int = -1

    def __repr__(self) -> str:
        checkpointName: str = os.path.basename(self.checkpoint)
        return( f"BatchParams: {checkpointName},\n " 
               f"prompt: {self.prompt},\n"
                f"style: {self.style},\n"
               f"neg_prompt: {self.neg_prompt},\n "
               f"batch_count: {self.batch_count},\n "
               f"clip_skip: {self.clip_skip}\n"
               f"size: {self.width}x{self.height}")

logger = Logger()

def get_all_batch_params(p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], checkpoints_as_string: str, prompts_as_string: str) -> List[BatchParams]:
    """Get all batch parameters from the input

    Args:
        p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object
        checkpoints_as_string (str): the checkpoints as string
        prompts_as_string (str): the prompts as string

    Returns:
        List[BatchParams]: the batch parameters
    """    

    def getRegexFromOpts(key: str, search_for_number: bool = True) -> Tuple[str, str]:
        """Get the regex from the options. As the user can change the regex, 
        it is checked if the regex is valid.

        Args:
            key (str): the key
            search_for_number (bool, optional): If true checks if the regex is valid. Defaults to True.

        Returns:
            Tuple[str, str]: the search pattern and the sub pattern
        """
        sub_pattern = getattr(shared.opts, key)
        search_pattern = sub_pattern.replace("[", "([").replace("]", "])")

        if not re.search(r"\[0-9\]\+|\\d\+", sub_pattern) and search_for_number:
            raise RuntimeError(f'Can\'t find a number with the regex for {key}: "{sub_pattern}"')
        
        return search_pattern, sub_pattern

    utils = Utils()

    def get_batch_count_from_prompt(prompt: str) -> Tuple[int, str]:
        """Extracts the batch count from the prompt if specified, else uses the default value

        Args:
            prompt (str): the prompt

        Returns:
            Tuple[int, str]: the batch count and the prompt
        """
        search_pattern, sub_pattern = getRegexFromOpts("batchCountRegex")
        number_match = re.search(search_pattern, prompt)
        if number_match and number_match.group(1):
            # Extract the number from the match object
            number = int(number_match.group(1))  # Use group(1) to get the number inside parentheses
            number = p.n_iter if number < 1 else number
            prompt = re.sub(sub_pattern, '', prompt)
        else:
            number = p.n_iter


        return number, prompt
    
    def get_clip_skip_from_prompt(prompt: str) -> Tuple[int, str]:
        """Extracts the clip skip from the prompt if specified, else uses the default value

        Args:
            prompt (str): the prompt

        Returns:
            Tuple[int, str]: the clip skip and the prompt
        """
        search_pattern, sub_pattern = getRegexFromOpts("clipSkipRegex")
        number_match = re.search(search_pattern, prompt)
        if number_match and number_match.group(1):
            # Extract the number from the match object
            number = int(number_match.group(1))
            number = shared.opts.data["CLIP_stop_at_last_layers"] if number < 1 else number
            prompt = (
                re.sub(sub_pattern, '', prompt))
        else:
            number = shared.opts.data["CLIP_stop_at_last_layers"]


        return number, prompt
    
    def get_style_from_prompt(prompt: str) -> Tuple[List[str], str]:
        """Extracts the style from the prompt if specified.

        Args:
            prompt (str): the prompt

        Returns:
            Tuple[List[str], str]: the styles and the prompt
        """
        styles = []
        search_pattern, sub_pattern = getRegexFromOpts("styleRegex", False)
        style_matches = re.findall(search_pattern, prompt)
        if style_matches:
            for i, stl in enumerate(style_matches):
                styles.append(stl)
                _, prompt_regex = getRegexFromOpts("promptRegex", False)
                replacement = prompt_regex if i == len(style_matches) - 1 else ""
                prompt = re.sub(sub_pattern, replacement, prompt, count=1)

            logger.debug_log(f"nr.: {i}, prompt: {prompt}", False)

        return styles, prompt
    
    def get_image_size_from_prompt(prompt: str) -> Tuple[int, int, str]:
        """Extracts the image size from the prompt if specified, else uses the default value

        Args:
            prompt (str): the prompt

        Returns:
            Tuple[int, int, str]: the width, height and the prompt.
            If the width and height are not specified, -1 is returned.
        """
        search_pattern, sub_pattern = getRegexFromOpts("widthHeightRegex", False) #{{size:[0-9]+x[0-9]+}}
        number_matches = re.search(search_pattern, prompt)
        if number_matches:
            try:
                width, height = map(int, number_matches.groups())
            except ValueError:
                raise RuntimeError(f"Can't convert the image size to an integer: {number_matches[0]}")
            prompt = re.sub(sub_pattern, '', prompt)
        else:
            width, height = -1, -1

        return width, height, prompt

    def split_postive_and_negative_postive_prompt(prompt: str) -> Tuple[str, str]:
        """Splits the prompt into a positive and negative prompt.
        If a negative prompt is specified.

        Args:
            prompt (str): the prompt

        Returns:
            Tuple[str, str]: the positive and negative prompt
        """
        pattern = getattr(shared.opts, "negPromptRegex")
        parts = re.split(pattern, prompt)
        if len(parts) > 1:
            neg_prompt = parts[1]
        else:
            neg_prompt = ""

        prompt = parts[0]

        return prompt, neg_prompt


    all_batch_params: List[BatchParams] = []

    checkpoints: List[str] = utils.getCheckpointListFromInput(checkpoints_as_string)


    prompts: List[str] = utils.remove_index_from_string(prompts_as_string).split(";")
    prompts = [prompt.replace('\n', '').strip() for prompt in prompts if not prompt.isspace() and prompt != '']

    if len(prompts) != len(checkpoints):
        logger.debug_log(f"len prompt: {len(prompts)}, len checkpoints{len(checkpoints)}")
        raise RuntimeError("amount of prompts don't match with amount of checkpoints")

    if len(prompts) == 0:
        raise RuntimeError("can't run without a checkpoint and prompt")
    
    
    for i in range(len(checkpoints)):

        info = modules.sd_models.get_closet_checkpoint_match(checkpoints[i])
        if info is None:
            raise RuntimeError(f"Unknown checkpoint: {checkpoints[i]}")


        batch_count, prompts[i] = get_batch_count_from_prompt(prompts[i])
        clip_skip, prompts[i] = get_clip_skip_from_prompt(prompts[i])
        style, prompts[i] = get_style_from_prompt(prompts[i])
        width, height, prompts[i] = get_image_size_from_prompt(prompts[i])
        prompt_template, neg_prompt = split_postive_and_negative_postive_prompt(prompts[i])


        _, prompt_regex = getRegexFromOpts("promptRegex", False)

        prompt = prompt_template.replace(prompt_regex, p.prompt)
        hr_prompt = prompt_template.replace(prompt_regex, p.hr_prompt)
        neg_prompt = p.negative_prompt + neg_prompt


        all_batch_params.append(BatchParams(checkpoints[i], prompt,hr_prompt, neg_prompt, style, batch_count, clip_skip, width, height))

        logger.debug_log(f"batch_params: {all_batch_params[i]}", False)

    return all_batch_params