|
import torch
|
|
from pathlib import Path
|
|
from utils import get_download_file
|
|
from stkey import read_safetensors_key
|
|
try:
|
|
from diffusers import BitsAndBytesConfig
|
|
is_nf4 = True
|
|
except Exception:
|
|
is_nf4 = False
|
|
|
|
|
|
DTYPE_DEFAULT = "default"
|
|
DTYPE_DICT = {
|
|
"fp16": torch.float16,
|
|
"bf16": torch.bfloat16,
|
|
"fp32": torch.float32,
|
|
"fp8": torch.float8_e4m3fn,
|
|
}
|
|
|
|
QTYPES = []
|
|
|
|
def get_dtypes():
|
|
return list(DTYPE_DICT.keys()) + [DTYPE_DEFAULT] + QTYPES
|
|
|
|
|
|
def get_dtype(dtype: str):
|
|
if dtype in set(QTYPES): return torch.bfloat16
|
|
return DTYPE_DICT.get(dtype, torch.float16)
|
|
|
|
|
|
from diffusers import (
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
EulerDiscreteScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
HeunDiscreteScheduler,
|
|
LMSDiscreteScheduler,
|
|
DDIMScheduler,
|
|
DEISMultistepScheduler,
|
|
UniPCMultistepScheduler,
|
|
LCMScheduler,
|
|
PNDMScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
DPMSolverSDEScheduler,
|
|
EDMDPMSolverMultistepScheduler,
|
|
DDPMScheduler,
|
|
EDMEulerScheduler,
|
|
TCDScheduler,
|
|
)
|
|
|
|
|
|
SCHEDULER_CONFIG_MAP = {
|
|
"DPM++ 2M": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False}),
|
|
"DPM++ 2M Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
|
|
"DPM++ 2M SDE": (DPMSolverMultistepScheduler, {"use_karras_sigmas": False, "algorithm_type": "sde-dpmsolver++"}),
|
|
"DPM++ 2M SDE Karras": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++"}),
|
|
"DPM++ 2S": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
|
"DPM++ 2S Karras": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
|
|
"DPM++ 1S": (DPMSolverMultistepScheduler, {"solver_order": 1}),
|
|
"DPM++ 1S Karras": (DPMSolverMultistepScheduler, {"solver_order": 1, "use_karras_sigmas": True}),
|
|
"DPM++ 3M": (DPMSolverMultistepScheduler, {"solver_order": 3}),
|
|
"DPM++ 3M Karras": (DPMSolverMultistepScheduler, {"solver_order": 3, "use_karras_sigmas": True}),
|
|
"DPM++ SDE": (DPMSolverSDEScheduler, {"use_karras_sigmas": False}),
|
|
"DPM++ SDE Karras": (DPMSolverSDEScheduler, {"use_karras_sigmas": True}),
|
|
"DPM2": (KDPM2DiscreteScheduler, {}),
|
|
"DPM2 Karras": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
|
|
"DPM2 a": (KDPM2AncestralDiscreteScheduler, {}),
|
|
"DPM2 a Karras": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
|
|
"Euler": (EulerDiscreteScheduler, {}),
|
|
"Euler a": (EulerAncestralDiscreteScheduler, {}),
|
|
"Euler trailing": (EulerDiscreteScheduler, {"timestep_spacing": "trailing", "prediction_type": "sample"}),
|
|
"Euler a trailing": (EulerAncestralDiscreteScheduler, {"timestep_spacing": "trailing"}),
|
|
"Heun": (HeunDiscreteScheduler, {}),
|
|
"Heun Karras": (HeunDiscreteScheduler, {"use_karras_sigmas": True}),
|
|
"LMS": (LMSDiscreteScheduler, {}),
|
|
"LMS Karras": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
|
|
"DDIM": (DDIMScheduler, {}),
|
|
"DDIM trailing": (DDIMScheduler, {"timestep_spacing": "trailing"}),
|
|
"DEIS": (DEISMultistepScheduler, {}),
|
|
"UniPC": (UniPCMultistepScheduler, {}),
|
|
"UniPC Karras": (UniPCMultistepScheduler, {"use_karras_sigmas": True}),
|
|
"PNDM": (PNDMScheduler, {}),
|
|
"Euler EDM": (EDMEulerScheduler, {}),
|
|
"Euler EDM Karras": (EDMEulerScheduler, {"use_karras_sigmas": True}),
|
|
"DPM++ 2M EDM": (EDMDPMSolverMultistepScheduler, {"solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
|
|
"DPM++ 2M EDM Karras": (EDMDPMSolverMultistepScheduler, {"use_karras_sigmas": True, "solver_order": 2, "solver_type": "midpoint", "final_sigmas_type": "zero", "algorithm_type": "dpmsolver++"}),
|
|
"DDPM": (DDPMScheduler, {}),
|
|
|
|
"DPM++ 2M Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True}),
|
|
"DPM++ 2M Ef": (DPMSolverMultistepScheduler, {"euler_at_final": True}),
|
|
"DPM++ 2M SDE Lu": (DPMSolverMultistepScheduler, {"use_lu_lambdas": True, "algorithm_type": "sde-dpmsolver++"}),
|
|
"DPM++ 2M SDE Ef": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "euler_at_final": True}),
|
|
|
|
"LCM": (LCMScheduler, {}),
|
|
"TCD": (TCDScheduler, {}),
|
|
"LCM trailing": (LCMScheduler, {"timestep_spacing": "trailing"}),
|
|
"TCD trailing": (TCDScheduler, {"timestep_spacing": "trailing"}),
|
|
"LCM Auto-Loader": (LCMScheduler, {}),
|
|
"TCD Auto-Loader": (TCDScheduler, {}),
|
|
}
|
|
|
|
|
|
def get_scheduler_config(name: str):
|
|
if not name in SCHEDULER_CONFIG_MAP.keys(): return SCHEDULER_CONFIG_MAP["Euler a"]
|
|
return SCHEDULER_CONFIG_MAP[name]
|
|
|
|
|
|
def fuse_loras(pipe, lora_dict: dict, temp_dir: str, civitai_key: str="", dkwargs: dict={}):
|
|
if not lora_dict or not isinstance(lora_dict, dict): return pipe
|
|
a_list = []
|
|
w_list = []
|
|
for k, v in lora_dict.items():
|
|
if not k: continue
|
|
new_lora_file = get_download_file(temp_dir, k, civitai_key)
|
|
if not new_lora_file or not Path(new_lora_file).exists():
|
|
print(f"LoRA file not found: {k}")
|
|
continue
|
|
w_name = Path(new_lora_file).name
|
|
a_name = Path(new_lora_file).stem
|
|
pipe.load_lora_weights(new_lora_file, weight_name=w_name, adapter_name=a_name, low_cpu_mem_usage=False, **dkwargs)
|
|
a_list.append(a_name)
|
|
w_list.append(v)
|
|
if Path(new_lora_file).exists(): Path(new_lora_file).unlink()
|
|
if len(a_list) == 0: return pipe
|
|
pipe.set_adapters(a_list, adapter_weights=w_list)
|
|
pipe.fuse_lora(adapter_names=a_list, lora_scale=1.0)
|
|
pipe.unload_lora_weights()
|
|
return pipe
|
|
|
|
|
|
MODEL_TYPE_KEY = {
|
|
"model.diffusion_model.output_blocks.1.1.norm.bias": "SDXL",
|
|
"model.diffusion_model.input_blocks.11.0.out_layers.3.weight": "SD 1.5",
|
|
"double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
|
|
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale": "FLUX",
|
|
"model.diffusion_model.joint_blocks.9.x_block.attn.ln_k.weight": "SD 3.5",
|
|
}
|
|
|
|
|
|
def get_model_type_from_key(path: str):
|
|
default = "SDXL"
|
|
try:
|
|
keys = read_safetensors_key(path)
|
|
for k, v in MODEL_TYPE_KEY.items():
|
|
if k in set(keys):
|
|
print(f"Model type is {v}.")
|
|
return v
|
|
print("Model type could not be identified.")
|
|
except Exception:
|
|
return default
|
|
return default
|
|
|
|
|
|
def get_process_dtype(dtype: str, model_type: str):
|
|
if dtype in set(["fp8"] + QTYPES): return torch.bfloat16 if model_type in ["FLUX", "SD 3.5"] else torch.float16
|
|
return DTYPE_DICT.get(dtype, torch.float16)
|
|
|