|
import json
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from pathlib import Path
|
|
import gc
|
|
import gguf
|
|
from dequant import dequantize_tensor
|
|
|
|
import os
|
|
import argparse
|
|
import gradio as gr
|
|
|
|
|
|
import subprocess
|
|
subprocess.run('pip cache purge', shell=True)
|
|
|
|
import spaces
|
|
@spaces.GPU()
|
|
def spaces_dummy():
|
|
pass
|
|
|
|
flux_dev_repo = "ChuckMcSneed/FLUX.1-dev"
|
|
flux_schnell_repo = "black-forest-labs/FLUX.1-schnell"
|
|
system_temp_dir = "temp"
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
torch.set_grad_enabled(False)
|
|
|
|
GGUF_QTYPE = [gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q5_1,
|
|
gguf.GGMLQuantizationType.Q5_0, gguf.GGMLQuantizationType.Q4_1,
|
|
gguf.GGMLQuantizationType.Q4_0, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
|
|
|
TORCH_DTYPE = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half,
|
|
torch.bfloat16, torch.complex32, torch.chalf, torch.complex64, torch.cfloat,
|
|
torch.complex128, torch.cdouble, torch.uint8, torch.uint16, torch.uint32, torch.uint64,
|
|
torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,
|
|
torch.bool, torch.float8_e4m3fn, torch.float8_e5m2]
|
|
|
|
TORCH_QUANTIZED_DTYPE = [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]
|
|
|
|
def list_sub(a, b):
|
|
return [e for e in a if e not in b]
|
|
|
|
def is_repo_name(s):
|
|
import re
|
|
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
|
|
|
|
def clear_cache():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
def clear_sd(sd: dict):
|
|
for k in list(sd.keys()):
|
|
sd.pop(k)
|
|
del sd
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
def clone_sd(sd: dict):
|
|
from copy import deepcopy
|
|
print("Cloning state dict.")
|
|
for k in list(sd.keys()):
|
|
sd[k] = deepcopy(sd.pop(k))
|
|
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
def print_resource_usage():
|
|
import psutil
|
|
cpu_usage = psutil.cpu_percent()
|
|
ram_usage = psutil.virtual_memory().used / psutil.virtual_memory().total * 100
|
|
print(f"CPU usage: {cpu_usage}% / RAM usage: {ram_usage}%")
|
|
|
|
def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc="Start downloading...")
|
|
url = url.strip()
|
|
if "drive.google.com" in url:
|
|
original_dir = os.getcwd()
|
|
os.chdir(directory)
|
|
os.system(f"gdown --fuzzy {url}")
|
|
os.chdir(original_dir)
|
|
elif "huggingface.co" in url:
|
|
url = url.replace("?download=true", "")
|
|
if "/blob/" in url:
|
|
url = url.replace("/blob/", "/resolve/")
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
|
else:
|
|
os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
|
|
elif "civitai.com" in url:
|
|
if "?" in url:
|
|
url = url.split("?")[0]
|
|
if civitai_api_key:
|
|
url = url + f"?token={civitai_api_key}"
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
else:
|
|
print("You need an API key to download Civitai models.")
|
|
else:
|
|
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
|
|
|
|
def get_local_model_list(dir_path):
|
|
model_list = []
|
|
valid_extensions = ('.safetensors')
|
|
for file in Path(dir_path).glob("*"):
|
|
if file.suffix in valid_extensions:
|
|
file_path = str(Path(f"{dir_path}/{file.name}"))
|
|
model_list.append(file_path)
|
|
return model_list
|
|
|
|
def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
|
|
if not "http" in url and is_repo_name(url) and not Path(url).exists():
|
|
print(f"Use HF Repo: {url}")
|
|
new_file = url
|
|
elif not "http" in url and Path(url).exists():
|
|
print(f"Use local file: {url}")
|
|
new_file = url
|
|
elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
|
|
print(f"File to download alreday exists: {url}")
|
|
new_file = f"{temp_dir}/{url.split('/')[-1]}"
|
|
else:
|
|
print(f"Start downloading: {url}")
|
|
before = get_local_model_list(temp_dir)
|
|
try:
|
|
download_thing(temp_dir, url.strip(), civitai_key)
|
|
except Exception:
|
|
print(f"Download failed: {url}")
|
|
return ""
|
|
after = get_local_model_list(temp_dir)
|
|
new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
|
|
if not new_file:
|
|
print(f"Download failed: {url}")
|
|
return ""
|
|
print(f"Download completed: {url}")
|
|
return new_file
|
|
|
|
def save_readme_md(dir, url):
|
|
orig_url = ""
|
|
if "http" in url:
|
|
orig_url = url
|
|
if orig_url:
|
|
md = f"""---
|
|
license: other
|
|
license_name: flux-1-dev-non-commercial-license
|
|
license_link: https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
|
|
language:
|
|
- en
|
|
library_name: diffusers
|
|
pipeline_tag: text-to-image
|
|
tags:
|
|
- text-to-image
|
|
- Flux
|
|
---
|
|
Converted from [{orig_url}]({orig_url}).
|
|
"""
|
|
else:
|
|
md = f"""---
|
|
license: other
|
|
license_name: flux-1-dev-non-commercial-license
|
|
license_link: https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
|
|
language:
|
|
- en
|
|
library_name: diffusers
|
|
pipeline_tag: text-to-image
|
|
tags:
|
|
- text-to-image
|
|
- Flux
|
|
---
|
|
"""
|
|
path = str(Path(dir, "README.md"))
|
|
with open(path, mode='w', encoding="utf-8") as f:
|
|
f.write(md)
|
|
|
|
def is_repo_exists(repo_id):
|
|
from huggingface_hub import HfApi
|
|
api = HfApi()
|
|
try:
|
|
if api.repo_exists(repo_id=repo_id): return True
|
|
else: return False
|
|
except Exception as e:
|
|
print(f"Error: Failed to connect {repo_id}. ")
|
|
return True
|
|
|
|
def create_diffusers_repo(new_repo_id, diffusers_folder, is_private, is_overwrite, progress=gr.Progress(track_tqdm=True)):
|
|
from huggingface_hub import HfApi
|
|
import os
|
|
hf_token = os.environ.get("HF_TOKEN")
|
|
api = HfApi()
|
|
try:
|
|
progress(0, desc="Start uploading...")
|
|
api.create_repo(repo_id=new_repo_id, token=hf_token, private=is_private, exist_ok=is_overwrite)
|
|
for path in Path(diffusers_folder).glob("*"):
|
|
if path.is_dir():
|
|
api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
|
|
elif path.is_file():
|
|
api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
|
|
progress(1, desc="Uploaded.")
|
|
url = f"https://huggingface.co./{new_repo_id}"
|
|
except Exception as e:
|
|
print(f"Error: Failed to upload to {new_repo_id}. ")
|
|
print(e)
|
|
return ""
|
|
return url
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(), torch.autocast(device):
|
|
@torch.jit.script
|
|
def swap_scale_shift(weight):
|
|
shift, scale = weight.chunk(2, dim=0)
|
|
new_weight = torch.cat([scale, shift], dim=0)
|
|
return new_weight
|
|
|
|
with torch.no_grad(), torch.autocast(device):
|
|
def convert_flux_transformer_checkpoint_to_diffusers(
|
|
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0,
|
|
progress=gr.Progress(track_tqdm=True)):
|
|
def conv(cdict: dict, odict: dict, ckey: str, okey: str):
|
|
if okey in odict.keys():
|
|
progress(0, desc=f"Converting {okey} => {ckey}")
|
|
print(f"Converting {okey} => {ckey}")
|
|
cdict[ckey] = odict.pop(okey)
|
|
gc.collect()
|
|
|
|
def convswap(cdict: dict, odict: dict, ckey: str, okey: str):
|
|
if okey in odict.keys():
|
|
progress(0, desc=f"Converting (swap) {okey} => {ckey}")
|
|
print(f"Converting {okey} => {ckey} (swap)")
|
|
cdict[ckey] = swap_scale_shift(odict.pop(okey))
|
|
gc.collect()
|
|
|
|
def convqkv(cdict: dict, odict: dict, i: int):
|
|
keys = odict.keys()
|
|
if (f"double_blocks.{i}.img_attn.qkv.weight" in keys or f"double_blocks.{i}.txt_attn.qkv.weight" in keys\
|
|
or f"double_blocks.{i}.img_attn.qkv.bias" in keys or f"double_blocks.{i}.txt_attn.qkv.bias" in keys)\
|
|
and (f"double_blocks.{i}.img_attn.qkv.weight" not in keys or f"double_blocks.{i}.txt_attn.qkv.weight" not in keys\
|
|
or f"double_blocks.{i}.img_attn.qkv.bias" not in keys or f"double_blocks.{i}.txt_attn.qkv.bias" not in keys):
|
|
progress(0, desc=f"Key error in converting Q, K, V (double_blocks.{i}).")
|
|
print(f"Key error in converting Q, K, V (double_blocks.{i}).")
|
|
return
|
|
progress(0, desc=f"Converting Q, K, V (double_blocks.{i}).")
|
|
print(f"Converting Q, K, V (double_blocks.{i}).")
|
|
sample_q, sample_k, sample_v = torch.chunk(
|
|
odict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
|
|
)
|
|
context_q, context_k, context_v = torch.chunk(
|
|
odict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
|
)
|
|
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
|
odict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
|
)
|
|
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
|
odict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
|
)
|
|
cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
|
cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
|
cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
|
cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
|
cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
|
cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
|
cdict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
|
cdict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
|
cdict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
|
cdict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
|
cdict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
|
cdict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
|
gc.collect()
|
|
|
|
def convqkvmlp(cdict: dict, odict: dict, i: int, inner_dim: int, mlp_ratio: float):
|
|
keys = odict.keys()
|
|
if (f"single_blocks.{i}.linear1.weight" in keys or f"single_blocks.{i}.linear1.bias" in keys)\
|
|
and (f"single_blocks.{i}.linear1.weight" not in keys or f"single_blocks.{i}.linear1.bias" not in keys):
|
|
progress(0, desc=f"Key error in converting Q, K, V, mlp (single_blocks.{i}).")
|
|
print(f"Key error in converting Q, K, V, mlp (single_blocks.{i}).")
|
|
return
|
|
progress(0, desc=f"Converting Q, K, V, mlp (single_blocks.{i}).")
|
|
print(f"Converting Q, K, V, mlp (single_blocks.{i}).")
|
|
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
|
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
|
q, k, v, mlp = torch.split(odict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
|
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
|
odict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
|
)
|
|
cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
|
cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
|
cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
|
cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
|
cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
|
cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
|
cdict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
|
cdict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
|
gc.collect()
|
|
|
|
converted_state_dict = {}
|
|
progress(0, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
|
|
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.weight", "time_in.in_layer.weight")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.bias", "time_in.in_layer.bias")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.weight", "time_in.out_layer.weight")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.bias", "time_in.out_layer.bias")
|
|
|
|
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.bias", "vector_in.in_layer.bias")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.bias", "vector_in.out_layer.bias")
|
|
|
|
|
|
has_guidance = any("guidance" in k for k in original_state_dict)
|
|
if has_guidance:
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.bias", "guidance_in.in_layer.bias")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight")
|
|
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.bias", "guidance_in.out_layer.bias")
|
|
|
|
|
|
conv(converted_state_dict, original_state_dict, "context_embedder.weight", "txt_in.weight")
|
|
conv(converted_state_dict, original_state_dict, "context_embedder.bias", "txt_in.bias")
|
|
|
|
|
|
conv(converted_state_dict, original_state_dict, "x_embedder.weight", "img_in.weight")
|
|
conv(converted_state_dict, original_state_dict, "x_embedder.bias", "img_in.bias")
|
|
|
|
progress(0.25, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
|
|
for i in range(num_layers):
|
|
block_prefix = f"transformer_blocks.{i}."
|
|
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.weight", f"double_blocks.{i}.img_mod.lin.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.bias", f"double_blocks.{i}.img_mod.lin.bias")
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.weight", f"double_blocks.{i}.txt_mod.lin.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.bias", f"double_blocks.{i}.txt_mod.lin.bias")
|
|
|
|
convqkv(converted_state_dict, original_state_dict, i)
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"double_blocks.{i}.img_attn.norm.query_norm.scale")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"double_blocks.{i}.img_attn.norm.key_norm.scale")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_q.weight", f"double_blocks.{i}.txt_attn.norm.query_norm.scale")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_k.weight", f"double_blocks.{i}.txt_attn.norm.key_norm.scale")
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.weight", f"double_blocks.{i}.img_mlp.0.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.bias", f"double_blocks.{i}.img_mlp.0.bias")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.weight", f"double_blocks.{i}.img_mlp.2.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.bias", f"double_blocks.{i}.img_mlp.2.bias")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{i}.txt_mlp.0.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.bias", f"double_blocks.{i}.txt_mlp.0.bias")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.weight", f"double_blocks.{i}.txt_mlp.2.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.bias", f"double_blocks.{i}.txt_mlp.2.bias")
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.weight", f"double_blocks.{i}.img_attn.proj.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.bias", f"double_blocks.{i}.img_attn.proj.bias")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.weight", f"double_blocks.{i}.txt_attn.proj.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.bias", f"double_blocks.{i}.txt_attn.proj.bias")
|
|
|
|
progress(0.5, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
|
|
for i in range(num_single_layers):
|
|
block_prefix = f"single_transformer_blocks.{i}."
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.weight", f"single_blocks.{i}.modulation.lin.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.bias", f"single_blocks.{i}.modulation.lin.bias")
|
|
|
|
convqkvmlp(converted_state_dict, original_state_dict, i, inner_dim, mlp_ratio)
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"single_blocks.{i}.norm.query_norm.scale")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"single_blocks.{i}.norm.key_norm.scale")
|
|
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.weight", f"single_blocks.{i}.linear2.weight")
|
|
conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.bias", f"single_blocks.{i}.linear2.bias")
|
|
|
|
progress(0.75, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
conv(converted_state_dict, original_state_dict, "proj_out.weight", "final_layer.linear.weight")
|
|
conv(converted_state_dict, original_state_dict, "proj_out.bias", "final_layer.linear.bias")
|
|
convswap(converted_state_dict, original_state_dict, "norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight")
|
|
convswap(converted_state_dict, original_state_dict, "norm_out.linear.bias", "final_layer.adaLN_modulation.1.bias")
|
|
|
|
progress(1, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
return converted_state_dict
|
|
|
|
|
|
def read_safetensors_metadata(path):
|
|
with open(path, 'rb') as f:
|
|
header_size = int.from_bytes(f.read(8), 'little')
|
|
header_json = f.read(header_size).decode('utf-8')
|
|
header = json.loads(header_json)
|
|
metadata = header.get('__metadata__', {})
|
|
return metadata.copy()
|
|
|
|
def normalize_key(k: str):
|
|
return k.replace("vae.", "").replace("model.diffusion_model.", "")\
|
|
.replace("text_encoders.clip_l.transformer.", "")\
|
|
.replace("text_encoders.t5xxl.transformer.", "")
|
|
|
|
def load_json_list(path: str):
|
|
try:
|
|
with open(path, encoding='utf-8') as f:
|
|
return list(json.load(f))
|
|
except Exception as e:
|
|
print(e)
|
|
return []
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
def to_safetensors(sd: dict, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)):
|
|
from huggingface_hub import save_torch_state_dict
|
|
print(f"Saving a temporary file to disk: {path}")
|
|
os.makedirs(path, exist_ok=True)
|
|
try:
|
|
for k, v in sd.items():
|
|
sd[k] = v.to(device="cpu")
|
|
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
def to_safetensors_flux_module(sd: dict, path: str, pattern: str, size: str,
|
|
quantization: bool=False, name: str = "",
|
|
metadata: dict | None = None, progress=gr.Progress(track_tqdm=True)):
|
|
from huggingface_hub import save_torch_state_dict, save_torch_model
|
|
from accelerate import init_empty_weights
|
|
try:
|
|
progress(0, desc=f"Preparing to save FLUX.1 {name} to Diffusers format.")
|
|
print(f"Preparing to save FLUX.1 {name} to Diffusers format.")
|
|
for k, v in sd.items():
|
|
sd[k] = v.to(device="cpu")
|
|
progress(0, desc=f"Loading FLUX.1 {name}.")
|
|
print(f"Loading FLUX.1 {name}.")
|
|
os.makedirs(path, exist_ok=True)
|
|
if quantization:
|
|
progress(0.5, desc=f"Saving quantized FLUX.1 {name} to {path}")
|
|
print(f"Saving quantized FLUX.1 {name} to {path}")
|
|
else:
|
|
progress(0.5, desc=f"Saving FLUX.1 {name} to: {path}")
|
|
if False and path.endswith("/transformer"):
|
|
from diffusers import FluxTransformer2DModel
|
|
has_guidance = any("guidance" in k for k in sd)
|
|
with init_empty_weights():
|
|
model = FluxTransformer2DModel(guidance_embeds=has_guidance)
|
|
model.to("cpu")
|
|
model.load_state_dict(sd, strict=True)
|
|
print(f"Saving FLUX.1 {name} to: {path} (FluxTransformer2DModel)")
|
|
if metadata is not None:
|
|
progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}")
|
|
save_torch_model(model=model, save_directory=path,
|
|
filename_pattern=pattern, max_shard_size=size, metadata=metadata)
|
|
else:
|
|
save_torch_model(model=model, save_directory=path,
|
|
filename_pattern=pattern, max_shard_size=size)
|
|
else:
|
|
print(f"Saving FLUX.1 {name} to: {path}")
|
|
if metadata is not None:
|
|
progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}")
|
|
save_torch_state_dict(state_dict=sd, save_directory=path,
|
|
filename_pattern=pattern, max_shard_size=size, metadata=metadata)
|
|
else:
|
|
save_torch_state_dict(state_dict=sd, save_directory=path,
|
|
filename_pattern=pattern, max_shard_size=size)
|
|
progress(1, desc=f"Saved FLUX.1 {name} to: {path}")
|
|
print(f"Saved FLUX.1 {name} to: {path}")
|
|
except Exception as e:
|
|
print(e)
|
|
finally:
|
|
gc.collect()
|
|
|
|
flux_transformer_json = "flux_transformer_keys.json"
|
|
flux_t5xxl_json = "flux_t5xxl_keys.json"
|
|
flux_clip_json = "flux_clip_keys.json"
|
|
flux_vae_json = "flux_vae_keys.json"
|
|
keys_flux_t5xxl = set(load_json_list(flux_t5xxl_json))
|
|
keys_flux_transformer = set(load_json_list(flux_transformer_json))
|
|
keys_flux_clip = set(load_json_list(flux_clip_json))
|
|
keys_flux_vae = set(load_json_list(flux_vae_json))
|
|
|
|
with torch.no_grad():
|
|
def dequant_tensor(v: torch.Tensor, dtype: torch.dtype, dequant: bool):
|
|
try:
|
|
|
|
if dequant:
|
|
qtype = v.tensor_type
|
|
if v.dtype in TORCH_DTYPE: return v.to(dtype) if v.dtype != dtype else v
|
|
elif qtype in GGUF_QTYPE: return dequantize_tensor(v, dtype)
|
|
elif torch.dtype in TORCH_QUANTIZED_DTYPE: return torch.dequantize(v).to(dtype)
|
|
else: return torch.dequantize(v).to(dtype)
|
|
else: return v.to(dtype) if v.dtype != dtype else v
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
with torch.no_grad():
|
|
def normalize_flux_state_dict(path: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
dequant: bool = False, progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc=f"Loading and normalizing FLUX.1 safetensors: {path}")
|
|
print(f"Loading and normalizing FLUX.1 safetensors: {path}")
|
|
new_sd = dict()
|
|
state_dict = load_file(path, device="cpu")
|
|
try:
|
|
for k in list(state_dict.keys()):
|
|
v = state_dict.pop(k)
|
|
nk = normalize_key(k)
|
|
print(f"{k} => {nk}")
|
|
new_sd[nk] = dequant_tensor(v, dtype, dequant)
|
|
except Exception as e:
|
|
print(e)
|
|
return
|
|
finally:
|
|
clear_sd(state_dict)
|
|
new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
|
|
metadata = read_safetensors_metadata(path)
|
|
progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
|
|
print(f"Saving FLUX.1 safetensors: {new_path}")
|
|
os.makedirs(savepath, exist_ok=True)
|
|
save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
|
|
progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
|
|
print(f"Saved FLUX.1 safetensors: {new_path}")
|
|
clear_sd(new_sd)
|
|
|
|
with torch.no_grad():
|
|
def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
|
|
dequant: bool = False, name: str = "", keys: set = {},
|
|
progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc=f"Loading and normalizing FLUX.1 {name} safetensors: {path}")
|
|
print(f"Loading and normalizing FLUX.1 {name} safetensors: {path}")
|
|
new_sd = dict()
|
|
state_dict = load_file(path, device="cpu")
|
|
try:
|
|
for k in list(state_dict.keys()):
|
|
if k not in keys: state_dict.pop(k)
|
|
gc.collect()
|
|
for k in list(state_dict.keys()):
|
|
v = state_dict.pop(k)
|
|
if k in keys:
|
|
nk = normalize_key(k)
|
|
progress(0.5, desc=f"{k} => {nk}")
|
|
print(f"{k} => {nk}")
|
|
new_sd[nk] = dequant_tensor(v, dtype, dequant)
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
return None
|
|
finally:
|
|
progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
|
|
print(f"Normalized FLUX.1 {name} safetensors: {path}")
|
|
clear_sd(state_dict)
|
|
return new_sd
|
|
|
|
with torch.no_grad():
|
|
def convert_flux_transformer_sd_to_diffusers(sd: dict, progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc="Converting FLUX.1 state dict to Diffusers format.")
|
|
print("Converting FLUX.1 state dict to Diffusers format.")
|
|
num_layers = 19
|
|
num_single_layers = 38
|
|
inner_dim = 3072
|
|
mlp_ratio = 4.0
|
|
try:
|
|
sd = convert_flux_transformer_checkpoint_to_diffusers(
|
|
sd, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
finally:
|
|
progress(1, desc="Converted FLUX.1 state dict to Diffusers format.")
|
|
print("Converted FLUX.1 state dict to Diffusers format.")
|
|
gc.collect()
|
|
return sd
|
|
|
|
with torch.no_grad():
|
|
def load_sharded_safetensors(path: str):
|
|
import glob
|
|
sd = {}
|
|
try:
|
|
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
sharded_sd = load_file(str(filepath), device="cpu")
|
|
for k, v in sharded_sd.items():
|
|
sharded_sd[k] = v.to(device="cpu")
|
|
sd = sd | sharded_sd.copy()
|
|
clear_sd(sharded_sd)
|
|
except Exception as e:
|
|
print(e)
|
|
return sd
|
|
|
|
|
|
with torch.no_grad():
|
|
def convert_flux_transformer_sd_to_diffusers_sharded(sd: dict, path: str, pattern: str,
|
|
size: str, progress=gr.Progress(track_tqdm=True)):
|
|
from huggingface_hub import save_torch_state_dict
|
|
import glob
|
|
try:
|
|
progress(0, desc=f"Saving temporary files to disk: {path}")
|
|
print(f"Saving temporary files to disk: {path}")
|
|
os.makedirs(path, exist_ok=True)
|
|
for k, v in sd.items():
|
|
if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
|
|
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
|
clear_sd(sd)
|
|
progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
|
print(f"Saved temporary files to disk: {path}")
|
|
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
progress(0.25, desc=f"Processing temporary files: {str(filepath)}")
|
|
print(f"Processing temporary files: {str(filepath)}")
|
|
sharded_sd = load_file(str(filepath), device="cpu")
|
|
sharded_sd = convert_flux_transformer_sd_to_diffusers(sharded_sd)
|
|
for k, v in sharded_sd.items():
|
|
sharded_sd[k] = v.to(device="cpu")
|
|
save_file(sharded_sd, str(filepath))
|
|
clear_sd(sharded_sd)
|
|
print(f"Loading temporary files from disk: {path}")
|
|
sd = load_sharded_safetensors(path)
|
|
print(f"Loaded temporary files from disk: {path}")
|
|
except Exception as e:
|
|
print(e)
|
|
return sd
|
|
|
|
with torch.no_grad():
|
|
def extract_normalized_flux_state_dict_sharded(loadpath: str, dtype: torch.dtype,
|
|
dequant: bool, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)):
|
|
from huggingface_hub import save_torch_state_dict
|
|
import glob
|
|
try:
|
|
progress(0, desc=f"Loading model file: {loadpath}")
|
|
print(f"Loading model file: {loadpath}")
|
|
sd = load_file(loadpath, device="cpu")
|
|
progress(0, desc=f"Saving temporary files to disk: {path}")
|
|
print(f"Saving temporary files to disk: {path}")
|
|
os.makedirs(path, exist_ok=True)
|
|
for k, v in sd.items():
|
|
sd[k] = v.to(device="cpu")
|
|
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
|
|
clear_sd(sd)
|
|
progress(0.25, desc=f"Saved temporary files to disk: {path}")
|
|
print(f"Saved temporary files to disk: {path}")
|
|
for filepath in glob.glob(f"{path}/*.safetensors"):
|
|
progress(0.25, desc=f"Processing temporary files: {str(filepath)}")
|
|
print(f"Processing temporary files: {str(filepath)}")
|
|
sharded_sd = extract_norm_flux_module_sd(str(filepath), dtype, dequant,
|
|
"Transformer", keys_flux_transformer)
|
|
for k, v in sharded_sd.items():
|
|
sharded_sd[k] = v.to(device="cpu")
|
|
save_file(sharded_sd, str(filepath))
|
|
clear_sd(sharded_sd)
|
|
print(f"Processed temporary files: {str(filepath)}")
|
|
print(f"Loading temporary files from disk: {path}")
|
|
sd = load_sharded_safetensors(path)
|
|
print(f"Loaded temporary files from disk: {path}")
|
|
except Exception as e:
|
|
print(e)
|
|
return sd
|
|
|
|
def download_repo(repo_name, path, use_original=["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)):
|
|
from huggingface_hub import snapshot_download
|
|
print(f"Downloading {repo_name}.")
|
|
try:
|
|
if "text_encoder_2" in use_original:
|
|
snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"])
|
|
else:
|
|
snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "text_encoder_2/model*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"])
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
def copy_nontensor_files(from_path, to_path, use_original=["vae", "text_encoder"]):
|
|
import shutil
|
|
if "text_encoder_2" in use_original:
|
|
te_from = str(Path(from_path, "text_encoder_2"))
|
|
te_to = str(Path(to_path, "text_encoder_2"))
|
|
print(f"Copying Text Encoder 2 files {te_from} to {te_to}")
|
|
shutil.copytree(te_from, te_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
|
if "text_encoder" in use_original:
|
|
te1_from = str(Path(from_path, "text_encoder"))
|
|
te1_to = str(Path(to_path, "text_encoder"))
|
|
print(f"Copying Text Encoder 1 files {te1_from} to {te1_to}")
|
|
shutil.copytree(te1_from, te1_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
|
if "vae" in use_original:
|
|
vae_from = str(Path(from_path, "vae"))
|
|
vae_to = str(Path(to_path, "vae"))
|
|
print(f"Copying VAE files {vae_from} to {vae_to}")
|
|
shutil.copytree(vae_from, vae_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
|
tn2_from = str(Path(from_path, "tokenizer_2"))
|
|
tn2_to = str(Path(to_path, "tokenizer_2"))
|
|
print(f"Copying Tokenizer 2 files {tn2_from} to {tn2_to}")
|
|
shutil.copytree(tn2_from, tn2_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True)
|
|
print(f"Copying non-tensor files {from_path} to {to_path}")
|
|
shutil.copytree(from_path, to_path, ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp", "*.index.json"), dirs_exist_ok=True)
|
|
|
|
def save_flux_other_diffusers(path: str, model_type: str = "dev", use_original: list = ["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)):
|
|
import shutil
|
|
progress(0, desc="Loading FLUX.1 Components.")
|
|
print("Loading FLUX.1 Components.")
|
|
temppath = system_temp_dir
|
|
if model_type == "schnell": repo = flux_schnell_repo
|
|
else: repo = flux_dev_repo
|
|
os.makedirs(temppath, exist_ok=True)
|
|
os.makedirs(path, exist_ok=True)
|
|
download_repo(repo, temppath, use_original)
|
|
progress(0.5, desc="Saving FLUX.1 Components.")
|
|
print("Saving FLUX.1 Components.")
|
|
copy_nontensor_files(temppath, path, use_original)
|
|
shutil.rmtree(temppath)
|
|
|
|
with torch.no_grad():
|
|
def fix_flux_safetensors(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
quantization: bool = False, model_type: str = "dev", dequant: bool = False):
|
|
save_flux_other_diffusers(savepath, model_type)
|
|
normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
|
|
clear_cache()
|
|
|
|
with torch.no_grad():
|
|
def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
quantization: bool = False, model_type: str = "dev",
|
|
dequant: bool = False, use_original: list = ["vae", "text_encoder"],
|
|
new_repo_id: str = "", local: bool = False, progress=gr.Progress(track_tqdm=True)):
|
|
unet_sd_path = savepath.removesuffix("/") + "/transformer"
|
|
unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
|
unet_sd_size = "9.5GB"
|
|
te_sd_path = savepath.removesuffix("/") + "/text_encoder_2"
|
|
te_sd_pattern = "model{suffix}.safetensors"
|
|
te_sd_size = "5GB"
|
|
clip_sd_path = savepath.removesuffix("/") + "/text_encoder"
|
|
clip_sd_pattern = "model{suffix}.safetensors"
|
|
clip_sd_size = "9.5GB"
|
|
vae_sd_path = savepath.removesuffix("/") + "/vae"
|
|
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
|
vae_sd_size = "9.5GB"
|
|
print_resource_usage()
|
|
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
|
clear_cache()
|
|
print_resource_usage()
|
|
if "vae" not in use_original:
|
|
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
|
keys_flux_vae)
|
|
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
|
quantization, "VAE", None)
|
|
clear_sd(vae_sd)
|
|
print_resource_usage()
|
|
if "text_encoder" not in use_original:
|
|
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
|
keys_flux_clip)
|
|
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
|
quantization, "Text Encoder", None)
|
|
clear_sd(clip_sd)
|
|
print_resource_usage()
|
|
if "text_encoder_2" not in use_original:
|
|
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
|
keys_flux_t5xxl)
|
|
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
|
quantization, "Text Encoder 2", None)
|
|
clear_sd(te_sd)
|
|
print_resource_usage()
|
|
unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
|
|
keys_flux_transformer)
|
|
clear_cache()
|
|
print_resource_usage()
|
|
if not local:
|
|
os.remove(loadpath)
|
|
print("Deleted downloaded file.")
|
|
clear_cache()
|
|
print_resource_usage()
|
|
unet_sd = convert_flux_transformer_sd_to_diffusers(unet_sd)
|
|
clear_cache()
|
|
print_resource_usage()
|
|
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
|
quantization, "Transformer", metadata)
|
|
clear_sd(unet_sd)
|
|
print_resource_usage()
|
|
save_flux_other_diffusers(savepath, model_type, use_original)
|
|
print_resource_usage()
|
|
|
|
with torch.no_grad():
|
|
def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
|
|
quantization: bool = False, model_type: str = "dev",
|
|
dequant: bool = False, use_original: list = ["vae", "text_encoder"],
|
|
new_repo_id: str = "", progress=gr.Progress(track_tqdm=True)):
|
|
unet_sd_path = savepath.removesuffix("/") + "/transformer"
|
|
unet_temp_path = system_temp_dir.removesuffix("/") + "/sharded"
|
|
unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
|
unet_sd_size = "10GB"
|
|
unet_temp_size = "5GB"
|
|
te_sd_path = savepath.removesuffix("/") + "/text_encoder_2"
|
|
te_sd_pattern = "model{suffix}.safetensors"
|
|
te_sd_size = "5GB"
|
|
clip_sd_path = savepath.removesuffix("/") + "/text_encoder"
|
|
clip_sd_pattern = "model{suffix}.safetensors"
|
|
clip_sd_size = "10GB"
|
|
vae_sd_path = savepath.removesuffix("/") + "/vae"
|
|
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
|
|
vae_sd_size = "10GB"
|
|
print_resource_usage()
|
|
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
|
|
clear_cache()
|
|
print_resource_usage()
|
|
if "vae" not in use_original:
|
|
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
|
|
keys_flux_vae)
|
|
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
|
|
quantization, "VAE", None)
|
|
clear_sd(vae_sd)
|
|
print_resource_usage()
|
|
if "text_encoder" not in use_original:
|
|
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
|
|
keys_flux_clip)
|
|
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
|
|
quantization, "Text Encoder", None)
|
|
clear_sd(clip_sd)
|
|
print_resource_usage()
|
|
if "text_encoder_2" not in use_original:
|
|
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
|
|
keys_flux_t5xxl)
|
|
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
|
|
quantization, "Text Encoder 2", None)
|
|
clear_sd(te_sd)
|
|
print_resource_usage()
|
|
unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
|
|
unet_temp_path, unet_sd_pattern, unet_temp_size)
|
|
clear_cache()
|
|
print_resource_usage()
|
|
unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
|
|
unet_sd_pattern, unet_temp_size)
|
|
clear_cache()
|
|
print_resource_usage()
|
|
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
|
|
quantization, "Transformer", metadata)
|
|
clear_sd(unet_sd)
|
|
print_resource_usage()
|
|
save_flux_other_diffusers(savepath, model_type, use_original)
|
|
print_resource_usage()
|
|
|
|
def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
|
|
model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
|
|
hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc="Start converting...")
|
|
temp_dir = "."
|
|
print_resource_usage()
|
|
new_file = get_download_file(temp_dir, url, civitai_key)
|
|
if not new_file:
|
|
print(f"Not found: {url}")
|
|
return ""
|
|
new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_")
|
|
|
|
dtype = torch.bfloat16
|
|
quantization = False
|
|
if data_type == "fp8": dtype = torch.float8_e4m3fn
|
|
elif data_type == "fp16": dtype = torch.float16
|
|
elif data_type == "qfloat8":
|
|
dtype = torch.bfloat16
|
|
quantization = True
|
|
else: dtype = torch.bfloat16
|
|
|
|
new_repo_id = f"{hf_user}/{Path(new_repo_name).stem}"
|
|
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
|
|
flux_to_diffusers_lowmem(new_file, new_repo_name, dtype, quantization, model_type, dequant, use_original, new_repo_id)
|
|
|
|
"""if is_upload_sf:
|
|
import shutil
|
|
shutil.move(str(Path(new_file).resolve()), str(Path(new_repo_name, Path(new_file).name).resolve()))
|
|
else: os.remove(new_file)"""
|
|
|
|
progress(1, desc="Converted.")
|
|
q.put(new_repo_name)
|
|
return new_repo_name
|
|
|
|
def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=False, data_type="bf16",
|
|
model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
|
|
progress(0, desc="Start converting...")
|
|
temp_dir = "."
|
|
print_resource_usage()
|
|
new_file = get_download_file(temp_dir, url, civitai_key)
|
|
if not new_file:
|
|
print(f"Not found: {url}")
|
|
return ""
|
|
new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_")
|
|
|
|
dtype = torch.bfloat16
|
|
quantization = False
|
|
if data_type == "fp8": dtype = torch.float8_e4m3fn
|
|
elif data_type == "fp16": dtype = torch.float16
|
|
elif data_type == "qfloat8":
|
|
dtype = torch.bfloat16
|
|
quantization = True
|
|
else: dtype = torch.bfloat16
|
|
|
|
fix_flux_safetensors(new_file, new_repo_name, dtype, model_type, dequant)
|
|
|
|
os.remove(new_file)
|
|
|
|
progress(1, desc="Converted.")
|
|
q.put(new_repo_name)
|
|
return new_repo_name
|
|
|
|
def convert_url_to_diffusers_repo_flux(dl_url, hf_user, hf_repo, hf_token, civitai_key="", is_private=True, is_overwrite=False,
|
|
is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False,
|
|
repo_urls=[], fix_only=False, use_original=["vae", "text_encoder"],
|
|
progress=gr.Progress(track_tqdm=True)):
|
|
import multiprocessing as mp
|
|
import shutil
|
|
if not hf_user:
|
|
print(f"Invalid user name: {hf_user}")
|
|
progress(1, desc=f"Invalid user name: {hf_user}")
|
|
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
|
|
if hf_token and not os.environ.get("HF_TOKEN"): os.environ['HF_TOKEN'] = hf_token
|
|
if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY")
|
|
q = mp.Queue()
|
|
if fix_only:
|
|
p = mp.Process(target=convert_url_to_fixed_flux_safetensors, args=(dl_url, civitai_key,
|
|
is_upload_sf, data_type, model_type, dequant, q))
|
|
|
|
else:
|
|
p = mp.Process(target=convert_url_to_diffusers_flux, args=(dl_url, civitai_key,
|
|
is_upload_sf, data_type, model_type, dequant, use_original, hf_user, hf_repo, q))
|
|
|
|
p.start()
|
|
new_path = q.get()
|
|
p.join()
|
|
if not new_path: return ""
|
|
new_repo_id = f"{hf_user}/{Path(new_path).stem}"
|
|
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
|
|
if not is_repo_name(new_repo_id):
|
|
print(f"Invalid repo name: {new_repo_id}")
|
|
progress(1, desc=f"Invalid repo name: {new_repo_id}")
|
|
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
|
|
if not is_overwrite and is_repo_exists(new_repo_id):
|
|
print(f"Repo already exists: {new_repo_id}")
|
|
progress(1, desc=f"Repo already exists: {new_repo_id}")
|
|
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
|
|
|
|
repo_url = create_diffusers_repo(new_repo_id, new_path, is_private, is_overwrite)
|
|
shutil.rmtree(new_path)
|
|
if not repo_urls: repo_urls = []
|
|
repo_urls.append(repo_url)
|
|
md = "Your new repo:<br>"
|
|
for u in repo_urls:
|
|
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
|
|
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--url", default=None, type=str, required=False, help="URL of the model to convert.")
|
|
parser.add_argument("--file", default=None, type=str, required=False, help="Filename of the model to convert.")
|
|
parser.add_argument("--fix", action="store_true", help="Only fix the keys of the local model.")
|
|
parser.add_argument("--civitai_key", default=None, type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
|
|
parser.add_argument("--dtype", type=str, default="fp8")
|
|
parser.add_argument("--model", type=str, default="dev")
|
|
parser.add_argument("--dequant", action="store_true", help="Dequantize model.")
|
|
args = parser.parse_args()
|
|
assert (args.url, args.file) != (None, None), "Must provide --url or --file!"
|
|
|
|
dtype = torch.bfloat16
|
|
quantization = False
|
|
if args.dtype == "fp8": dtype = torch.float8_e4m3fn
|
|
elif args.dtype == "fp16": dtype = torch.float16
|
|
elif args.dtype == "qfloat8":
|
|
dtype = torch.bfloat16
|
|
quantization = True
|
|
else: dtype = torch.bfloat16
|
|
|
|
use_original = ["vae", "text_encoder"]
|
|
new_repo_id = ""
|
|
use_local = True
|
|
|
|
if args.file is not None and Path(args.file).exists():
|
|
if args.fix: normalize_flux_state_dict(args.file, ".", dtype, args.dequant)
|
|
else: flux_to_diffusers_lowmem(args.file, Path(args.file).stem, dtype, quantization,
|
|
args.model, args.dequant, use_original, new_repo_id, use_local)
|
|
elif args.url is not None:
|
|
convert_url_to_diffusers_flux(args.url, args.civitai_key, False, args.dtype, args.model,
|
|
args.dequant)
|
|
|