""""This module provides a function to get all""" from dataclasses import dataclass from typing import Union, List, Tuple import re import os from scripts.Logger import Logger from scripts.Utils import Utils import modules import modules.shared as shared @dataclass() class BatchParams: """Dataclass to store the parameters for a batch Args: checkpoint (str): the checkpoint name prompt (str): the prompt hr_prompt (str): the hires. fix prompt neg_prompt (str): the negative prompt style (List[str]): the style (A1111 styles) batch_count (int, optional): the batch count. Defaults to -1. (don't overwrite the UI value) clip_skip (int, optional): the clip skip. Defaults to 1. width (int, optional): the width. Defaults to -1. (don't overwrite the UI value) height (int, optional): the height. Defaults to -1. (don't overwrite the UI value) """ checkpoint: str prompt: str hr_prompt: str neg_prompt: str style : List[str] batch_count: int = -1 clip_skip: int = 1 width: int = -1 height: int = -1 def __repr__(self) -> str: checkpointName: str = os.path.basename(self.checkpoint) return( f"BatchParams: {checkpointName},\n " f"prompt: {self.prompt},\n" f"style: {self.style},\n" f"neg_prompt: {self.neg_prompt},\n " f"batch_count: {self.batch_count},\n " f"clip_skip: {self.clip_skip}\n" f"size: {self.width}x{self.height}") logger = Logger() def get_all_batch_params(p: Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img], checkpoints_as_string: str, prompts_as_string: str) -> List[BatchParams]: """Get all batch parameters from the input Args: p (Union[modules.processing.StableDiffusionProcessingTxt2Img, modules.processing.StableDiffusionProcessingImg2Img]): the processing object checkpoints_as_string (str): the checkpoints as string prompts_as_string (str): the prompts as string Returns: List[BatchParams]: the batch parameters """ def getRegexFromOpts(key: str, search_for_number: bool = True) -> Tuple[str, str]: """Get the regex from the options. As the user can change the regex, it is checked if the regex is valid. Args: key (str): the key search_for_number (bool, optional): If true checks if the regex is valid. Defaults to True. Returns: Tuple[str, str]: the search pattern and the sub pattern """ sub_pattern = getattr(shared.opts, key) search_pattern = sub_pattern.replace("[", "([").replace("]", "])") if not re.search(r"\[0-9\]\+|\\d\+", sub_pattern) and search_for_number: raise RuntimeError(f'Can\'t find a number with the regex for {key}: "{sub_pattern}"') return search_pattern, sub_pattern utils = Utils() def get_batch_count_from_prompt(prompt: str) -> Tuple[int, str]: """Extracts the batch count from the prompt if specified, else uses the default value Args: prompt (str): the prompt Returns: Tuple[int, str]: the batch count and the prompt """ search_pattern, sub_pattern = getRegexFromOpts("batchCountRegex") number_match = re.search(search_pattern, prompt) if number_match and number_match.group(1): # Extract the number from the match object number = int(number_match.group(1)) # Use group(1) to get the number inside parentheses number = p.n_iter if number < 1 else number prompt = re.sub(sub_pattern, '', prompt) else: number = p.n_iter return number, prompt def get_clip_skip_from_prompt(prompt: str) -> Tuple[int, str]: """Extracts the clip skip from the prompt if specified, else uses the default value Args: prompt (str): the prompt Returns: Tuple[int, str]: the clip skip and the prompt """ search_pattern, sub_pattern = getRegexFromOpts("clipSkipRegex") number_match = re.search(search_pattern, prompt) if number_match and number_match.group(1): # Extract the number from the match object number = int(number_match.group(1)) number = shared.opts.data["CLIP_stop_at_last_layers"] if number < 1 else number prompt = ( re.sub(sub_pattern, '', prompt)) else: number = shared.opts.data["CLIP_stop_at_last_layers"] return number, prompt def get_style_from_prompt(prompt: str) -> Tuple[List[str], str]: """Extracts the style from the prompt if specified. Args: prompt (str): the prompt Returns: Tuple[List[str], str]: the styles and the prompt """ styles = [] search_pattern, sub_pattern = getRegexFromOpts("styleRegex", False) style_matches = re.findall(search_pattern, prompt) if style_matches: for i, stl in enumerate(style_matches): styles.append(stl) _, prompt_regex = getRegexFromOpts("promptRegex", False) replacement = prompt_regex if i == len(style_matches) - 1 else "" prompt = re.sub(sub_pattern, replacement, prompt, count=1) logger.debug_log(f"nr.: {i}, prompt: {prompt}", False) return styles, prompt def get_image_size_from_prompt(prompt: str) -> Tuple[int, int, str]: """Extracts the image size from the prompt if specified, else uses the default value Args: prompt (str): the prompt Returns: Tuple[int, int, str]: the width, height and the prompt. If the width and height are not specified, -1 is returned. """ search_pattern, sub_pattern = getRegexFromOpts("widthHeightRegex", False) #{{size:[0-9]+x[0-9]+}} number_matches = re.search(search_pattern, prompt) if number_matches: try: width, height = map(int, number_matches.groups()) except ValueError: raise RuntimeError(f"Can't convert the image size to an integer: {number_matches[0]}") prompt = re.sub(sub_pattern, '', prompt) else: width, height = -1, -1 return width, height, prompt def split_postive_and_negative_postive_prompt(prompt: str) -> Tuple[str, str]: """Splits the prompt into a positive and negative prompt. If a negative prompt is specified. Args: prompt (str): the prompt Returns: Tuple[str, str]: the positive and negative prompt """ pattern = getattr(shared.opts, "negPromptRegex") parts = re.split(pattern, prompt) if len(parts) > 1: neg_prompt = parts[1] else: neg_prompt = "" prompt = parts[0] return prompt, neg_prompt all_batch_params: List[BatchParams] = [] checkpoints: List[str] = utils.getCheckpointListFromInput(checkpoints_as_string) prompts: List[str] = utils.remove_index_from_string(prompts_as_string).split(";") prompts = [prompt.replace('\n', '').strip() for prompt in prompts if not prompt.isspace() and prompt != ''] if len(prompts) != len(checkpoints): logger.debug_log(f"len prompt: {len(prompts)}, len checkpoints{len(checkpoints)}") raise RuntimeError("amount of prompts don't match with amount of checkpoints") if len(prompts) == 0: raise RuntimeError("can't run without a checkpoint and prompt") for i in range(len(checkpoints)): info = modules.sd_models.get_closet_checkpoint_match(checkpoints[i]) if info is None: raise RuntimeError(f"Unknown checkpoint: {checkpoints[i]}") batch_count, prompts[i] = get_batch_count_from_prompt(prompts[i]) clip_skip, prompts[i] = get_clip_skip_from_prompt(prompts[i]) style, prompts[i] = get_style_from_prompt(prompts[i]) width, height, prompts[i] = get_image_size_from_prompt(prompts[i]) prompt_template, neg_prompt = split_postive_and_negative_postive_prompt(prompts[i]) _, prompt_regex = getRegexFromOpts("promptRegex", False) prompt = prompt_template.replace(prompt_regex, p.prompt) hr_prompt = prompt_template.replace(prompt_regex, p.hr_prompt) neg_prompt = p.negative_prompt + neg_prompt all_batch_params.append(BatchParams(checkpoints[i], prompt,hr_prompt, neg_prompt, style, batch_count, clip_skip, width, height)) logger.debug_log(f"batch_params: {all_batch_params[i]}", False) return all_batch_params