dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
raw
history blame
11.4 kB
# Python
import hashlib
import json
import math
import re
import string
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, List
# SD-WebUI
from modules import images, processing, shared
from modules.processing import Processed
from modules.processing import StableDiffusionProcessingTxt2Img as SD_Proc
# Local
from sd_advanced_grid.grid_settings import AxisOption
from sd_advanced_grid.utils import clean_name, logger
# ################################### Types ################################## #
if TYPE_CHECKING:
from PIL import Image
AxisSet = dict[str, tuple[str, Any]]
# ################################# Constants ################################ #
CHAR_SET = string.digits + string.ascii_uppercase
PROB_PATTERNS = ["date", "datetime", "job_timestamp", "batch_number", "generation_number", "seed"]
# ############################# Helper Functions ############################# #
def convert(num: int):
"""convert a decimal number into an alphanumerical value"""
base = len(CHAR_SET)
converted = ""
while num:
digit = num % base
converted += CHAR_SET[digit]
num //= base
return converted[::-1].zfill(2)
def file_exist(folder: Path, cell_id: str):
files = [path.stem for path in sorted(folder.glob(f"adv_cell-{cell_id}-*.*"))]
files = list(filter(lambda file: file.startswith(f"adv_cell-{cell_id}-"), files))
return len(files) > 0
def generate_filename(proc: SD_Proc, axis_set: AxisSet, keep_origin: bool = False):
"""generate a filename for each images based on data to be processed"""
file_name = ""
if keep_origin:
# use pattern defined by the user
# FIXME: how will fontend know about the filename?
re_pattern = re.compile(r"(\[([^\[\]<>]+)(?:<.+>|)\])")
width, height = proc.width, proc.height
namegen = images.FilenameGenerator(proc, proc.seed, proc.prompt, {"width": width, "height": height})
filename_pattern = shared.opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
# remove patterns that may prevent existance detection
for match in re_pattern.finditer(filename_pattern):
pattern, keyword = match.groups()
if keyword in PROB_PATTERNS:
filename_pattern = filename_pattern.replace("-" + pattern, "").replace("_" + pattern, "").replace(pattern, "")
file_name = f"{namegen.apply(filename_pattern)}"
else:
# in JS: md5(JSON.stringify(axis_set, Object.keys(axis_set).sort(), 2))
encoded = json.dumps(axis_set, sort_keys=True, indent=2).encode("utf-8")
dhash = hashlib.md5(encoded)
file_name = f"{dhash.hexdigest()}"
return file_name
def apply_axes(set_proc: SD_Proc, axes_settings: list[AxisOption]):
"""
run through each axis to apply current active values,
then select next available value on an axis
"""
axes = axes_settings.copy()
axes.sort(key=lambda axis: axis.cost) # reorder to avoid heavy changes
excs: list[Exception] = []
axis_set: AxisSet = {}
axis_code = ["00"] * len(axes_settings)
should_iter = True
# self.proc.styles = self.proc.styles[:] # allows for multiple styles axis
for axis in axes:
axis_code[axes_settings.index(axis)] = convert(axis.index + 1)
try:
axis.apply(set_proc)
except RuntimeError as err:
excs.append(err)
else:
axis_set[axis.id] = (axis.label, axis.value)
if should_iter:
should_iter = not axis.next()
return axis_set, "".join(axis_code[::-1]), excs
def prepare_jobs(adv_proc: SD_Proc, axes_settings: list[AxisOption], jobs: int, name: str, batches: int = 1):
"""create a dedicated processing instance for each variation with different axes values"""
if batches > 1:
# note: batches are possible but only with prompt, negative_prompt, seeds, or subseed
pass
cells: list[GridCell] = []
for _ in range(jobs):
set_proc = copy(adv_proc)
processing.fix_seed(set_proc)
set_proc.override_settings = copy(adv_proc.override_settings)
set_proc.extra_generation_params = copy(set_proc.extra_generation_params)
set_proc.extra_generation_params["Adv. Grid"] = name
axis_set, axis_code, errors = apply_axes(set_proc, axes_settings)
if errors:
logger.debug(f"Detected issues for {axis_code}:", errors)
# TODO: option to break here
continue
cell = GridCell(axis_code, set_proc, axis_set)
cells.append(cell)
return cells
def combine_processed(processed_result: Processed, processed: Processed):
"""combine all processed data to allow a single disaply in SD WebUI"""
if processed_result.index_of_first_image == 0:
# Use our first processed result object as a template container to hold our full results
processed_result.images = []
processed_result.all_prompts = []
processed_result.all_negative_prompts = []
processed_result.all_seeds = []
processed_result.all_subseeds = []
processed_result.infotexts = []
processed_result.index_of_first_image = 1
if processed.images:
# Non-empty list indicates some degree of success.
processed_result.images.extend(processed.images)
processed_result.all_prompts.extend(processed.all_prompts)
processed_result.all_negative_prompts.extend(processed.all_negative_prompts)
processed_result.all_seeds.extend(processed.all_seeds)
processed_result.all_subseeds.extend(processed.all_subseeds)
processed_result.infotexts.extend(processed.infotexts)
return processed_result
# ####################### Logic For Individual Variant ####################### #
@dataclass
class GridCell:
# init
cell_id: str
proc: SD_Proc
axis_set: dict[str, tuple[str, Any]]
processed: Processed = field(init=False)
job_count: int = field(init=False, default=1)
skipped: bool = field(init=False, default=False)
failed: bool = field(init=False, default=False)
def __post_init__(self):
if self.proc.enable_hr:
# NOTE: there might be some extensions that add jobs
self.job_count *= 2
def run(self, save_to: Path, overwrite: bool = False, for_web: bool = False):
total_steps = self.proc.steps + (
(self.proc.hr_second_pass_steps or self.proc.steps) if self.proc.enable_hr else 0
)
if file_exist(save_to, self.cell_id) and not overwrite:
# pylint: disable=protected-access
self.skipped = True
if shared.total_tqdm._tqdm:
# update console progessbar
shared.total_tqdm._tqdm.update(total_steps)
shared.state.nextjob()
if self.proc.enable_hr:
# NOTE: not sure if this is needed or automatic, progressbar update is finicky
shared.state.nextjob()
logger.debug(f"Skipping cell #{self.cell_id}, file already exist.")
return
logger.info(
f"Running image generation for cell {self.cell_id} with the following attributes:",
[f"{label}: {value}" for label, value in self.axis_set.values()],
)
# All the magic happens here
processed = None
try:
processed = processing.process_images(self.proc)
except:
logger.error(f"Skipping cell #{self.cell_id} due to a rendering error.")
if shared.state.interrupted:
return
if shared.state.skipped:
# pylint: disable=protected-access
self.skipped = True
shared.state.skipped = False
if shared.total_tqdm._tqdm:
# update console progessbar (to be tested)
shared.total_tqdm._tqdm.update(total_steps - shared.state.sampling_step)
logger.warn(f"Skipping cell #{self.cell_id}, requested by the system.")
return
if not processed or not processed.images or not any(processed.images):
logger.warn(f"No images were generated for cell #{self.cell_id}")
self.failed = True
return
base_name = generate_filename(self.proc, self.axis_set, not for_web)
file_name = f"adv_cell-{self.cell_id}-{base_name}"
file_ext = shared.opts.samples_format
file_path = save_to.joinpath(f"{file_name}.{file_ext}")
info_text = processing.create_infotext(
self.proc, self.proc.all_prompts, self.proc.all_seeds, self.proc.all_subseeds
)
processed.infotexts[0] = info_text
image: Image.Image = processed.images[0]
images.save_image(
image,
path=str(file_path.parent),
basename="",
info=info_text,
forced_filename=file_path.stem,
extension=file_path.suffix[1:],
save_to_dirs=False,
)
processed.images[0] = str(file_path)
self.processed = processed
# image.thumbnail((512, 512)) # could be useful to reduce memory usage (need testing)
if for_web:
# create and save thumbnail
file_path = save_to.parent.joinpath("thumbnails", f"{file_name}.png")
file_path.parent.mkdir(parents=True, exist_ok=True)
thumb = image.copy()
thumb.thumbnail((512, 512))
thumb.save(file_path)
logger.debug(f"Cell {self.cell_id} saved as {file_path.stem}")
# ########################## Generation Entry Point ########################## #
def generate_grid(adv_proc: SD_Proc, grid_name: str, overwrite: bool, batches: int, test: bool, axes: list[AxisOption], for_web=False):
grid_path = Path(adv_proc.outpath_grids, f"adv_grid_{clean_name(grid_name)}")
processed = Processed(adv_proc, [], adv_proc.seed, "", adv_proc.subseed)
aprox_jobs = math.prod([axis.length for axis in axes])
cells = prepare_jobs(adv_proc, axes, aprox_jobs, grid_name, batches)
grid_path.mkdir(parents=True, exist_ok=True)
grid_data = {
"name": grid_name,
"params": json.loads(processed.js()),
"axis": [axis.dict() for axis in axes],
# "cells": [{ "id": cell.cell_id, "set": cell.axis_set } for cell in cells] # for testing only
}
with grid_path.joinpath("config.json").open(mode="w", encoding="UTF-8") as file:
file.write(json.dumps(grid_data, indent=2))
if test:
return processed
shared.state.job_count = sum((cell.job_count for cell in cells), start=0)
shared.state.processing_has_refined_job_count = True
logger.info(f"Starting generation of {len(cells)} variants")
for i, cell in enumerate(cells):
job_info = f"Generating variant #{i + 1} out of {len(cells)} - "
shared.state.textinfo = job_info # type: ignore
shared.state.job = job_info # seems to be unused
cell.run(save_to=grid_path.joinpath("images"), overwrite=overwrite, for_web=for_web)
cell.proc.close()
if shared.state.interrupted:
logger.warn("Process interupted. Cancelling all jobs.")
break
if not cell.skipped and not cell.failed:
combine_processed(processed, cell.processed)
return processed