|
""""This module provides a function to get all""" |
|
from dataclasses import dataclass |
|
from typing import Union, List, Tuple |
|
import re |
|
import os |
|
from scripts.Logger import Logger |
|
from scripts.Utils import Utils |
|
|
|
import modules |
|
|
|
import modules.shared as shared |
|
|
|
@dataclass() |
|
class BatchParams: |
|
"""Dataclass to store the parameters for a batch |
|
|
|
Args: |
|
checkpoint (str): the checkpoint name |
|
prompt (str): the prompt |
|
hr_prompt (str): the hires. fix prompt |
|
neg_prompt (str): the negative prompt |
|
style (List[str]): the style (A1111 styles) |
|
batch_count (int, optional): the batch count. Defaults to -1. (don't overwrite the UI value) |
|
clip_skip (int, optional): the clip skip. Defaults to 1. |
|
width (int, optional): the width. Defaults to -1. (don't overwrite the UI value) |
|
height (int, optional): the height. Defaults to -1. (don't overwrite the UI value) |
|
""" |
|
checkpoint: str |
|
prompt: str |
|
hr_prompt: str |
|
neg_prompt: str |
|
style : List[str] |
|
batch_count: int = -1 |
|
clip_skip: int = 1 |
|
width: int = -1 |
|
height: int = -1 |
|
|
|
def __repr__(self) -> str: |
|
checkpointName: str = os.path.basename(self.checkpoint) |
|
return( f"BatchParams: {checkpointName},\n " |
|
f"prompt: {self.prompt},\n" |
|
f"style: {self.style},\n" |
|
f"neg_prompt: {self.neg_prompt},\n " |
|
f"batch_count: {self.batch_count},\n " |
|
f"clip_skip: {self.clip_skip}\n" |
|
f"size: {self.width}x{self.height}") |
|
|
|
logger = Logger() |
|
|
|
def get_all_batch_params(p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], checkpoints_as_string: str, prompts_as_string: str) -> List[BatchParams]: |
|
"""Get all batch parameters from the input |
|
|
|
Args: |
|
p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object |
|
checkpoints_as_string (str): the checkpoints as string |
|
prompts_as_string (str): the prompts as string |
|
|
|
Returns: |
|
List[BatchParams]: the batch parameters |
|
""" |
|
|
|
def getRegexFromOpts(key: str, search_for_number: bool = True) -> Tuple[str, str]: |
|
"""Get the regex from the options. As the user can change the regex, |
|
it is checked if the regex is valid. |
|
|
|
Args: |
|
key (str): the key |
|
search_for_number (bool, optional): If true checks if the regex is valid. Defaults to True. |
|
|
|
Returns: |
|
Tuple[str, str]: the search pattern and the sub pattern |
|
""" |
|
sub_pattern = getattr(shared.opts, key) |
|
search_pattern = sub_pattern.replace("[", "([").replace("]", "])") |
|
|
|
if not re.search(r"\[0-9\]\+|\\d\+", sub_pattern) and search_for_number: |
|
raise RuntimeError(f'Can\'t find a number with the regex for {key}: "{sub_pattern}"') |
|
|
|
return search_pattern, sub_pattern |
|
|
|
utils = Utils() |
|
|
|
def get_batch_count_from_prompt(prompt: str) -> Tuple[int, str]: |
|
"""Extracts the batch count from the prompt if specified, else uses the default value |
|
|
|
Args: |
|
prompt (str): the prompt |
|
|
|
Returns: |
|
Tuple[int, str]: the batch count and the prompt |
|
""" |
|
search_pattern, sub_pattern = getRegexFromOpts("batchCountRegex") |
|
number_match = re.search(search_pattern, prompt) |
|
if number_match and number_match.group(1): |
|
|
|
number = int(number_match.group(1)) |
|
number = p.n_iter if number < 1 else number |
|
prompt = re.sub(sub_pattern, '', prompt) |
|
else: |
|
number = p.n_iter |
|
|
|
|
|
return number, prompt |
|
|
|
def get_clip_skip_from_prompt(prompt: str) -> Tuple[int, str]: |
|
"""Extracts the clip skip from the prompt if specified, else uses the default value |
|
|
|
Args: |
|
prompt (str): the prompt |
|
|
|
Returns: |
|
Tuple[int, str]: the clip skip and the prompt |
|
""" |
|
search_pattern, sub_pattern = getRegexFromOpts("clipSkipRegex") |
|
number_match = re.search(search_pattern, prompt) |
|
if number_match and number_match.group(1): |
|
|
|
number = int(number_match.group(1)) |
|
number = shared.opts.data["CLIP_stop_at_last_layers"] if number < 1 else number |
|
prompt = ( |
|
re.sub(sub_pattern, '', prompt)) |
|
else: |
|
number = shared.opts.data["CLIP_stop_at_last_layers"] |
|
|
|
|
|
return number, prompt |
|
|
|
def get_style_from_prompt(prompt: str) -> Tuple[List[str], str]: |
|
"""Extracts the style from the prompt if specified. |
|
|
|
Args: |
|
prompt (str): the prompt |
|
|
|
Returns: |
|
Tuple[List[str], str]: the styles and the prompt |
|
""" |
|
styles = [] |
|
search_pattern, sub_pattern = getRegexFromOpts("styleRegex", False) |
|
style_matches = re.findall(search_pattern, prompt) |
|
if style_matches: |
|
for i, stl in enumerate(style_matches): |
|
styles.append(stl) |
|
_, prompt_regex = getRegexFromOpts("promptRegex", False) |
|
replacement = prompt_regex if i == len(style_matches) - 1 else "" |
|
prompt = re.sub(sub_pattern, replacement, prompt, count=1) |
|
|
|
logger.debug_log(f"nr.: {i}, prompt: {prompt}", False) |
|
|
|
return styles, prompt |
|
|
|
def get_image_size_from_prompt(prompt: str) -> Tuple[int, int, str]: |
|
"""Extracts the image size from the prompt if specified, else uses the default value |
|
|
|
Args: |
|
prompt (str): the prompt |
|
|
|
Returns: |
|
Tuple[int, int, str]: the width, height and the prompt. |
|
If the width and height are not specified, -1 is returned. |
|
""" |
|
search_pattern, sub_pattern = getRegexFromOpts("widthHeightRegex", False) |
|
number_matches = re.search(search_pattern, prompt) |
|
if number_matches: |
|
try: |
|
width, height = map(int, number_matches.groups()) |
|
except ValueError: |
|
raise RuntimeError(f"Can't convert the image size to an integer: {number_matches[0]}") |
|
prompt = re.sub(sub_pattern, '', prompt) |
|
else: |
|
width, height = -1, -1 |
|
|
|
return width, height, prompt |
|
|
|
def split_postive_and_negative_postive_prompt(prompt: str) -> Tuple[str, str]: |
|
"""Splits the prompt into a positive and negative prompt. |
|
If a negative prompt is specified. |
|
|
|
Args: |
|
prompt (str): the prompt |
|
|
|
Returns: |
|
Tuple[str, str]: the positive and negative prompt |
|
""" |
|
pattern = getattr(shared.opts, "negPromptRegex") |
|
parts = re.split(pattern, prompt) |
|
if len(parts) > 1: |
|
neg_prompt = parts[1] |
|
else: |
|
neg_prompt = "" |
|
|
|
prompt = parts[0] |
|
|
|
return prompt, neg_prompt |
|
|
|
|
|
all_batch_params: List[BatchParams] = [] |
|
|
|
checkpoints: List[str] = utils.getCheckpointListFromInput(checkpoints_as_string) |
|
|
|
|
|
prompts: List[str] = utils.remove_index_from_string(prompts_as_string).split(";") |
|
prompts = [prompt.replace('\n', '').strip() for prompt in prompts if not prompt.isspace() and prompt != ''] |
|
|
|
if len(prompts) != len(checkpoints): |
|
logger.debug_log(f"len prompt: {len(prompts)}, len checkpoints{len(checkpoints)}") |
|
raise RuntimeError("amount of prompts don't match with amount of checkpoints") |
|
|
|
if len(prompts) == 0: |
|
raise RuntimeError("can't run without a checkpoint and prompt") |
|
|
|
|
|
for i in range(len(checkpoints)): |
|
|
|
info = modules.sd_models.get_closet_checkpoint_match(checkpoints[i]) |
|
if info is None: |
|
raise RuntimeError(f"Unknown checkpoint: {checkpoints[i]}") |
|
|
|
|
|
batch_count, prompts[i] = get_batch_count_from_prompt(prompts[i]) |
|
clip_skip, prompts[i] = get_clip_skip_from_prompt(prompts[i]) |
|
style, prompts[i] = get_style_from_prompt(prompts[i]) |
|
width, height, prompts[i] = get_image_size_from_prompt(prompts[i]) |
|
prompt_template, neg_prompt = split_postive_and_negative_postive_prompt(prompts[i]) |
|
|
|
|
|
_, prompt_regex = getRegexFromOpts("promptRegex", False) |
|
|
|
prompt = prompt_template.replace(prompt_regex, p.prompt) |
|
hr_prompt = prompt_template.replace(prompt_regex, p.hr_prompt) |
|
neg_prompt = p.negative_prompt + neg_prompt |
|
|
|
|
|
all_batch_params.append(BatchParams(checkpoints[i], prompt,hr_prompt, neg_prompt, style, batch_count, clip_skip, width, height)) |
|
|
|
logger.debug_log(f"batch_params: {all_batch_params[i]}", False) |
|
|
|
return all_batch_params |