dikdimon's picture
Upload extensions using SD-Hub extension
f4a41d8 verified
raw
history blame
9.2 kB
"""This module provides utility functions."""
from scripts.Logger import Logger
import os
import re
import requests
from typing import List
import modules
from modules.sd_models import read_state_dict
from modules.sd_models_config import (find_checkpoint_config, config_default, config_sd2, config_sd2v, config_sd2_inpainting,
config_depth_model, config_unclip, config_unopenclip, config_inpainting, config_instruct_pix2pix, config_alt_diffusion)
import sys
sys.path.insert(0, os.path.join(os.path.dirname(
os.path.abspath(__file__)), "scripts"))
class Utils():
"""
methods that are needed in different classes
"""
def __init__(self) -> None:
self.logger = Logger()
self.logger.debug = False
script_path = os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))
self.held_md_file_name = os.path.join(
script_path, "HelpBatchCheckpointsPrompt.md")
self.held_md_url = f"https://raw.githubusercontent.com/h43lb1t0/BatchCheckpointPrompt/main/{self.held_md_file_name}.md"
def split_prompts(self, text: str) -> List[str]:
"""Split the prompts by the ; and remove empty strings and newlines
Args:
text (str): the input string
Returns:
List[str]: a list of prompts
"""
prompt_list = text.split(";")
return [prompt.replace('\n', '').strip(
) for prompt in prompt_list if not prompt.isspace() and prompt != '']
def remove_index_from_string(self, input: str) -> str:
"""Remove the index from the string
Args:
input (str): the input string
Returns:
str: the string without the index
"""
return re.sub(r"@index:\d+", "", input).strip()
def remove_model_version_from_string(self, checkpoints_text: str) -> str:
"""Remove the model version from the string
Args:
input (str): the input string with all checkpoints
Returns:
str: the string without the model version
"""
patterns = [
'@version:sd1',
'@version:sd2',
'@version:sd2v',
'@version:sd2-inpainting',
'@version:depth',
'@version:unclip',
'@version:unopenclip',
'@version:sd1-inpainting',
'@version:pix2pix',
'@version:alt'
]
# Iterate over the patterns and substitute them with an empty string
for pattern in patterns:
checkpoints_text = re.sub(pattern, '', checkpoints_text)
return checkpoints_text
def get_clean_checkpoint_path(self, checkpoint: str) -> str:
"""Remove the checkpoint hash from the filename
Args:
input (str): the input string with hash
Returns:
str: the string without the hash
"""
return re.sub(r' \[.*?\]', '', checkpoint).strip()
def getCheckpointListFromInput(self, checkpoints_text: str, clean: bool = True) -> List[str]:
"""Get a list of checkpoints from the input string
Args:
checkpoints_text (str): the input string with all checkpoints
clean (bool): remove the index and hash from the string
Returns:
List[str]: a list of checkpoints
"""
self.logger.debug_log(f"checkpoints: {checkpoints_text}")
checkpoints_text = self.remove_model_version_from_string(checkpoints_text)
if clean:
checkpoints_text = self.remove_index_from_string(checkpoints_text)
checkpoints_text = self.get_clean_checkpoint_path(checkpoints_text)
checkpoints = checkpoints_text.split(",")
checkpoints = [checkpoint.replace('\n', '').strip(
) for checkpoint in checkpoints if checkpoints if not checkpoint.isspace() and checkpoint != '']
return checkpoints
def get_help_md(self) -> str:
"""Gets the help md file.
If the file is not localy found downloads it from the github repository
Returns:
str: the help md file as a string
"""
md = "could not get help file. Check Github for more information"
if os.path.isfile(self.held_md_file_name):
with open(self.held_md_file_name) as f:
md = f.read()
else:
self.logger.debug_log("downloading help md")
result = requests.get(self.held_md_url)
if result.status_code == 200:
with open(self.held_md_file_name, "wb") as file:
file.write(result.content)
return self.get_help_md()
return md
def add_index_to_string(self, text: str, is_checkpoint: bool = True) -> str:
"""Add the index to the string
Args:
text (str): the input string
is_checkpoint (bool): if the string is a checkpoint lits or a prompt list
Returns:
str: the string with the index
"""
text_string = ""
if is_checkpoint:
checkpoint_List = self.getCheckpointListFromInput(text)
for i, checkpoint in enumerate(checkpoint_List):
text_string += f"{self.remove_index_from_string(checkpoint)} @index:{i},\n"
return text_string
else:
prompt_list = self.split_prompts(text)
for i, prompt in enumerate(prompt_list):
text_string += f"{self.remove_index_from_string(prompt)} @index:{i};\n\n"
return text_string
def add_model_version_to_string(self, checkpoints_text: str) -> str:
"""Add the model version to the string.
EXPERIMENTAL!
Args:
checkpoints_text (str): the input string with all checkpoints
Returns:
str: the string with the model version
"""
text_string = ""
checkpoints_not_cleaned = self.getCheckpointListFromInput(
checkpoints_text, clean=False)
checkpoints = self.getCheckpointListFromInput(checkpoints_text)
for i, checkpoint in enumerate(checkpoints):
info = modules.sd_models.get_closet_checkpoint_match(checkpoint)
state_dict = read_state_dict(info.filename)
version_string = find_checkpoint_config(state_dict, None)
if version_string == config_default:
version_string = "sd1"
elif version_string == config_sd2:
version_string = "sd2"
elif version_string == config_sd2v:
version_string = "sd2v"
elif version_string == config_sd2_inpainting:
version_string = "sd2-inpainting"
elif version_string == config_depth_model:
version_string = "depth"
elif version_string == config_unclip:
version_string = "unclip"
elif version_string == config_unopenclip:
version_string = "unopenclip"
elif version_string == config_inpainting:
version_string = "sd1-inpainting"
elif version_string == config_instruct_pix2pix:
version_string = "pix2pix"
elif version_string == config_alt_diffusion:
version_string = "alt"
checkpoint_partly_cleaned = checkpoints_not_cleaned[i].replace(
"\n", "").replace(",", "")
text_string += f"{checkpoint_partly_cleaned} @version:{version_string},\n\n"
return text_string
def remove_element_at_index(self, checkpoints: str, prompts: str, index: List[int]) -> List[str]:
"""Remove the element at the given index from the string
Args:
checkpoints (str): the input string with all checkpoints
prompts (str): the input string with all prompts
index (List[int]): the indices to remove
Returns:
List[str]: a list with the new checkpoints and prompts
"""
checkpoints_list = self.getCheckpointListFromInput(checkpoints)
prompts_list = self.split_prompts(prompts)
if (len(checkpoints_list) == len(prompts_list) or len(prompts_list) - len(index) <= 0 ):
if max(index) <= len(checkpoints_list) -1:
for i in index:
checkpoints_list.pop(i)
prompts_list.pop(i)
checkpoints = ""
for c in checkpoints_list:
checkpoints += f"{c},"
prompts = ""
for p in prompts_list:
prompts += f"{p};"
result = [self.add_index_to_string(checkpoints, True), self.add_index_to_string(prompts, False)]
self.logger.debug_log(f"result: {result}")
return result
else:
self.logger.debug_log("index is out of range")
return [checkpoints, prompts]
else:
self.logger.debug_log(
f"checkpoints and prompts are not the same length cp: {len(checkpoints_list)} p: {len(prompts_list)}")
return [checkpoints, prompts]