import math from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Optional import numpy as np import pandas as pd import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image from torch import Tensor, nn @dataclass class Heatmap: label: str score: float image: Image.Image @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] @dataclass class ImageLabels: caption: str booru: str rating: dict[str, float] general: dict[str, float] character: dict[str, float] @lru_cache(maxsize=5) def load_labels_hf( repo_id: str, revision: Optional[str] = None, token: Optional[str] = None, ) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data def mcut_threshold(probs: np.ndarray) -> float: """ Maximum Cut Thresholding (MCut) Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy for Multi-label Classification. In 11th International Symposium, IDA 2012 (pp. 172-183). """ probs = probs[probs.argsort()[::-1]] diffs = probs[:-1] - probs[1:] idx = diffs.argmax() thresh = (probs[idx] + probs[idx + 1]) / 2 return float(thresh) def pil_ensure_rgb(image: Image.Image) -> Image.Image: # convert to RGB/RGBA if not already (deals with palette images etc.) if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") # convert RGBA to RGB with white background if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square( image: Image.Image, fill: tuple[int, int, int] = (255, 255, 255), ) -> Image.Image: w, h = image.size # get the largest dimension so we can pad to a square px = max(image.size) # pad to square with white background canvas = Image.new("RGB", (px, px), fill) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas def preprocess_image( image: Image.Image, size_px: int | tuple[int, int], upscale: bool = True, ) -> Image.Image: """ Preprocess an image to be square and centered on a white background. """ if isinstance(size_px, int): size_px = (size_px, size_px) # ensure RGB and pad to square image = pil_ensure_rgb(image) image = pil_pad_square(image) # resize to target size if image.size[0] < size_px[0] or image.size[1] < size_px[1]: if upscale is False: raise ValueError("Image is smaller than target size, and upscaling is disabled") image = image.resize(size_px, Image.LANCZOS) if image.size[0] > size_px[0] or image.size[1] > size_px[1]: image.thumbnail(size_px, Image.BICUBIC) return image def pil_make_grid( images: list[Image.Image], max_cols: int = 8, padding: int = 4, bg_color: tuple[int, int, int] = (40, 42, 54), # dracula background color partial_rows: bool = True, ) -> Image.Image: n_cols = min(math.floor(math.sqrt(len(images))), max_cols) n_rows = math.ceil(len(images) / n_cols) # if the final row is not full and partial_rows is False, remove a row if n_cols * n_rows > len(images) and not partial_rows: n_rows -= 1 # assumes all images are same size image_width, image_height = images[0].size canvas_width = ((image_width + padding) * n_cols) + padding canvas_height = ((image_height + padding) * n_rows) + padding canvas = Image.new("RGB", (canvas_width, canvas_height), bg_color) for i, img in enumerate(images): x = (i % n_cols) * (image_width + padding) + padding y = (i // n_cols) * (image_height + padding) + padding canvas.paste(img, (x, y)) return canvas # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368 kaomojis = [ "0_0", "(o)_(o)", "+_+", "+_-", "._.", "_", "<|>_<|>", "=_=", ">_<", "3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||", ]