flux-to-diffusers-test / convert_url_to_diffusers_flux_gr.py
John6666's picture
Upload 10 files
a64fccd verified
import spaces
import json
import torch
from safetensors.torch import load_file, save_file
from pathlib import Path
import gc
import gguf
from dequant import dequantize_tensor # https://github.com/city96/ComfyUI-GGUF
from huggingface_hub import HfFolder
import os
import argparse
import gradio as gr
# also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning
import subprocess
subprocess.run('pip cache purge', shell=True)
@spaces.GPU()
def spaces_dummy():
pass
flux_diffusers_repos = {
"dev": "ChuckMcSneed/FLUX.1-dev",
"schnell": "black-forest-labs/FLUX.1-schnell",
"dev fp8": "John6666/flux1-dev-fp8-flux",
"schnell fp8": "John6666/flux1-schnell-fp8-flux",
}
system_temp_dir = "temp"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
GGUF_QTYPE = [gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q5_0, gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q4_0, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
TORCH_DTYPE = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half,
torch.bfloat16, torch.complex32, torch.chalf, torch.complex64, torch.cfloat,
torch.complex128, torch.cdouble, torch.uint8, torch.uint16, torch.uint32, torch.uint64,
torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,
torch.bool, torch.float8_e4m3fn, torch.float8_e5m2]
TORCH_QUANTIZED_DTYPE = [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2]
def get_token():
try:
token = HfFolder.get_token()
except Exception:
token = ""
return token
def list_sub(a, b):
return [e for e in a if e not in b]
def is_repo_name(s):
import re
return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s)
def clear_cache():
torch.cuda.empty_cache()
#torch.cuda.reset_max_memory_allocated()
#torch.cuda.reset_peak_memory_stats()
gc.collect()
def clear_sd(sd: dict):
for k in list(sd.keys()):
sd.pop(k)
del sd
torch.cuda.empty_cache()
#torch.cuda.reset_max_memory_allocated()
#torch.cuda.reset_peak_memory_stats()
gc.collect()
def clone_sd(sd: dict):
from copy import deepcopy
print("Cloning state dict.")
for k in list(sd.keys()):
sd[k] = deepcopy(sd.pop(k))
#sd[k] = sd.pop(k).detach().clone().to(device="cpu")
torch.cuda.empty_cache()
gc.collect()
def print_resource_usage():
import psutil
cpu_usage = psutil.cpu_percent()
ram_usage = psutil.virtual_memory().used / psutil.virtual_memory().total * 100
print(f"CPU usage: {cpu_usage}% / RAM usage: {ram_usage}%")
def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Start downloading...")
url = url.strip()
if "drive.google.com" in url:
original_dir = os.getcwd()
os.chdir(directory)
os.system(f"gdown --fuzzy {url}")
os.chdir(original_dir)
elif "huggingface.co" in url:
url = url.replace("?download=true", "")
if "/blob/" in url:
url = url.replace("/blob/", "/resolve/")
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
else:
os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
elif "civitai.com" in url:
if "?" in url:
url = url.split("?")[0]
if civitai_api_key:
url = url + f"?token={civitai_api_key}"
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
else:
print("You need an API key to download Civitai models.")
else:
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
def get_local_model_list(dir_path):
model_list = []
valid_extensions = ('.safetensors')
for file in Path(dir_path).glob("*"):
if file.suffix in valid_extensions:
file_path = str(Path(f"{dir_path}/{file.name}"))
model_list.append(file_path)
return model_list
def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)):
if not "http" in url and is_repo_name(url) and not Path(url).exists():
print(f"Use HF Repo: {url}")
new_file = url
elif not "http" in url and Path(url).exists():
print(f"Use local file: {url}")
new_file = url
elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists():
print(f"File to download alreday exists: {url}")
new_file = f"{temp_dir}/{url.split('/')[-1]}"
else:
print(f"Start downloading: {url}")
before = get_local_model_list(temp_dir)
try:
download_thing(temp_dir, url.strip(), civitai_key)
except Exception:
print(f"Download failed: {url}")
return ""
after = get_local_model_list(temp_dir)
new_file = list_sub(after, before)[0] if list_sub(after, before) else ""
if not new_file:
print(f"Download failed: {url}")
return ""
print(f"Download completed: {url}")
return new_file
def save_readme_md(dir, url):
orig_url = ""
if "http" in url:
orig_url = url
if orig_url:
md = f"""---
license: other
license_name: flux-1-dev-non-commercial-license
license_link: https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
language:
- en
library_name: diffusers
pipeline_tag: text-to-image
tags:
- text-to-image
- Flux
---
Converted from [{orig_url}]({orig_url}).
"""
else:
md = f"""---
license: other
license_name: flux-1-dev-non-commercial-license
license_link: https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE.
language:
- en
library_name: diffusers
pipeline_tag: text-to-image
tags:
- text-to-image
- Flux
---
"""
path = str(Path(dir, "README.md"))
with open(path, mode='w', encoding="utf-8") as f:
f.write(md)
def is_repo_exists(repo_id):
from huggingface_hub import HfApi
api = HfApi()
try:
if api.repo_exists(repo_id=repo_id): return True
else: return False
except Exception as e:
print(f"Error: Failed to connect {repo_id}. ")
return True # for safe
def create_diffusers_repo(new_repo_id, diffusers_folder, is_private, is_overwrite, progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import HfApi
hf_token = get_token()
api = HfApi()
try:
progress(0, desc="Start uploading...")
api.create_repo(repo_id=new_repo_id, token=hf_token, private=is_private, exist_ok=is_overwrite)
for path in Path(diffusers_folder).glob("*"):
if path.is_dir():
api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token)
elif path.is_file():
api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token)
progress(1, desc="Uploaded.")
url = f"https://huggingface.co./{new_repo_id}"
except Exception as e:
print(f"Error: Failed to upload to {new_repo_id}. ")
print(e)
return ""
return url
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
with torch.no_grad(), torch.autocast(device):
@torch.jit.script
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
with torch.no_grad(), torch.autocast(device):
def convert_flux_transformer_checkpoint_to_diffusers(
original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0,
progress=gr.Progress(track_tqdm=True)):
def conv(cdict: dict, odict: dict, ckey: str, okey: str):
if okey in odict.keys():
progress(0, desc=f"Converting {okey} => {ckey}")
print(f"Converting {okey} => {ckey}")
cdict[ckey] = odict.pop(okey)
gc.collect()
def convswap(cdict: dict, odict: dict, ckey: str, okey: str):
if okey in odict.keys():
progress(0, desc=f"Converting (swap) {okey} => {ckey}")
print(f"Converting {okey} => {ckey} (swap)")
cdict[ckey] = swap_scale_shift(odict.pop(okey))
gc.collect()
def convqkv(cdict: dict, odict: dict, i: int):
keys = odict.keys()
if (f"double_blocks.{i}.img_attn.qkv.weight" in keys or f"double_blocks.{i}.txt_attn.qkv.weight" in keys\
or f"double_blocks.{i}.img_attn.qkv.bias" in keys or f"double_blocks.{i}.txt_attn.qkv.bias" in keys)\
and (f"double_blocks.{i}.img_attn.qkv.weight" not in keys or f"double_blocks.{i}.txt_attn.qkv.weight" not in keys\
or f"double_blocks.{i}.img_attn.qkv.bias" not in keys or f"double_blocks.{i}.txt_attn.qkv.bias" not in keys):
progress(0, desc=f"Key error in converting Q, K, V (double_blocks.{i}).")
print(f"Key error in converting Q, K, V (double_blocks.{i}).")
return
progress(0, desc=f"Converting Q, K, V (double_blocks.{i}).")
print(f"Converting Q, K, V (double_blocks.{i}).")
sample_q, sample_k, sample_v = torch.chunk(
odict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
)
context_q, context_k, context_v = torch.chunk(
odict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
odict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
odict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
cdict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
cdict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
cdict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
cdict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
cdict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
cdict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
gc.collect()
def convqkvmlp(cdict: dict, odict: dict, i: int, inner_dim: int, mlp_ratio: float):
keys = odict.keys()
if (f"single_blocks.{i}.linear1.weight" in keys or f"single_blocks.{i}.linear1.bias" in keys)\
and (f"single_blocks.{i}.linear1.weight" not in keys or f"single_blocks.{i}.linear1.bias" not in keys):
progress(0, desc=f"Key error in converting Q, K, V, mlp (single_blocks.{i}).")
print(f"Key error in converting Q, K, V, mlp (single_blocks.{i}).")
return
progress(0, desc=f"Converting Q, K, V, mlp (single_blocks.{i}).")
print(f"Converting Q, K, V, mlp (single_blocks.{i}).")
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(odict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
odict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
cdict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
cdict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
gc.collect()
converted_state_dict = {}
progress(0, desc="Converting FLUX.1 state dict to Diffusers format.")
## time_text_embed.timestep_embedder <- time_in
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.weight", "time_in.in_layer.weight")
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.bias", "time_in.in_layer.bias")
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.weight", "time_in.out_layer.weight")
conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.bias", "time_in.out_layer.bias")
## time_text_embed.text_embedder <- vector_in
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight")
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.bias", "vector_in.in_layer.bias")
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight")
conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.bias", "vector_in.out_layer.bias")
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight")
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.bias", "guidance_in.in_layer.bias")
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight")
conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.bias", "guidance_in.out_layer.bias")
# context_embedder
conv(converted_state_dict, original_state_dict, "context_embedder.weight", "txt_in.weight")
conv(converted_state_dict, original_state_dict, "context_embedder.bias", "txt_in.bias")
# x_embedder
conv(converted_state_dict, original_state_dict, "x_embedder.weight", "img_in.weight")
conv(converted_state_dict, original_state_dict, "x_embedder.bias", "img_in.bias")
progress(0.25, desc="Converting FLUX.1 state dict to Diffusers format.")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.weight", f"double_blocks.{i}.img_mod.lin.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.bias", f"double_blocks.{i}.img_mod.lin.bias")
## norm1_context
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.weight", f"double_blocks.{i}.txt_mod.lin.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.bias", f"double_blocks.{i}.txt_mod.lin.bias")
# Q, K, V
convqkv(converted_state_dict, original_state_dict, i)
# qk_norm
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"double_blocks.{i}.img_attn.norm.query_norm.scale")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"double_blocks.{i}.img_attn.norm.key_norm.scale")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_q.weight", f"double_blocks.{i}.txt_attn.norm.query_norm.scale")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_k.weight", f"double_blocks.{i}.txt_attn.norm.key_norm.scale")
# ff img_mlp
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.weight", f"double_blocks.{i}.img_mlp.0.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.bias", f"double_blocks.{i}.img_mlp.0.bias")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.weight", f"double_blocks.{i}.img_mlp.2.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.bias", f"double_blocks.{i}.img_mlp.2.bias")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{i}.txt_mlp.0.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.bias", f"double_blocks.{i}.txt_mlp.0.bias")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.weight", f"double_blocks.{i}.txt_mlp.2.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.bias", f"double_blocks.{i}.txt_mlp.2.bias")
# output projections.
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.weight", f"double_blocks.{i}.img_attn.proj.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.bias", f"double_blocks.{i}.img_attn.proj.bias")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.weight", f"double_blocks.{i}.txt_attn.proj.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.bias", f"double_blocks.{i}.txt_attn.proj.bias")
progress(0.5, desc="Converting FLUX.1 state dict to Diffusers format.")
# single transfomer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.weight", f"single_blocks.{i}.modulation.lin.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.bias", f"single_blocks.{i}.modulation.lin.bias")
# Q, K, V, mlp
convqkvmlp(converted_state_dict, original_state_dict, i, inner_dim, mlp_ratio)
# qk norm
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"single_blocks.{i}.norm.query_norm.scale")
conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"single_blocks.{i}.norm.key_norm.scale")
# output projections.
conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.weight", f"single_blocks.{i}.linear2.weight")
conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.bias", f"single_blocks.{i}.linear2.bias")
progress(0.75, desc="Converting FLUX.1 state dict to Diffusers format.")
conv(converted_state_dict, original_state_dict, "proj_out.weight", "final_layer.linear.weight")
conv(converted_state_dict, original_state_dict, "proj_out.bias", "final_layer.linear.bias")
convswap(converted_state_dict, original_state_dict, "norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight")
convswap(converted_state_dict, original_state_dict, "norm_out.linear.bias", "final_layer.adaLN_modulation.1.bias")
progress(1, desc="Converting FLUX.1 state dict to Diffusers format.")
return converted_state_dict
# read safetensors metadata
def read_safetensors_metadata(path):
with open(path, 'rb') as f:
header_size = int.from_bytes(f.read(8), 'little')
header_json = f.read(header_size).decode('utf-8')
header = json.loads(header_json)
metadata = header.get('__metadata__', {})
return metadata.copy()
def normalize_key(k: str):
return k.replace("vae.", "").replace("model.diffusion_model.", "")\
.replace("text_encoders.clip_l.transformer.", "")\
.replace("text_encoders.t5xxl.transformer.", "")
def load_json_list(path: str):
try:
with open(path, encoding='utf-8') as f:
return list(json.load(f))
except Exception as e:
print(e)
return []
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/modeling_utils.py
# https://huggingface.co./docs/huggingface_hub/v0.24.5/package_reference/serialization
# https://huggingface.co./docs/huggingface_hub/index
with torch.no_grad():
def to_safetensors(sd: dict, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import save_torch_state_dict
print(f"Saving a temporary file to disk: {path}")
os.makedirs(path, exist_ok=True)
try:
for k, v in sd.items():
sd[k] = v.to(device="cpu")
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
except Exception as e:
print(e)
# https://discuss.huggingface.co/t/t5forconditionalgeneration-checkpoint-size-mismatch-19418/24119
# https://github.com/huggingface/transformers/issues/13769
# https://github.com/huggingface/optimum-quanto/issues/278
# https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py
# https://huggingface.co./docs/accelerate/usage_guides/big_modeling
with torch.no_grad():
def to_safetensors_flux_module(sd: dict, path: str, pattern: str, size: str,
quantization: bool=False, name: str = "",
metadata: dict | None = None, progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import save_torch_state_dict, save_torch_model
from accelerate import init_empty_weights
try:
progress(0, desc=f"Preparing to save FLUX.1 {name} to Diffusers format.")
print(f"Preparing to save FLUX.1 {name} to Diffusers format.")
for k, v in sd.items():
sd[k] = v.to(device="cpu")
progress(0, desc=f"Loading FLUX.1 {name}.")
print(f"Loading FLUX.1 {name}.")
os.makedirs(path, exist_ok=True)
if quantization:
progress(0.5, desc=f"Saving quantized FLUX.1 {name} to {path}")
print(f"Saving quantized FLUX.1 {name} to {path}")
else:
progress(0.5, desc=f"Saving FLUX.1 {name} to: {path}")
if False and path.endswith("/transformer"): # omitted
from diffusers import FluxTransformer2DModel
has_guidance = any("guidance" in k for k in sd)
#with init_empty_weights():
model = FluxTransformer2DModel(guidance_embeds=has_guidance).to("cpu")
model.load_state_dict(sd, strict=True)
print(f"Saving FLUX.1 {name} to: {path} (FluxTransformer2DModel)")
if metadata is not None:
progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}")
save_torch_model(model=model, save_directory=path,
filename_pattern=pattern, max_shard_size=size, metadata=metadata)
else:
save_torch_model(model=model, save_directory=path,
filename_pattern=pattern, max_shard_size=size)
else:
print(f"Saving FLUX.1 {name} to: {path}")
if metadata is not None:
progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}")
save_torch_state_dict(state_dict=sd, save_directory=path,
filename_pattern=pattern, max_shard_size=size, metadata=metadata)
else:
save_torch_state_dict(state_dict=sd, save_directory=path,
filename_pattern=pattern, max_shard_size=size)
progress(1, desc=f"Saved FLUX.1 {name} to: {path}")
print(f"Saved FLUX.1 {name} to: {path}")
except Exception as e:
print(e)
finally:
gc.collect()
flux_transformer_json = "flux_transformer_keys.json"
flux_t5xxl_json = "flux_t5xxl_keys.json"
flux_clip_json = "flux_clip_keys.json"
flux_vae_json = "flux_vae_keys.json"
keys_flux_t5xxl = set(load_json_list(flux_t5xxl_json))
keys_flux_transformer = set(load_json_list(flux_transformer_json))
keys_flux_clip = set(load_json_list(flux_clip_json))
keys_flux_vae = set(load_json_list(flux_vae_json))
with torch.no_grad():
def dequant_tensor(v: torch.Tensor, dtype: torch.dtype, dequant: bool):
try:
#print(f"shape: {v.shape} / dim: {v.ndim}")
if dequant:
qtype = v.tensor_type
if v.dtype in TORCH_DTYPE: return v.to(dtype) if v.dtype != dtype else v
elif qtype in GGUF_QTYPE: return dequantize_tensor(v, dtype)
elif torch.dtype in TORCH_QUANTIZED_DTYPE: return torch.dequantize(v).to(dtype)
else: return torch.dequantize(v).to(dtype)
else: return v.to(dtype) if v.dtype != dtype else v
except Exception as e:
print(e)
with torch.no_grad():
def normalize_flux_state_dict(path: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
dequant: bool = False, progress=gr.Progress(track_tqdm=True)):
progress(0, desc=f"Loading and normalizing FLUX.1 safetensors: {path}")
print(f"Loading and normalizing FLUX.1 safetensors: {path}")
new_sd = dict()
state_dict = load_file(path, device="cpu")
try:
for k in list(state_dict.keys()):
v = state_dict.pop(k)
nk = normalize_key(k)
print(f"{k} => {nk}") #
new_sd[nk] = dequant_tensor(v, dtype, dequant)
except Exception as e:
print(e)
return
finally:
clear_sd(state_dict)
new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix))
metadata = read_safetensors_metadata(path)
progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}")
print(f"Saving FLUX.1 safetensors: {new_path}")
os.makedirs(savepath, exist_ok=True)
save_file(new_sd, new_path, metadata={"format": "pt", **metadata})
progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}")
print(f"Saved FLUX.1 safetensors: {new_path}")
clear_sd(new_sd)
with torch.no_grad():
def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16,
dequant: bool = False, name: str = "", keys: set = {},
progress=gr.Progress(track_tqdm=True)):
progress(0, desc=f"Loading and normalizing FLUX.1 {name} safetensors: {path}")
print(f"Loading and normalizing FLUX.1 {name} safetensors: {path}")
new_sd = dict()
state_dict = load_file(path, device="cpu")
try:
for k in list(state_dict.keys()):
if k not in keys: state_dict.pop(k)
gc.collect()
for k in list(state_dict.keys()):
v = state_dict.pop(k)
if k in keys:
nk = normalize_key(k)
progress(0.5, desc=f"{k} => {nk}") #
print(f"{k} => {nk}") #
new_sd[nk] = dequant_tensor(v, dtype, dequant)
#print_resource_usage() #
except Exception as e:
print(e)
return None
finally:
progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}")
print(f"Normalized FLUX.1 {name} safetensors: {path}")
clear_sd(state_dict)
return new_sd
with torch.no_grad():
def convert_flux_transformer_sd_to_diffusers(sd: dict, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Converting FLUX.1 state dict to Diffusers format.")
print("Converting FLUX.1 state dict to Diffusers format.")
num_layers = 19
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0
try:
sd = convert_flux_transformer_checkpoint_to_diffusers(
sd, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
)
except Exception as e:
print(e)
finally:
progress(1, desc="Converted FLUX.1 state dict to Diffusers format.")
print("Converted FLUX.1 state dict to Diffusers format.")
gc.collect()
return sd
with torch.no_grad():
def load_sharded_safetensors(path: str):
import glob
sd = {}
try:
for filepath in glob.glob(f"{path}/*.safetensors"):
sharded_sd = load_file(str(filepath), device="cpu")
for k, v in sharded_sd.items():
sharded_sd[k] = v.to(device="cpu")
sd = sd | sharded_sd.copy()
clear_sd(sharded_sd)
except Exception as e:
print(e)
return sd
# https://huggingface.co./docs/safetensors/api/torch
with torch.no_grad():
def convert_flux_transformer_sd_to_diffusers_sharded(sd: dict, path: str, pattern: str,
size: str, progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import save_torch_state_dict#, load_torch_model
import glob
try:
progress(0, desc=f"Saving temporary files to disk: {path}")
print(f"Saving temporary files to disk: {path}")
os.makedirs(path, exist_ok=True)
for k, v in sd.items():
if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu")
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
clear_sd(sd)
progress(0.25, desc=f"Saved temporary files to disk: {path}")
print(f"Saved temporary files to disk: {path}")
for filepath in glob.glob(f"{path}/*.safetensors"):
progress(0.25, desc=f"Processing temporary files: {str(filepath)}")
print(f"Processing temporary files: {str(filepath)}")
sharded_sd = load_file(str(filepath), device="cpu")
sharded_sd = convert_flux_transformer_sd_to_diffusers(sharded_sd)
for k, v in sharded_sd.items():
sharded_sd[k] = v.to(device="cpu")
save_file(sharded_sd, str(filepath))
clear_sd(sharded_sd)
print(f"Loading temporary files from disk: {path}")
sd = load_sharded_safetensors(path)
print(f"Loaded temporary files from disk: {path}")
except Exception as e:
print(e)
return sd
with torch.no_grad():
def extract_normalized_flux_state_dict_sharded(loadpath: str, dtype: torch.dtype,
dequant: bool, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import save_torch_state_dict#, load_torch_model
import glob
try:
progress(0, desc=f"Loading model file: {loadpath}")
print(f"Loading model file: {loadpath}")
sd = load_file(loadpath, device="cpu")
progress(0, desc=f"Saving temporary files to disk: {path}")
print(f"Saving temporary files to disk: {path}")
os.makedirs(path, exist_ok=True)
for k, v in sd.items():
sd[k] = v.to(device="cpu")
save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size)
clear_sd(sd)
progress(0.25, desc=f"Saved temporary files to disk: {path}")
print(f"Saved temporary files to disk: {path}")
for filepath in glob.glob(f"{path}/*.safetensors"):
progress(0.25, desc=f"Processing temporary files: {str(filepath)}")
print(f"Processing temporary files: {str(filepath)}")
sharded_sd = extract_norm_flux_module_sd(str(filepath), dtype, dequant,
"Transformer", keys_flux_transformer)
for k, v in sharded_sd.items():
sharded_sd[k] = v.to(device="cpu")
save_file(sharded_sd, str(filepath))
clear_sd(sharded_sd)
print(f"Processed temporary files: {str(filepath)}")
print(f"Loading temporary files from disk: {path}")
sd = load_sharded_safetensors(path)
print(f"Loaded temporary files from disk: {path}")
except Exception as e:
print(e)
return sd
def download_repo(repo_name, path, use_original=["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)):
from huggingface_hub import snapshot_download
print(f"Downloading {repo_name}.")
try:
if "text_encoder_2" in use_original:
snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
else:
snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "text_encoder_2/model*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"])
except Exception as e:
print(e)
def copy_missing_files(from_path, to_path, use_original=["vae", "text_encoder"]):
import shutil
if "text_encoder_2" in use_original:
te_from = str(Path(from_path, "text_encoder_2"))
te_to = str(Path(to_path, "text_encoder_2"))
print(f"Copying Text Encoder 2 files {te_from} to {te_to}")
shutil.copytree(te_from, te_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
if "text_encoder" in use_original:
te1_from = str(Path(from_path, "text_encoder"))
te1_to = str(Path(to_path, "text_encoder"))
print(f"Copying Text Encoder 1 files {te1_from} to {te1_to}")
shutil.copytree(te1_from, te1_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
if "vae" in use_original:
vae_from = str(Path(from_path, "vae"))
vae_to = str(Path(to_path, "vae"))
print(f"Copying VAE files {vae_from} to {vae_to}")
shutil.copytree(vae_from, vae_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
tn2_from = str(Path(from_path, "tokenizer_2"))
tn2_to = str(Path(to_path, "tokenizer_2"))
print(f"Copying Tokenizer 2 files {tn2_from} to {tn2_to}")
shutil.copytree(tn2_from, tn2_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True)
print(f"Copying non-tensor files {from_path} to {to_path}")
shutil.copytree(from_path, to_path, ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp", "*.index.json"), dirs_exist_ok=True)
def save_flux_other_diffusers(path: str, model_type: str = "dev", use_original: list = ["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)):
import shutil
progress(0, desc="Loading FLUX.1 Components.")
print("Loading FLUX.1 Components.")
temppath = system_temp_dir
repo = flux_diffusers_repos.get(model_type, None) if model_type in flux_diffusers_repos else flux_diffusers_repos.get("dev", None)
os.makedirs(temppath, exist_ok=True)
os.makedirs(path, exist_ok=True)
download_repo(repo, temppath, use_original)
progress(0.5, desc="Saving FLUX.1 Components.")
print("Saving FLUX.1 Components.")
copy_missing_files(temppath, path, use_original)
shutil.rmtree(temppath)
with torch.no_grad():
def fix_flux_safetensors(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
quantization: bool = False, model_type: str = "dev", dequant: bool = False):
save_flux_other_diffusers(savepath, model_type)
normalize_flux_state_dict(loadpath, savepath, dtype, dequant)
clear_cache()
with torch.no_grad(): # Much lower memory consumption, but higher disk load
def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
quantization: bool = False, model_type: str = "dev",
dequant: bool = False, use_original: list = ["vae", "text_encoder"],
new_repo_id: str = "", local: bool = False, progress=gr.Progress(track_tqdm=True)):
unet_sd_path = savepath.removesuffix("/") + "/transformer"
unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
unet_sd_size = "9.5GB"
te_sd_path = savepath.removesuffix("/") + "/text_encoder_2"
te_sd_pattern = "model{suffix}.safetensors"
te_sd_size = "5GB"
clip_sd_path = savepath.removesuffix("/") + "/text_encoder"
clip_sd_pattern = "model{suffix}.safetensors"
clip_sd_size = "9.5GB"
vae_sd_path = savepath.removesuffix("/") + "/vae"
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
vae_sd_size = "9.5GB"
print_resource_usage() #
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
clear_cache()
print_resource_usage() #
if "vae" not in use_original:
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
keys_flux_vae)
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
quantization, "VAE", None)
clear_sd(vae_sd)
print_resource_usage() #
if "text_encoder" not in use_original:
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
keys_flux_clip)
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
quantization, "Text Encoder", None)
clear_sd(clip_sd)
print_resource_usage() #
if "text_encoder_2" not in use_original:
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
keys_flux_t5xxl)
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
quantization, "Text Encoder 2", None)
clear_sd(te_sd)
print_resource_usage() #
unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer",
keys_flux_transformer)
clear_cache()
print_resource_usage() #
if not local:
os.remove(loadpath)
print("Deleted downloaded file.")
clear_cache()
print_resource_usage() #
unet_sd = convert_flux_transformer_sd_to_diffusers(unet_sd)
clear_cache()
print_resource_usage() #
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
quantization, "Transformer", metadata)
clear_sd(unet_sd)
print_resource_usage() #
save_flux_other_diffusers(savepath, model_type, use_original)
print_resource_usage() #
with torch.no_grad(): # lowest memory consumption, but higheest disk load
def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16,
quantization: bool = False, model_type: str = "dev",
dequant: bool = False, use_original: list = ["vae", "text_encoder"],
new_repo_id: str = "", progress=gr.Progress(track_tqdm=True)):
unet_sd_path = savepath.removesuffix("/") + "/transformer"
unet_temp_path = system_temp_dir.removesuffix("/") + "/sharded"
unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
unet_sd_size = "10GB"
unet_temp_size = "5GB"
te_sd_path = savepath.removesuffix("/") + "/text_encoder_2"
te_sd_pattern = "model{suffix}.safetensors"
te_sd_size = "5GB"
clip_sd_path = savepath.removesuffix("/") + "/text_encoder"
clip_sd_pattern = "model{suffix}.safetensors"
clip_sd_size = "10GB"
vae_sd_path = savepath.removesuffix("/") + "/vae"
vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors"
vae_sd_size = "10GB"
print_resource_usage() #
metadata = {"format": "pt", **read_safetensors_metadata(loadpath)}
clear_cache()
print_resource_usage() #
if "vae" not in use_original:
vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE",
keys_flux_vae)
to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size,
quantization, "VAE", None)
clear_sd(vae_sd)
print_resource_usage() #
if "text_encoder" not in use_original:
clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder",
keys_flux_clip)
to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size,
quantization, "Text Encoder", None)
clear_sd(clip_sd)
print_resource_usage() #
if "text_encoder_2" not in use_original:
te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2",
keys_flux_t5xxl)
to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size,
quantization, "Text Encoder 2", None)
clear_sd(te_sd)
print_resource_usage() #
unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant,
unet_temp_path, unet_sd_pattern, unet_temp_size)
clear_cache()
print_resource_usage() #
unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path,
unet_sd_pattern, unet_temp_size)
clear_cache()
print_resource_usage() #
to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size,
quantization, "Transformer", metadata)
clear_sd(unet_sd)
print_resource_usage() #
save_flux_other_diffusers(savepath, model_type, use_original)
print_resource_usage() #
def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16",
model_type="dev", dequant=False, use_original=["vae", "text_encoder"],
hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Start converting...")
temp_dir = "."
print_resource_usage() #
new_file = get_download_file(temp_dir, url, civitai_key)
if not new_file:
print(f"Not found: {url}")
return ""
new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
dtype = torch.bfloat16
quantization = False
if data_type == "fp8": dtype = torch.float8_e4m3fn
elif data_type == "fp16": dtype = torch.float16
elif data_type == "qfloat8":
dtype = torch.bfloat16
quantization = True
else: dtype = torch.bfloat16
new_repo_id = f"{hf_user}/{Path(new_repo_name).stem}"
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
flux_to_diffusers_lowmem(new_file, new_repo_name, dtype, quantization, model_type, dequant, use_original, new_repo_id)
if is_upload_sf:
import shutil
shutil.move(str(Path(new_file).resolve()), str(Path(new_repo_name, Path(new_file).name).resolve()))
else: os.remove(new_file)
progress(1, desc="Converted.")
q.put(new_repo_name)
return new_repo_name
def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=False, data_type="bf16",
model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Start converting...")
temp_dir = "."
print_resource_usage() #
new_file = get_download_file(temp_dir, url, civitai_key)
if not new_file:
print(f"Not found: {url}")
return ""
new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") #
dtype = torch.bfloat16
quantization = False
if data_type == "fp8": dtype = torch.float8_e4m3fn
elif data_type == "fp16": dtype = torch.float16
elif data_type == "qfloat8":
dtype = torch.bfloat16
quantization = True
else: dtype = torch.bfloat16
fix_flux_safetensors(new_file, new_repo_name, dtype, model_type, dequant)
os.remove(new_file)
progress(1, desc="Converted.")
q.put(new_repo_name)
return new_repo_name
def convert_url_to_diffusers_repo_flux(dl_url, hf_user, hf_repo, hf_token, civitai_key="", is_private=True, is_overwrite=False,
is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False,
repo_urls=[], fix_only=False, use_original=["vae", "text_encoder"],
progress=gr.Progress(track_tqdm=True)):
import multiprocessing as mp
import shutil
if not hf_user:
print(f"Invalid user name: {hf_user}")
progress(1, desc=f"Invalid user name: {hf_user}")
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
if not hf_token and os.environ.get("HF_TOKEN"): HfFolder.save_token(os.environ.get("HF_TOKEN"))
else: HfFolder.save_token(hf_token)
if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY")
q = mp.Queue()
if fix_only:
p = mp.Process(target=convert_url_to_fixed_flux_safetensors, args=(dl_url, civitai_key,
is_upload_sf, data_type, model_type, dequant, q))
#new_path = convert_url_to_fixed_flux_safetensors(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant)
else:
p = mp.Process(target=convert_url_to_diffusers_flux, args=(dl_url, civitai_key,
is_upload_sf, data_type, model_type, dequant, use_original, hf_user, hf_repo, q))
#new_path = convert_url_to_diffusers_flux(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant)
p.start()
new_path = q.get()
p.join()
if not new_path: return ""
new_repo_id = f"{hf_user}/{Path(new_path).stem}"
if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}"
if not is_repo_name(new_repo_id):
print(f"Invalid repo name: {new_repo_id}")
progress(1, desc=f"Invalid repo name: {new_repo_id}")
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
if not is_overwrite and is_repo_exists(new_repo_id):
print(f"Repo already exists: {new_repo_id}")
progress(1, desc=f"Repo already exists: {new_repo_id}")
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="")
save_readme_md(new_path, dl_url)
repo_url = create_diffusers_repo(new_repo_id, new_path, is_private, is_overwrite)
shutil.rmtree(new_path)
if not repo_urls: repo_urls = []
repo_urls.append(repo_url)
md = "Your new repo:<br>"
for u in repo_urls:
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
return gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", default=None, type=str, required=False, help="URL of the model to convert.")
parser.add_argument("--file", default=None, type=str, required=False, help="Filename of the model to convert.")
parser.add_argument("--fix", action="store_true", help="Only fix the keys of the local model.")
parser.add_argument("--civitai_key", default=None, type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).")
parser.add_argument("--dtype", type=str, default="fp8")
parser.add_argument("--model", type=str, default="dev")
parser.add_argument("--dequant", action="store_true", help="Dequantize model.")
args = parser.parse_args()
assert (args.url, args.file) != (None, None), "Must provide --url or --file!"
dtype = torch.bfloat16
quantization = False
if args.dtype == "fp8": dtype = torch.float8_e4m3fn
elif args.dtype == "fp16": dtype = torch.float16
#elif args.dtype == "qfloat8":
# dtype = torch.bfloat16
# quantization = True
else: dtype = torch.bfloat16
use_original = ["vae", "text_encoder", "text_encoder_2"]
new_repo_id = ""
use_local = True
if args.file is not None and Path(args.file).exists():
if args.fix: normalize_flux_state_dict(args.file, ".", dtype, args.dequant)
else: flux_to_diffusers_lowmem(args.file, Path(args.file).stem, dtype, quantization,
args.model, args.dequant, use_original, new_repo_id, use_local)
elif args.url is not None:
convert_url_to_diffusers_flux(args.url, args.civitai_key, False, args.dtype, args.model,
args.dequant)