|
|
|
import hashlib |
|
import json |
|
import math |
|
import re |
|
import string |
|
from copy import copy |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
from typing import TYPE_CHECKING, Any, List |
|
|
|
|
|
from modules import images, processing, shared |
|
from modules.processing import Processed |
|
from modules.processing import StableDiffusionProcessingTxt2Img as SD_Proc |
|
|
|
|
|
from sd_advanced_grid.grid_settings import AxisOption |
|
from sd_advanced_grid.utils import clean_name, logger |
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
from PIL import Image |
|
|
|
AxisSet = dict[str, tuple[str, Any]] |
|
|
|
|
|
|
|
CHAR_SET = string.digits + string.ascii_uppercase |
|
PROB_PATTERNS = ["date", "datetime", "job_timestamp", "batch_number", "generation_number", "seed"] |
|
|
|
|
|
|
|
|
|
def convert(num: int): |
|
"""convert a decimal number into an alphanumerical value""" |
|
base = len(CHAR_SET) |
|
converted = "" |
|
while num: |
|
digit = num % base |
|
converted += CHAR_SET[digit] |
|
num //= base |
|
return converted[::-1].zfill(2) |
|
|
|
def file_exist(folder: Path, cell_id: str): |
|
files = [path.stem for path in sorted(folder.glob(f"adv_cell-{cell_id}-*.*"))] |
|
files = list(filter(lambda file: file.startswith(f"adv_cell-{cell_id}-"), files)) |
|
return len(files) > 0 |
|
|
|
def generate_filename(proc: SD_Proc, axis_set: AxisSet, keep_origin: bool = False): |
|
"""generate a filename for each images based on data to be processed""" |
|
file_name = "" |
|
if keep_origin: |
|
|
|
|
|
re_pattern = re.compile(r"(\[([^\[\]<>]+)(?:<.+>|)\])") |
|
width, height = proc.width, proc.height |
|
namegen = images.FilenameGenerator(proc, proc.seed, proc.prompt, {"width": width, "height": height}) |
|
filename_pattern = shared.opts.samples_filename_pattern or "[seed]-[prompt_spaces]" |
|
|
|
for match in re_pattern.finditer(filename_pattern): |
|
pattern, keyword = match.groups() |
|
if keyword in PROB_PATTERNS: |
|
filename_pattern = filename_pattern.replace("-" + pattern, "").replace("_" + pattern, "").replace(pattern, "") |
|
|
|
file_name = f"{namegen.apply(filename_pattern)}" |
|
else: |
|
|
|
encoded = json.dumps(axis_set, sort_keys=True, indent=2).encode("utf-8") |
|
dhash = hashlib.md5(encoded) |
|
file_name = f"{dhash.hexdigest()}" |
|
|
|
return file_name |
|
|
|
|
|
def apply_axes(set_proc: SD_Proc, axes_settings: list[AxisOption]): |
|
""" |
|
run through each axis to apply current active values, |
|
then select next available value on an axis |
|
""" |
|
axes = axes_settings.copy() |
|
axes.sort(key=lambda axis: axis.cost) |
|
|
|
excs: list[Exception] = [] |
|
axis_set: AxisSet = {} |
|
axis_code = ["00"] * len(axes_settings) |
|
should_iter = True |
|
|
|
|
|
for axis in axes: |
|
axis_code[axes_settings.index(axis)] = convert(axis.index + 1) |
|
try: |
|
axis.apply(set_proc) |
|
except RuntimeError as err: |
|
excs.append(err) |
|
else: |
|
axis_set[axis.id] = (axis.label, axis.value) |
|
if should_iter: |
|
should_iter = not axis.next() |
|
return axis_set, "".join(axis_code[::-1]), excs |
|
|
|
|
|
def prepare_jobs(adv_proc: SD_Proc, axes_settings: list[AxisOption], jobs: int, name: str, batches: int = 1): |
|
"""create a dedicated processing instance for each variation with different axes values""" |
|
|
|
if batches > 1: |
|
|
|
pass |
|
|
|
cells: list[GridCell] = [] |
|
|
|
for _ in range(jobs): |
|
set_proc = copy(adv_proc) |
|
processing.fix_seed(set_proc) |
|
set_proc.override_settings = copy(adv_proc.override_settings) |
|
set_proc.extra_generation_params = copy(set_proc.extra_generation_params) |
|
set_proc.extra_generation_params["Adv. Grid"] = name |
|
axis_set, axis_code, errors = apply_axes(set_proc, axes_settings) |
|
if errors: |
|
logger.debug(f"Detected issues for {axis_code}:", errors) |
|
|
|
continue |
|
cell = GridCell(axis_code, set_proc, axis_set) |
|
cells.append(cell) |
|
|
|
return cells |
|
|
|
|
|
def combine_processed(processed_result: Processed, processed: Processed): |
|
"""combine all processed data to allow a single disaply in SD WebUI""" |
|
if processed_result.index_of_first_image == 0: |
|
|
|
processed_result.images = [] |
|
processed_result.all_prompts = [] |
|
processed_result.all_negative_prompts = [] |
|
processed_result.all_seeds = [] |
|
processed_result.all_subseeds = [] |
|
processed_result.infotexts = [] |
|
processed_result.index_of_first_image = 1 |
|
|
|
if processed.images: |
|
|
|
processed_result.images.extend(processed.images) |
|
processed_result.all_prompts.extend(processed.all_prompts) |
|
processed_result.all_negative_prompts.extend(processed.all_negative_prompts) |
|
processed_result.all_seeds.extend(processed.all_seeds) |
|
processed_result.all_subseeds.extend(processed.all_subseeds) |
|
processed_result.infotexts.extend(processed.infotexts) |
|
|
|
return processed_result |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class GridCell: |
|
|
|
cell_id: str |
|
proc: SD_Proc |
|
axis_set: dict[str, tuple[str, Any]] |
|
processed: Processed = field(init=False) |
|
job_count: int = field(init=False, default=1) |
|
skipped: bool = field(init=False, default=False) |
|
failed: bool = field(init=False, default=False) |
|
|
|
def __post_init__(self): |
|
if self.proc.enable_hr: |
|
|
|
self.job_count *= 2 |
|
|
|
def run(self, save_to: Path, overwrite: bool = False, for_web: bool = False): |
|
|
|
|
|
total_steps = self.proc.steps + ( |
|
(self.proc.hr_second_pass_steps or self.proc.steps) if self.proc.enable_hr else 0 |
|
) |
|
|
|
if file_exist(save_to, self.cell_id) and not overwrite: |
|
|
|
self.skipped = True |
|
if shared.total_tqdm._tqdm: |
|
|
|
shared.total_tqdm._tqdm.update(total_steps) |
|
shared.state.nextjob() |
|
if self.proc.enable_hr: |
|
|
|
shared.state.nextjob() |
|
logger.debug(f"Skipping cell #{self.cell_id}, file already exist.") |
|
return |
|
|
|
logger.info( |
|
f"Running image generation for cell {self.cell_id} with the following attributes:", |
|
[f"{label}: {value}" for label, value in self.axis_set.values()], |
|
) |
|
|
|
|
|
processed = None |
|
try: |
|
processed = processing.process_images(self.proc) |
|
except: |
|
logger.error(f"Skipping cell #{self.cell_id} due to a rendering error.") |
|
|
|
if shared.state.interrupted: |
|
return |
|
|
|
if shared.state.skipped: |
|
|
|
self.skipped = True |
|
shared.state.skipped = False |
|
if shared.total_tqdm._tqdm: |
|
|
|
shared.total_tqdm._tqdm.update(total_steps - shared.state.sampling_step) |
|
logger.warn(f"Skipping cell #{self.cell_id}, requested by the system.") |
|
return |
|
|
|
if not processed or not processed.images or not any(processed.images): |
|
logger.warn(f"No images were generated for cell #{self.cell_id}") |
|
self.failed = True |
|
return |
|
|
|
base_name = generate_filename(self.proc, self.axis_set, not for_web) |
|
file_name = f"adv_cell-{self.cell_id}-{base_name}" |
|
file_ext = shared.opts.samples_format |
|
file_path = save_to.joinpath(f"{file_name}.{file_ext}") |
|
|
|
info_text = processing.create_infotext( |
|
self.proc, self.proc.all_prompts, self.proc.all_seeds, self.proc.all_subseeds |
|
) |
|
processed.infotexts[0] = info_text |
|
image: Image.Image = processed.images[0] |
|
|
|
images.save_image( |
|
image, |
|
path=str(file_path.parent), |
|
basename="", |
|
info=info_text, |
|
forced_filename=file_path.stem, |
|
extension=file_path.suffix[1:], |
|
save_to_dirs=False, |
|
) |
|
|
|
processed.images[0] = str(file_path) |
|
self.processed = processed |
|
|
|
|
|
if for_web: |
|
|
|
file_path = save_to.parent.joinpath("thumbnails", f"{file_name}.png") |
|
file_path.parent.mkdir(parents=True, exist_ok=True) |
|
thumb = image.copy() |
|
thumb.thumbnail((512, 512)) |
|
thumb.save(file_path) |
|
|
|
logger.debug(f"Cell {self.cell_id} saved as {file_path.stem}") |
|
|
|
|
|
|
|
|
|
|
|
def generate_grid(adv_proc: SD_Proc, grid_name: str, overwrite: bool, batches: int, test: bool, axes: list[AxisOption], for_web=False): |
|
grid_path = Path(adv_proc.outpath_grids, f"adv_grid_{clean_name(grid_name)}") |
|
|
|
processed = Processed(adv_proc, [], adv_proc.seed, "", adv_proc.subseed) |
|
|
|
aprox_jobs = math.prod([axis.length for axis in axes]) |
|
cells = prepare_jobs(adv_proc, axes, aprox_jobs, grid_name, batches) |
|
|
|
grid_path.mkdir(parents=True, exist_ok=True) |
|
grid_data = { |
|
"name": grid_name, |
|
"params": json.loads(processed.js()), |
|
"axis": [axis.dict() for axis in axes], |
|
|
|
} |
|
|
|
with grid_path.joinpath("config.json").open(mode="w", encoding="UTF-8") as file: |
|
file.write(json.dumps(grid_data, indent=2)) |
|
|
|
if test: |
|
return processed |
|
|
|
shared.state.job_count = sum((cell.job_count for cell in cells), start=0) |
|
shared.state.processing_has_refined_job_count = True |
|
|
|
logger.info(f"Starting generation of {len(cells)} variants") |
|
|
|
for i, cell in enumerate(cells): |
|
job_info = f"Generating variant #{i + 1} out of {len(cells)} - " |
|
shared.state.textinfo = job_info |
|
shared.state.job = job_info |
|
cell.run(save_to=grid_path.joinpath("images"), overwrite=overwrite, for_web=for_web) |
|
cell.proc.close() |
|
if shared.state.interrupted: |
|
logger.warn("Process interupted. Cancelling all jobs.") |
|
break |
|
if not cell.skipped and not cell.failed: |
|
combine_processed(processed, cell.processed) |
|
|
|
return processed |
|
|