Spaces:
Build error
Build error
from tkinter import filedialog, Tk | |
from easygui import msgbox | |
import os | |
import re | |
import gradio as gr | |
import easygui | |
import shutil | |
import sys | |
import json | |
from library.custom_logging import setup_logging | |
from datetime import datetime | |
# Set up logging | |
log = setup_logging() | |
folder_symbol = '\U0001f4c2' # π | |
refresh_symbol = '\U0001f504' # π | |
save_style_symbol = '\U0001f4be' # πΎ | |
document_symbol = '\U0001F4C4' # π | |
# define a list of substrings to search for v2 base models | |
V2_BASE_MODELS = [ | |
'stabilityai/stable-diffusion-2-1-base/blob/main/v2-1_512-ema-pruned', | |
'stabilityai/stable-diffusion-2-1-base', | |
'stabilityai/stable-diffusion-2-base', | |
] | |
# define a list of substrings to search for v_parameterization models | |
V_PARAMETERIZATION_MODELS = [ | |
'stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned', | |
'stabilityai/stable-diffusion-2-1', | |
'stabilityai/stable-diffusion-2', | |
] | |
# define a list of substrings to v1.x models | |
V1_MODELS = [ | |
'CompVis/stable-diffusion-v1-4', | |
'runwayml/stable-diffusion-v1-5', | |
] | |
# define a list of substrings to search for SDXL base models | |
SDXL_MODELS = [ | |
'stabilityai/stable-diffusion-xl-base-0.9', | |
'stabilityai/stable-diffusion-xl-refiner-0.9' | |
] | |
# define a list of substrings to search for | |
ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS + SDXL_MODELS | |
ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] | |
def check_if_model_exist( | |
output_name, output_dir, save_model_as, headless=False | |
): | |
if headless: | |
log.info( | |
'Headless mode, skipping verification if model already exist... if model already exist it will be overwritten...' | |
) | |
return False | |
if save_model_as in ['diffusers', 'diffusers_safetendors']: | |
ckpt_folder = os.path.join(output_dir, output_name) | |
if os.path.isdir(ckpt_folder): | |
msg = f'A diffuser model with the same name {ckpt_folder} already exists. Do you want to overwrite it?' | |
if not easygui.ynbox(msg, 'Overwrite Existing Model?'): | |
log.info( | |
'Aborting training due to existing model with same name...' | |
) | |
return True | |
elif save_model_as in ['ckpt', 'safetensors']: | |
ckpt_file = os.path.join(output_dir, output_name + '.' + save_model_as) | |
if os.path.isfile(ckpt_file): | |
msg = f'A model with the same file name {ckpt_file} already exists. Do you want to overwrite it?' | |
if not easygui.ynbox(msg, 'Overwrite Existing Model?'): | |
log.info( | |
'Aborting training due to existing model with same name...' | |
) | |
return True | |
else: | |
log.info( | |
'Can\'t verify if existing model exist when save model is set a "same as source model", continuing to train model...' | |
) | |
return False | |
return False | |
def output_message(msg='', title='', headless=False): | |
if headless: | |
log.info(msg) | |
else: | |
msgbox(msg=msg, title=title) | |
def update_my_data(my_data): | |
# Update the optimizer based on the use_8bit_adam flag | |
use_8bit_adam = my_data.get('use_8bit_adam', False) | |
my_data.setdefault('optimizer', 'AdamW8bit' if use_8bit_adam else 'AdamW') | |
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model | |
model_list = my_data.get('model_list', []) | |
pretrained_model_name_or_path = my_data.get( | |
'pretrained_model_name_or_path', '' | |
) | |
if ( | |
not model_list | |
or pretrained_model_name_or_path not in ALL_PRESET_MODELS | |
): | |
my_data['model_list'] = 'custom' | |
# Convert values to int if they are strings | |
for key in ['epoch', 'save_every_n_epochs', 'lr_warmup']: | |
value = my_data.get(key, 0) | |
if isinstance(value, str) and value.strip().isdigit(): | |
my_data[key] = int(value) | |
elif not value: | |
my_data[key] = 0 | |
# Convert values to float if they are strings | |
for key in ['noise_offset', 'learning_rate', 'text_encoder_lr', 'unet_lr']: | |
value = my_data.get(key, 0) | |
if isinstance(value, str) and value.strip().isdigit(): | |
my_data[key] = float(value) | |
elif not value: | |
my_data[key] = 0 | |
# Update LoRA_type if it is set to LoCon | |
if my_data.get('LoRA_type', 'Standard') == 'LoCon': | |
my_data['LoRA_type'] = 'LyCORIS/LoCon' | |
# Update model save choices due to changes for LoRA and TI training | |
if 'save_model_as' in my_data: | |
if ( | |
my_data.get('LoRA_type') or my_data.get('num_vectors_per_token') | |
) and my_data.get('save_model_as') not in ['safetensors', 'ckpt']: | |
message = 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' | |
if my_data.get('LoRA_type'): | |
log.info(message.format('LoRA')) | |
if my_data.get('num_vectors_per_token'): | |
log.info(message.format('TI')) | |
my_data['save_model_as'] = 'safetensors' | |
return my_data | |
def get_dir_and_file(file_path): | |
dir_path, file_name = os.path.split(file_path) | |
return (dir_path, file_name) | |
def get_file_path( | |
file_path='', default_extension='.json', extension_name='Config files' | |
): | |
if ( | |
not any(var in os.environ for var in ENV_EXCLUSION) | |
and sys.platform != 'darwin' | |
): | |
current_file_path = file_path | |
# log.info(f'current file path: {current_file_path}') | |
initial_dir, initial_file = get_dir_and_file(file_path) | |
# Create a hidden Tkinter root window | |
root = Tk() | |
root.wm_attributes('-topmost', 1) | |
root.withdraw() | |
# Show the open file dialog and get the selected file path | |
file_path = filedialog.askopenfilename( | |
filetypes=( | |
(extension_name, f'*{default_extension}'), | |
('All files', '*.*'), | |
), | |
defaultextension=default_extension, | |
initialfile=initial_file, | |
initialdir=initial_dir, | |
) | |
# Destroy the hidden root window | |
root.destroy() | |
# If no file is selected, use the current file path | |
if not file_path: | |
file_path = current_file_path | |
current_file_path = file_path | |
# log.info(f'current file path: {current_file_path}') | |
return file_path | |
def get_any_file_path(file_path=''): | |
if ( | |
not any(var in os.environ for var in ENV_EXCLUSION) | |
and sys.platform != 'darwin' | |
): | |
current_file_path = file_path | |
# log.info(f'current file path: {current_file_path}') | |
initial_dir, initial_file = get_dir_and_file(file_path) | |
root = Tk() | |
root.wm_attributes('-topmost', 1) | |
root.withdraw() | |
file_path = filedialog.askopenfilename( | |
initialdir=initial_dir, | |
initialfile=initial_file, | |
) | |
root.destroy() | |
if file_path == '': | |
file_path = current_file_path | |
return file_path | |
def remove_doublequote(file_path): | |
if file_path != None: | |
file_path = file_path.replace('"', '') | |
return file_path | |
def get_folder_path(folder_path=''): | |
if ( | |
not any(var in os.environ for var in ENV_EXCLUSION) | |
and sys.platform != 'darwin' | |
): | |
current_folder_path = folder_path | |
initial_dir, initial_file = get_dir_and_file(folder_path) | |
root = Tk() | |
root.wm_attributes('-topmost', 1) | |
root.withdraw() | |
folder_path = filedialog.askdirectory(initialdir=initial_dir) | |
root.destroy() | |
if folder_path == '': | |
folder_path = current_folder_path | |
return folder_path | |
def get_saveasfile_path( | |
file_path='', defaultextension='.json', extension_name='Config files' | |
): | |
if ( | |
not any(var in os.environ for var in ENV_EXCLUSION) | |
and sys.platform != 'darwin' | |
): | |
current_file_path = file_path | |
# log.info(f'current file path: {current_file_path}') | |
initial_dir, initial_file = get_dir_and_file(file_path) | |
root = Tk() | |
root.wm_attributes('-topmost', 1) | |
root.withdraw() | |
save_file_path = filedialog.asksaveasfile( | |
filetypes=( | |
(f'{extension_name}', f'{defaultextension}'), | |
('All files', '*'), | |
), | |
defaultextension=defaultextension, | |
initialdir=initial_dir, | |
initialfile=initial_file, | |
) | |
root.destroy() | |
# log.info(save_file_path) | |
if save_file_path == None: | |
file_path = current_file_path | |
else: | |
log.info(save_file_path.name) | |
file_path = save_file_path.name | |
# log.info(file_path) | |
return file_path | |
def get_saveasfilename_path( | |
file_path='', extensions='*', extension_name='Config files' | |
): | |
if ( | |
not any(var in os.environ for var in ENV_EXCLUSION) | |
and sys.platform != 'darwin' | |
): | |
current_file_path = file_path | |
# log.info(f'current file path: {current_file_path}') | |
initial_dir, initial_file = get_dir_and_file(file_path) | |
root = Tk() | |
root.wm_attributes('-topmost', 1) | |
root.withdraw() | |
save_file_path = filedialog.asksaveasfilename( | |
filetypes=( | |
(f'{extension_name}', f'{extensions}'), | |
('All files', '*'), | |
), | |
defaultextension=extensions, | |
initialdir=initial_dir, | |
initialfile=initial_file, | |
) | |
root.destroy() | |
if save_file_path == '': | |
file_path = current_file_path | |
else: | |
# log.info(save_file_path) | |
file_path = save_file_path | |
return file_path | |
def add_pre_postfix( | |
folder: str = '', | |
prefix: str = '', | |
postfix: str = '', | |
caption_file_ext: str = '.caption', | |
) -> None: | |
""" | |
Add prefix and/or postfix to the content of caption files within a folder. | |
If no caption files are found, create one with the requested prefix and/or postfix. | |
Args: | |
folder (str): Path to the folder containing caption files. | |
prefix (str, optional): Prefix to add to the content of the caption files. | |
postfix (str, optional): Postfix to add to the content of the caption files. | |
caption_file_ext (str, optional): Extension of the caption files. | |
""" | |
if prefix == '' and postfix == '': | |
return | |
image_extensions = ('.jpg', '.jpeg', '.png', '.webp') | |
image_files = [ | |
f for f in os.listdir(folder) if f.lower().endswith(image_extensions) | |
] | |
for image_file in image_files: | |
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext | |
caption_file_path = os.path.join(folder, caption_file_name) | |
if not os.path.exists(caption_file_path): | |
with open(caption_file_path, 'w', encoding='utf8') as f: | |
separator = ' ' if prefix and postfix else '' | |
f.write(f'{prefix}{separator}{postfix}') | |
else: | |
with open(caption_file_path, 'r+', encoding='utf8') as f: | |
content = f.read() | |
content = content.rstrip() | |
f.seek(0, 0) | |
prefix_separator = ' ' if prefix else '' | |
postfix_separator = ' ' if postfix else '' | |
f.write( | |
f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}' | |
) | |
def has_ext_files(folder_path: str, file_extension: str) -> bool: | |
""" | |
Check if there are any files with the specified extension in the given folder. | |
Args: | |
folder_path (str): Path to the folder containing files. | |
file_extension (str): Extension of the files to look for. | |
Returns: | |
bool: True if files with the specified extension are found, False otherwise. | |
""" | |
for file in os.listdir(folder_path): | |
if file.endswith(file_extension): | |
return True | |
return False | |
def find_replace( | |
folder_path: str = '', | |
caption_file_ext: str = '.caption', | |
search_text: str = '', | |
replace_text: str = '', | |
) -> None: | |
""" | |
Find and replace text in caption files within a folder. | |
Args: | |
folder_path (str, optional): Path to the folder containing caption files. | |
caption_file_ext (str, optional): Extension of the caption files. | |
search_text (str, optional): Text to search for in the caption files. | |
replace_text (str, optional): Text to replace the search text with. | |
""" | |
log.info('Running caption find/replace') | |
if not has_ext_files(folder_path, caption_file_ext): | |
msgbox( | |
f'No files with extension {caption_file_ext} were found in {folder_path}...' | |
) | |
return | |
if search_text == '': | |
return | |
caption_files = [ | |
f for f in os.listdir(folder_path) if f.endswith(caption_file_ext) | |
] | |
for caption_file in caption_files: | |
with open( | |
os.path.join(folder_path, caption_file), 'r', errors='ignore' | |
) as f: | |
content = f.read() | |
content = content.replace(search_text, replace_text) | |
with open(os.path.join(folder_path, caption_file), 'w') as f: | |
f.write(content) | |
def color_aug_changed(color_aug): | |
if color_aug: | |
msgbox( | |
'Disabling "Cache latent" because "Color augmentation" has been selected...' | |
) | |
return gr.Checkbox.update(value=False, interactive=False) | |
else: | |
return gr.Checkbox.update(value=True, interactive=True) | |
def save_inference_file(output_dir, v2, v_parameterization, output_name): | |
# List all files in the directory | |
files = os.listdir(output_dir) | |
# Iterate over the list of files | |
for file in files: | |
# Check if the file starts with the value of output_name | |
if file.startswith(output_name): | |
# Check if it is a file or a directory | |
if os.path.isfile(os.path.join(output_dir, file)): | |
# Split the file name and extension | |
file_name, ext = os.path.splitext(file) | |
# Copy the v2-inference-v.yaml file to the current file, with a .yaml extension | |
if v2 and v_parameterization: | |
log.info( | |
f'Saving v2-inference-v.yaml as {output_dir}/{file_name}.yaml' | |
) | |
shutil.copy( | |
f'./v2_inference/v2-inference-v.yaml', | |
f'{output_dir}/{file_name}.yaml', | |
) | |
elif v2: | |
log.info( | |
f'Saving v2-inference.yaml as {output_dir}/{file_name}.yaml' | |
) | |
shutil.copy( | |
f'./v2_inference/v2-inference.yaml', | |
f'{output_dir}/{file_name}.yaml', | |
) | |
def set_pretrained_model_name_or_path_input( | |
model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl | |
): | |
# Check if the given model_list is in the list of SDXL models | |
if str(model_list) in SDXL_MODELS: | |
log.info('SDXL model selected. Setting sdxl parameters') | |
v2 = gr.Checkbox.update(value=False, visible=False) | |
v_parameterization = gr.Checkbox.update(value=False, visible=False) | |
sdxl = gr.Checkbox.update(value=True, visible=False) | |
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) | |
pretrained_model_name_or_path_file = gr.Button.update(visible=False) | |
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) | |
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl | |
# Check if the given model_list is in the list of V2 base models | |
if str(model_list) in V2_BASE_MODELS: | |
log.info('SD v2 base model selected. Setting --v2 parameter') | |
v2 = gr.Checkbox.update(value=True, visible=False) | |
v_parameterization = gr.Checkbox.update(value=False, visible=False) | |
sdxl = gr.Checkbox.update(value=False, visible=False) | |
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) | |
pretrained_model_name_or_path_file = gr.Button.update(visible=False) | |
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) | |
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl | |
# Check if the given model_list is in the list of V parameterization models | |
if str(model_list) in V_PARAMETERIZATION_MODELS: | |
log.info( | |
'SD v2 model selected. Setting --v2 and --v_parameterization parameters' | |
) | |
v2 = gr.Checkbox.update(value=True, visible=False) | |
v_parameterization = gr.Checkbox.update(value=True, visible=False) | |
sdxl = gr.Checkbox.update(value=False, visible=False) | |
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) | |
pretrained_model_name_or_path_file = gr.Button.update(visible=False) | |
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) | |
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl | |
# Check if the given model_list is in the list of V1 models | |
if str(model_list) in V1_MODELS: | |
log.info( | |
'SD v1.4 model selected.' | |
) | |
v2 = gr.Checkbox.update(value=False, visible=False) | |
v_parameterization = gr.Checkbox.update(value=False, visible=False) | |
sdxl = gr.Checkbox.update(value=False, visible=False) | |
pretrained_model_name_or_path = gr.Textbox.update(value=str(model_list), visible=False) | |
pretrained_model_name_or_path_file = gr.Button.update(visible=False) | |
pretrained_model_name_or_path_folder = gr.Button.update(visible=False) | |
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl | |
# Check if the model_list is set to 'custom' | |
if model_list == 'custom': | |
v2 = gr.Checkbox.update(visible=True) | |
v_parameterization = gr.Checkbox.update(visible=True) | |
sdxl = gr.Checkbox.update(visible=True) | |
pretrained_model_name_or_path = gr.Textbox.update(visible=True) | |
pretrained_model_name_or_path_file = gr.Button.update(visible=True) | |
pretrained_model_name_or_path_folder = gr.Button.update(visible=True) | |
return model_list, pretrained_model_name_or_path, pretrained_model_name_or_path_file, pretrained_model_name_or_path_folder, v2, v_parameterization, sdxl | |
### | |
### Gradio common GUI section | |
### | |
def get_pretrained_model_name_or_path_file( | |
model_list, pretrained_model_name_or_path | |
): | |
pretrained_model_name_or_path = get_any_file_path( | |
pretrained_model_name_or_path | |
) | |
# set_model_list(model_list, pretrained_model_name_or_path) | |
def get_int_or_default(kwargs, key, default_value=0): | |
value = kwargs.get(key, default_value) | |
if isinstance(value, int): | |
return value | |
elif isinstance(value, str): | |
return int(value) | |
elif isinstance(value, float): | |
return int(value) | |
else: | |
log.info(f'{key} is not an int, float or a string, setting value to {default_value}') | |
return default_value | |
def get_float_or_default(kwargs, key, default_value=0.0): | |
value = kwargs.get(key, default_value) | |
if isinstance(value, float): | |
return value | |
elif isinstance(value, int): | |
return float(value) | |
elif isinstance(value, str): | |
return float(value) | |
else: | |
log.info(f'{key} is not an int, float or a string, setting value to {default_value}') | |
return default_value | |
def get_str_or_default(kwargs, key, default_value=""): | |
value = kwargs.get(key, default_value) | |
if isinstance(value, str): | |
return value | |
elif isinstance(value, int): | |
return str(value) | |
elif isinstance(value, str): | |
return str(value) | |
else: | |
return default_value | |
def run_cmd_training(**kwargs): | |
run_cmd = '' | |
learning_rate = kwargs.get("learning_rate", "") | |
if learning_rate: | |
run_cmd += f' --learning_rate="{learning_rate}"' | |
lr_scheduler = kwargs.get("lr_scheduler", "") | |
if lr_scheduler: | |
run_cmd += f' --lr_scheduler="{lr_scheduler}"' | |
lr_warmup_steps = kwargs.get("lr_warmup_steps", "") | |
if lr_warmup_steps: | |
if lr_scheduler == 'constant': | |
log.info('Can\'t use LR warmup with LR Scheduler constant... ignoring...') | |
else: | |
run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"' | |
train_batch_size = kwargs.get("train_batch_size", "") | |
if train_batch_size: | |
run_cmd += f' --train_batch_size="{train_batch_size}"' | |
max_train_steps = kwargs.get("max_train_steps", "") | |
if max_train_steps: | |
run_cmd += f' --max_train_steps="{max_train_steps}"' | |
save_every_n_epochs = kwargs.get("save_every_n_epochs") | |
if save_every_n_epochs: | |
run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"' | |
mixed_precision = kwargs.get("mixed_precision", "") | |
if mixed_precision: | |
run_cmd += f' --mixed_precision="{mixed_precision}"' | |
save_precision = kwargs.get("save_precision", "") | |
if save_precision: | |
run_cmd += f' --save_precision="{save_precision}"' | |
seed = kwargs.get("seed", "") | |
if seed != '': | |
run_cmd += f' --seed="{seed}"' | |
caption_extension = kwargs.get("caption_extension", "") | |
if caption_extension: | |
run_cmd += f' --caption_extension="{caption_extension}"' | |
cache_latents = kwargs.get('cache_latents') | |
if cache_latents: | |
run_cmd += ' --cache_latents' | |
cache_latents_to_disk = kwargs.get('cache_latents_to_disk') | |
if cache_latents_to_disk: | |
run_cmd += ' --cache_latents_to_disk' | |
optimizer_type = kwargs.get("optimizer", "AdamW") | |
run_cmd += f' --optimizer_type="{optimizer_type}"' | |
optimizer_args = kwargs.get("optimizer_args", "") | |
if optimizer_args != '': | |
run_cmd += f' --optimizer_args {optimizer_args}' | |
return run_cmd | |
def run_cmd_advanced_training(**kwargs): | |
run_cmd = '' | |
max_train_epochs = kwargs.get("max_train_epochs", "") | |
if max_train_epochs: | |
run_cmd += f' --max_train_epochs={max_train_epochs}' | |
max_data_loader_n_workers = kwargs.get("max_data_loader_n_workers", "") | |
if max_data_loader_n_workers: | |
run_cmd += f' --max_data_loader_n_workers="{max_data_loader_n_workers}"' | |
max_token_length = int(kwargs.get("max_token_length", 75)) | |
if max_token_length > 75: | |
run_cmd += f' --max_token_length={max_token_length}' | |
clip_skip = int(kwargs.get("clip_skip", 1)) | |
if clip_skip > 1: | |
run_cmd += f' --clip_skip={clip_skip}' | |
resume = kwargs.get("resume", "") | |
if resume: | |
run_cmd += f' --resume="{resume}"' | |
keep_tokens = int(kwargs.get("keep_tokens", 0)) | |
if keep_tokens > 0: | |
run_cmd += f' --keep_tokens="{keep_tokens}"' | |
caption_dropout_every_n_epochs = int(kwargs.get("caption_dropout_every_n_epochs", 0)) | |
if caption_dropout_every_n_epochs > 0: | |
run_cmd += f' --caption_dropout_every_n_epochs="{caption_dropout_every_n_epochs}"' | |
caption_dropout_rate = float(kwargs.get("caption_dropout_rate", 0)) | |
if caption_dropout_rate > 0: | |
run_cmd += f' --caption_dropout_rate="{caption_dropout_rate}"' | |
vae_batch_size = int(kwargs.get("vae_batch_size", 0)) | |
if vae_batch_size > 0: | |
run_cmd += f' --vae_batch_size="{vae_batch_size}"' | |
bucket_reso_steps = int(kwargs.get("bucket_reso_steps", 64)) | |
run_cmd += f' --bucket_reso_steps={bucket_reso_steps}' | |
save_every_n_steps = int(kwargs.get("save_every_n_steps", 0)) | |
if save_every_n_steps > 0: | |
run_cmd += f' --save_every_n_steps="{save_every_n_steps}"' | |
save_last_n_steps = int(kwargs.get("save_last_n_steps", 0)) | |
if save_last_n_steps > 0: | |
run_cmd += f' --save_last_n_steps="{save_last_n_steps}"' | |
save_last_n_steps_state = int(kwargs.get("save_last_n_steps_state", 0)) | |
if save_last_n_steps_state > 0: | |
run_cmd += f' --save_last_n_steps_state="{save_last_n_steps_state}"' | |
min_snr_gamma = int(kwargs.get("min_snr_gamma", 0)) | |
if min_snr_gamma >= 1: | |
run_cmd += f' --min_snr_gamma={min_snr_gamma}' | |
min_timestep = int(kwargs.get("min_timestep", 0)) | |
if min_timestep > 0: | |
run_cmd += f' --min_timestep={min_timestep}' | |
max_timestep = int(kwargs.get("max_timestep", 1000)) | |
if max_timestep < 1000: | |
run_cmd += f' --max_timestep={max_timestep}' | |
save_state = kwargs.get('save_state') | |
if save_state: | |
run_cmd += ' --save_state' | |
mem_eff_attn = kwargs.get('mem_eff_attn') | |
if mem_eff_attn: | |
run_cmd += ' --mem_eff_attn' | |
color_aug = kwargs.get('color_aug') | |
if color_aug: | |
run_cmd += ' --color_aug' | |
flip_aug = kwargs.get('flip_aug') | |
if flip_aug: | |
run_cmd += ' --flip_aug' | |
shuffle_caption = kwargs.get('shuffle_caption') | |
if shuffle_caption: | |
run_cmd += ' --shuffle_caption' | |
gradient_checkpointing = kwargs.get('gradient_checkpointing') | |
if gradient_checkpointing: | |
run_cmd += ' --gradient_checkpointing' | |
full_fp16 = kwargs.get('full_fp16') | |
if full_fp16: | |
run_cmd += ' --full_fp16' | |
xformers = kwargs.get('xformers') | |
if xformers: | |
run_cmd += ' --xformers' | |
persistent_data_loader_workers = kwargs.get('persistent_data_loader_workers') | |
if persistent_data_loader_workers: | |
run_cmd += ' --persistent_data_loader_workers' | |
bucket_no_upscale = kwargs.get('bucket_no_upscale') | |
if bucket_no_upscale: | |
run_cmd += ' --bucket_no_upscale' | |
random_crop = kwargs.get('random_crop') | |
if random_crop: | |
run_cmd += ' --random_crop' | |
scale_v_pred_loss_like_noise_pred = kwargs.get('scale_v_pred_loss_like_noise_pred') | |
if scale_v_pred_loss_like_noise_pred: | |
run_cmd += ' --scale_v_pred_loss_like_noise_pred' | |
noise_offset_type = kwargs.get('noise_offset_type', 'Original') | |
if noise_offset_type == 'Original': | |
noise_offset = float(kwargs.get("noise_offset", 0)) | |
if noise_offset > 0: | |
run_cmd += f' --noise_offset={noise_offset}' | |
adaptive_noise_scale = float(kwargs.get("adaptive_noise_scale", 0)) | |
if adaptive_noise_scale != 0 and noise_offset > 0: | |
run_cmd += f' --adaptive_noise_scale={adaptive_noise_scale}' | |
else: | |
multires_noise_iterations = int(kwargs.get("multires_noise_iterations", 0)) | |
if multires_noise_iterations > 0: | |
run_cmd += f' --multires_noise_iterations="{multires_noise_iterations}"' | |
multires_noise_discount = float(kwargs.get("multires_noise_discount", 0)) | |
if multires_noise_discount > 0: | |
run_cmd += f' --multires_noise_discount="{multires_noise_discount}"' | |
additional_parameters = kwargs.get("additional_parameters", "") | |
if additional_parameters: | |
run_cmd += f' {additional_parameters}' | |
use_wandb = kwargs.get('use_wandb') | |
if use_wandb: | |
run_cmd += ' --log_with wandb' | |
wandb_api_key = kwargs.get("wandb_api_key", "") | |
if wandb_api_key: | |
run_cmd += f' --wandb_api_key="{wandb_api_key}"' | |
return run_cmd | |
def verify_image_folder_pattern(folder_path): | |
false_response = True # temporarily set to true to prevent stopping training in case of false positive | |
true_response = True | |
# Check if the folder exists | |
if not os.path.isdir(folder_path): | |
log.error(f"The provided path '{folder_path}' is not a valid folder. Please follow the folder structure documentation found at docs\image_folder_structure.md ...") | |
return false_response | |
# Create a regular expression pattern to match the required sub-folder names | |
# The pattern should start with one or more digits (\d+) followed by an underscore (_) | |
# After the underscore, it should match one or more word characters (\w+), which can be letters, numbers, or underscores | |
# Example of a valid pattern matching name: 123_example_folder | |
pattern = r'^\d+_\w+' | |
# Get the list of sub-folders in the directory | |
subfolders = [ | |
os.path.join(folder_path, subfolder) | |
for subfolder in os.listdir(folder_path) | |
if os.path.isdir(os.path.join(folder_path, subfolder)) | |
] | |
# Check the pattern of each sub-folder | |
matching_subfolders = [subfolder for subfolder in subfolders if re.match(pattern, os.path.basename(subfolder))] | |
# Print non-matching sub-folders | |
non_matching_subfolders = set(subfolders) - set(matching_subfolders) | |
if non_matching_subfolders: | |
log.error(f"The following folders do not match the required pattern <number>_<text>: {', '.join(non_matching_subfolders)}") | |
log.error(f"Please follow the folder structure documentation found at docs\image_folder_structure.md ...") | |
return false_response | |
# Check if no sub-folders exist | |
if not matching_subfolders: | |
log.error(f"No image folders found in {folder_path}. Please follow the folder structure documentation found at docs\image_folder_structure.md ...") | |
return false_response | |
log.info(f'Valid image folder names found in: {folder_path}') | |
return true_response | |
def SaveConfigFile(parameters, file_path: str, exclusion = ['file_path', 'save_as', 'headless', 'print_only']): | |
# Return the values of the variables as a dictionary | |
variables = { | |
name: value | |
for name, value in sorted(parameters, key=lambda x: x[0]) | |
if name not in exclusion | |
} | |
# Save the data to the selected file | |
with open(file_path, 'w') as file: | |
json.dump(variables, file, indent=2) | |
def save_to_file(content): | |
file_path = 'logs/print_command.txt' | |
with open(file_path, 'a') as file: | |
file.write(content + '\n') |