import json import torch from safetensors.torch import load_file, save_file from pathlib import Path import gc import gguf from dequant import dequantize_tensor # https://github.com/city96/ComfyUI-GGUF import os import argparse import gradio as gr # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning import subprocess subprocess.run('pip cache purge', shell=True) import spaces @spaces.GPU() def spaces_dummy(): pass flux_dev_repo = "ChuckMcSneed/FLUX.1-dev" flux_schnell_repo = "black-forest-labs/FLUX.1-schnell" system_temp_dir = "temp" device = "cuda" if torch.cuda.is_available() else "cpu" torch.set_grad_enabled(False) GGUF_QTYPE = [gguf.GGMLQuantizationType.Q8_0, gguf.GGMLQuantizationType.Q5_1, gguf.GGMLQuantizationType.Q5_0, gguf.GGMLQuantizationType.Q4_1, gguf.GGMLQuantizationType.Q4_0, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16] TORCH_DTYPE = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bfloat16, torch.complex32, torch.chalf, torch.complex64, torch.cfloat, torch.complex128, torch.cdouble, torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long, torch.bool, torch.float8_e4m3fn, torch.float8_e5m2] TORCH_QUANTIZED_DTYPE = [torch.quint8, torch.qint8, torch.qint32, torch.quint4x2] def list_sub(a, b): return [e for e in a if e not in b] def is_repo_name(s): import re return re.fullmatch(r'^[^/,\s\"\']+/[^/,\s\"\']+$', s) def clear_cache(): torch.cuda.empty_cache() gc.collect() def clear_sd(sd: dict): for k in list(sd.keys()): sd.pop(k) del sd torch.cuda.empty_cache() gc.collect() def clone_sd(sd: dict): from copy import deepcopy print("Cloning state dict.") for k in list(sd.keys()): sd[k] = deepcopy(sd.pop(k)) #sd[k] = sd.pop(k).detach().clone().to(device="cpu") torch.cuda.empty_cache() gc.collect() def print_resource_usage(): import psutil cpu_usage = psutil.cpu_percent() ram_usage = psutil.virtual_memory().used / psutil.virtual_memory().total * 100 print(f"CPU usage: {cpu_usage}% / RAM usage: {ram_usage}%") def download_thing(directory, url, civitai_api_key="", progress=gr.Progress(track_tqdm=True)): progress(0, desc="Start downloading...") url = url.strip() if "drive.google.com" in url: original_dir = os.getcwd() os.chdir(directory) os.system(f"gdown --fuzzy {url}") os.chdir(original_dir) elif "huggingface.co" in url: url = url.replace("?download=true", "") if "/blob/" in url: url = url.replace("/blob/", "/resolve/") os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") else: os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") elif "civitai.com" in url: if "?" in url: url = url.split("?")[0] if civitai_api_key: url = url + f"?token={civitai_api_key}" os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") else: print("You need an API key to download Civitai models.") else: os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") def get_local_model_list(dir_path): model_list = [] valid_extensions = ('.safetensors') for file in Path(dir_path).glob("*"): if file.suffix in valid_extensions: file_path = str(Path(f"{dir_path}/{file.name}")) model_list.append(file_path) return model_list def get_download_file(temp_dir, url, civitai_key, progress=gr.Progress(track_tqdm=True)): if not "http" in url and is_repo_name(url) and not Path(url).exists(): print(f"Use HF Repo: {url}") new_file = url elif not "http" in url and Path(url).exists(): print(f"Use local file: {url}") new_file = url elif Path(f"{temp_dir}/{url.split('/')[-1]}").exists(): print(f"File to download alreday exists: {url}") new_file = f"{temp_dir}/{url.split('/')[-1]}" else: print(f"Start downloading: {url}") before = get_local_model_list(temp_dir) try: download_thing(temp_dir, url.strip(), civitai_key) except Exception: print(f"Download failed: {url}") return "" after = get_local_model_list(temp_dir) new_file = list_sub(after, before)[0] if list_sub(after, before) else "" if not new_file: print(f"Download failed: {url}") return "" print(f"Download completed: {url}") return new_file def save_readme_md(dir, url): orig_url = "" if "http" in url: orig_url = url if orig_url: md = f"""--- license: other license_name: flux-1-dev-non-commercial-license license_link: https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE. language: - en library_name: diffusers pipeline_tag: text-to-image tags: - text-to-image - Flux --- Converted from [{orig_url}]({orig_url}). """ else: md = f"""--- license: other license_name: flux-1-dev-non-commercial-license license_link: https://huggingface.co./black-forest-labs/FLUX.1-dev/blob/main/LICENSE. language: - en library_name: diffusers pipeline_tag: text-to-image tags: - text-to-image - Flux --- """ path = str(Path(dir, "README.md")) with open(path, mode='w', encoding="utf-8") as f: f.write(md) def is_repo_exists(repo_id): from huggingface_hub import HfApi api = HfApi() try: if api.repo_exists(repo_id=repo_id): return True else: return False except Exception as e: print(f"Error: Failed to connect {repo_id}. ") return True # for safe def create_diffusers_repo(new_repo_id, diffusers_folder, is_private, is_overwrite, progress=gr.Progress(track_tqdm=True)): from huggingface_hub import HfApi import os hf_token = os.environ.get("HF_TOKEN") api = HfApi() try: progress(0, desc="Start uploading...") api.create_repo(repo_id=new_repo_id, token=hf_token, private=is_private, exist_ok=is_overwrite) for path in Path(diffusers_folder).glob("*"): if path.is_dir(): api.upload_folder(repo_id=new_repo_id, folder_path=str(path), path_in_repo=path.name, token=hf_token) elif path.is_file(): api.upload_file(repo_id=new_repo_id, path_or_fileobj=str(path), path_in_repo=path.name, token=hf_token) progress(1, desc="Uploaded.") url = f"https://huggingface.co./{new_repo_id}" except Exception as e: print(f"Error: Failed to upload to {new_repo_id}. ") print(e) return "" return url # https://github.com/huggingface/diffusers/blob/main/scripts/convert_flux_to_diffusers.py # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation with torch.no_grad(), torch.autocast(device): @torch.jit.script def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) return new_weight with torch.no_grad(), torch.autocast(device): def convert_flux_transformer_checkpoint_to_diffusers( original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0, progress=gr.Progress(track_tqdm=True)): def conv(cdict: dict, odict: dict, ckey: str, okey: str): if okey in odict.keys(): progress(0, desc=f"Converting {okey} => {ckey}") print(f"Converting {okey} => {ckey}") cdict[ckey] = odict.pop(okey) gc.collect() def convswap(cdict: dict, odict: dict, ckey: str, okey: str): if okey in odict.keys(): progress(0, desc=f"Converting (swap) {okey} => {ckey}") print(f"Converting {okey} => {ckey} (swap)") cdict[ckey] = swap_scale_shift(odict.pop(okey)) gc.collect() def convqkv(cdict: dict, odict: dict, i: int): keys = odict.keys() if (f"double_blocks.{i}.img_attn.qkv.weight" in keys or f"double_blocks.{i}.txt_attn.qkv.weight" in keys\ or f"double_blocks.{i}.img_attn.qkv.bias" in keys or f"double_blocks.{i}.txt_attn.qkv.bias" in keys)\ and (f"double_blocks.{i}.img_attn.qkv.weight" not in keys or f"double_blocks.{i}.txt_attn.qkv.weight" not in keys\ or f"double_blocks.{i}.img_attn.qkv.bias" not in keys or f"double_blocks.{i}.txt_attn.qkv.bias" not in keys): progress(0, desc=f"Key error in converting Q, K, V (double_blocks.{i}).") print(f"Key error in converting Q, K, V (double_blocks.{i}).") return progress(0, desc=f"Converting Q, K, V (double_blocks.{i}).") print(f"Converting Q, K, V (double_blocks.{i}).") sample_q, sample_k, sample_v = torch.chunk( odict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 ) context_q, context_k, context_v = torch.chunk( odict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 ) sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( odict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 ) context_q_bias, context_k_bias, context_v_bias = torch.chunk( odict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 ) cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) cdict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) cdict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) cdict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) cdict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) cdict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) cdict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) gc.collect() def convqkvmlp(cdict: dict, odict: dict, i: int, inner_dim: int, mlp_ratio: float): keys = odict.keys() if (f"single_blocks.{i}.linear1.weight" in keys or f"single_blocks.{i}.linear1.bias" in keys)\ and (f"single_blocks.{i}.linear1.weight" not in keys or f"single_blocks.{i}.linear1.bias" not in keys): progress(0, desc=f"Key error in converting Q, K, V, mlp (single_blocks.{i}).") print(f"Key error in converting Q, K, V, mlp (single_blocks.{i}).") return progress(0, desc=f"Converting Q, K, V, mlp (single_blocks.{i}).") print(f"Converting Q, K, V, mlp (single_blocks.{i}).") mlp_hidden_dim = int(inner_dim * mlp_ratio) split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) q, k, v, mlp = torch.split(odict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) q_bias, k_bias, v_bias, mlp_bias = torch.split( odict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 ) cdict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) cdict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) cdict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) cdict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) cdict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) cdict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) cdict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) cdict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) gc.collect() converted_state_dict = {} progress(0, desc="Converting FLUX.1 state dict to Diffusers format.") ## time_text_embed.timestep_embedder <- time_in conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.weight", "time_in.in_layer.weight") conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_1.bias", "time_in.in_layer.bias") conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.weight", "time_in.out_layer.weight") conv(converted_state_dict, original_state_dict, "time_text_embed.timestep_embedder.linear_2.bias", "time_in.out_layer.bias") ## time_text_embed.text_embedder <- vector_in conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight") conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_1.bias", "vector_in.in_layer.bias") conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight") conv(converted_state_dict, original_state_dict, "time_text_embed.text_embedder.linear_2.bias", "vector_in.out_layer.bias") # guidance has_guidance = any("guidance" in k for k in original_state_dict) if has_guidance: conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight") conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_1.bias", "guidance_in.in_layer.bias") conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight") conv(converted_state_dict, original_state_dict, "time_text_embed.guidance_embedder.linear_2.bias", "guidance_in.out_layer.bias") # context_embedder conv(converted_state_dict, original_state_dict, "context_embedder.weight", "txt_in.weight") conv(converted_state_dict, original_state_dict, "context_embedder.bias", "txt_in.bias") # x_embedder conv(converted_state_dict, original_state_dict, "x_embedder.weight", "img_in.weight") conv(converted_state_dict, original_state_dict, "x_embedder.bias", "img_in.bias") progress(0.25, desc="Converting FLUX.1 state dict to Diffusers format.") # double transformer blocks for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." # norms. ## norm1 conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.weight", f"double_blocks.{i}.img_mod.lin.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1.linear.bias", f"double_blocks.{i}.img_mod.lin.bias") ## norm1_context conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.weight", f"double_blocks.{i}.txt_mod.lin.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}norm1_context.linear.bias", f"double_blocks.{i}.txt_mod.lin.bias") # Q, K, V convqkv(converted_state_dict, original_state_dict, i) # qk_norm conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"double_blocks.{i}.img_attn.norm.query_norm.scale") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"double_blocks.{i}.img_attn.norm.key_norm.scale") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_q.weight", f"double_blocks.{i}.txt_attn.norm.query_norm.scale") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_added_k.weight", f"double_blocks.{i}.txt_attn.norm.key_norm.scale") # ff img_mlp conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.weight", f"double_blocks.{i}.img_mlp.0.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.0.proj.bias", f"double_blocks.{i}.img_mlp.0.bias") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.weight", f"double_blocks.{i}.img_mlp.2.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff.net.2.bias", f"double_blocks.{i}.img_mlp.2.bias") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{i}.txt_mlp.0.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.0.proj.bias", f"double_blocks.{i}.txt_mlp.0.bias") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.weight", f"double_blocks.{i}.txt_mlp.2.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}ff_context.net.2.bias", f"double_blocks.{i}.txt_mlp.2.bias") # output projections. conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.weight", f"double_blocks.{i}.img_attn.proj.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_out.0.bias", f"double_blocks.{i}.img_attn.proj.bias") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.weight", f"double_blocks.{i}.txt_attn.proj.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.to_add_out.bias", f"double_blocks.{i}.txt_attn.proj.bias") progress(0.5, desc="Converting FLUX.1 state dict to Diffusers format.") # single transfomer blocks for i in range(num_single_layers): block_prefix = f"single_transformer_blocks.{i}." # norm.linear <- single_blocks.0.modulation.lin conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.weight", f"single_blocks.{i}.modulation.lin.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}norm.linear.bias", f"single_blocks.{i}.modulation.lin.bias") # Q, K, V, mlp convqkvmlp(converted_state_dict, original_state_dict, i, inner_dim, mlp_ratio) # qk norm conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_q.weight", f"single_blocks.{i}.norm.query_norm.scale") conv(converted_state_dict, original_state_dict, f"{block_prefix}attn.norm_k.weight", f"single_blocks.{i}.norm.key_norm.scale") # output projections. conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.weight", f"single_blocks.{i}.linear2.weight") conv(converted_state_dict, original_state_dict, f"{block_prefix}proj_out.bias", f"single_blocks.{i}.linear2.bias") progress(0.75, desc="Converting FLUX.1 state dict to Diffusers format.") conv(converted_state_dict, original_state_dict, "proj_out.weight", "final_layer.linear.weight") conv(converted_state_dict, original_state_dict, "proj_out.bias", "final_layer.linear.bias") convswap(converted_state_dict, original_state_dict, "norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight") convswap(converted_state_dict, original_state_dict, "norm_out.linear.bias", "final_layer.adaLN_modulation.1.bias") progress(1, desc="Converting FLUX.1 state dict to Diffusers format.") return converted_state_dict # read safetensors metadata def read_safetensors_metadata(path): with open(path, 'rb') as f: header_size = int.from_bytes(f.read(8), 'little') header_json = f.read(header_size).decode('utf-8') header = json.loads(header_json) metadata = header.get('__metadata__', {}) return metadata.copy() def normalize_key(k: str): return k.replace("vae.", "").replace("model.diffusion_model.", "")\ .replace("text_encoders.clip_l.transformer.", "")\ .replace("text_encoders.t5xxl.transformer.", "") def load_json_list(path: str): try: with open(path, encoding='utf-8') as f: return list(json.load(f)) except Exception as e: print(e) return [] # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/modeling_utils.py # https://huggingface.co./docs/huggingface_hub/v0.24.5/package_reference/serialization # https://huggingface.co./docs/huggingface_hub/index with torch.no_grad(): def to_safetensors(sd: dict, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)): from huggingface_hub import save_torch_state_dict print(f"Saving a temporary file to disk: {path}") os.makedirs(path, exist_ok=True) try: for k, v in sd.items(): sd[k] = v.to(device="cpu") save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size) except Exception as e: print(e) # https://discuss.huggingface.co/t/t5forconditionalgeneration-checkpoint-size-mismatch-19418/24119 # https://github.com/huggingface/transformers/issues/13769 # https://github.com/huggingface/optimum-quanto/issues/278 # https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py # https://huggingface.co./docs/accelerate/usage_guides/big_modeling with torch.no_grad(): def to_safetensors_flux_module(sd: dict, path: str, pattern: str, size: str, quantization: bool=False, name: str = "", metadata: dict | None = None, progress=gr.Progress(track_tqdm=True)): from huggingface_hub import save_torch_state_dict, save_torch_model from accelerate import init_empty_weights try: progress(0, desc=f"Preparing to save FLUX.1 {name} to Diffusers format.") print(f"Preparing to save FLUX.1 {name} to Diffusers format.") for k, v in sd.items(): sd[k] = v.to(device="cpu") progress(0, desc=f"Loading FLUX.1 {name}.") print(f"Loading FLUX.1 {name}.") os.makedirs(path, exist_ok=True) if quantization: progress(0.5, desc=f"Saving quantized FLUX.1 {name} to {path}") print(f"Saving quantized FLUX.1 {name} to {path}") else: progress(0.5, desc=f"Saving FLUX.1 {name} to: {path}") if False and path.endswith("/transformer"): from diffusers import FluxTransformer2DModel has_guidance = any("guidance" in k for k in sd) with init_empty_weights(): model = FluxTransformer2DModel(guidance_embeds=has_guidance) model.to("cpu") model.load_state_dict(sd, strict=True) print(f"Saving FLUX.1 {name} to: {path} (FluxTransformer2DModel)") if metadata is not None: progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}") save_torch_model(model=model, save_directory=path, filename_pattern=pattern, max_shard_size=size, metadata=metadata) else: save_torch_model(model=model, save_directory=path, filename_pattern=pattern, max_shard_size=size) else: print(f"Saving FLUX.1 {name} to: {path}") if metadata is not None: progress(0.5, desc=f"Saving FLUX.1 {name} metadata to: {path}") save_torch_state_dict(state_dict=sd, save_directory=path, filename_pattern=pattern, max_shard_size=size, metadata=metadata) else: save_torch_state_dict(state_dict=sd, save_directory=path, filename_pattern=pattern, max_shard_size=size) progress(1, desc=f"Saved FLUX.1 {name} to: {path}") print(f"Saved FLUX.1 {name} to: {path}") except Exception as e: print(e) finally: gc.collect() flux_transformer_json = "flux_transformer_keys.json" flux_t5xxl_json = "flux_t5xxl_keys.json" flux_clip_json = "flux_clip_keys.json" flux_vae_json = "flux_vae_keys.json" keys_flux_t5xxl = set(load_json_list(flux_t5xxl_json)) keys_flux_transformer = set(load_json_list(flux_transformer_json)) keys_flux_clip = set(load_json_list(flux_clip_json)) keys_flux_vae = set(load_json_list(flux_vae_json)) with torch.no_grad(): def dequant_tensor(v: torch.Tensor, dtype: torch.dtype, dequant: bool): try: #print(f"shape: {v.shape} / dim: {v.ndim}") if dequant: qtype = v.tensor_type if v.dtype in TORCH_DTYPE: return v.to(dtype) if v.dtype != dtype else v elif qtype in GGUF_QTYPE: return dequantize_tensor(v, dtype) elif torch.dtype in TORCH_QUANTIZED_DTYPE: return torch.dequantize(v).to(dtype) else: return torch.dequantize(v).to(dtype) else: return v.to(dtype) if v.dtype != dtype else v except Exception as e: print(e) with torch.no_grad(): def normalize_flux_state_dict(path: str, savepath: str, dtype: torch.dtype = torch.bfloat16, dequant: bool = False, progress=gr.Progress(track_tqdm=True)): progress(0, desc=f"Loading and normalizing FLUX.1 safetensors: {path}") print(f"Loading and normalizing FLUX.1 safetensors: {path}") new_sd = dict() state_dict = load_file(path, device="cpu") try: for k in list(state_dict.keys()): v = state_dict.pop(k) nk = normalize_key(k) print(f"{k} => {nk}") # new_sd[nk] = dequant_tensor(v, dtype, dequant) except Exception as e: print(e) return finally: clear_sd(state_dict) new_path = str(Path(savepath, Path(path).stem + "_fixed" + Path(path).suffix)) metadata = read_safetensors_metadata(path) progress(0.5, desc=f"Saving FLUX.1 safetensors: {new_path}") print(f"Saving FLUX.1 safetensors: {new_path}") os.makedirs(savepath, exist_ok=True) save_file(new_sd, new_path, metadata={"format": "pt", **metadata}) progress(1, desc=f"Saved FLUX.1 safetensors: {new_path}") print(f"Saved FLUX.1 safetensors: {new_path}") clear_sd(new_sd) with torch.no_grad(): def extract_norm_flux_module_sd(path: str, dtype: torch.dtype = torch.bfloat16, dequant: bool = False, name: str = "", keys: set = {}, progress=gr.Progress(track_tqdm=True)): progress(0, desc=f"Loading and normalizing FLUX.1 {name} safetensors: {path}") print(f"Loading and normalizing FLUX.1 {name} safetensors: {path}") new_sd = dict() state_dict = load_file(path, device="cpu") try: for k in list(state_dict.keys()): if k not in keys: state_dict.pop(k) gc.collect() for k in list(state_dict.keys()): v = state_dict.pop(k) if k in keys: nk = normalize_key(k) progress(0.5, desc=f"{k} => {nk}") # print(f"{k} => {nk}") # new_sd[nk] = dequant_tensor(v, dtype, dequant) #print_resource_usage() # except Exception as e: print(e) return None finally: progress(1, desc=f"Normalized FLUX.1 {name} safetensors: {path}") print(f"Normalized FLUX.1 {name} safetensors: {path}") clear_sd(state_dict) return new_sd with torch.no_grad(): def convert_flux_transformer_sd_to_diffusers(sd: dict, progress=gr.Progress(track_tqdm=True)): progress(0, desc="Converting FLUX.1 state dict to Diffusers format.") print("Converting FLUX.1 state dict to Diffusers format.") num_layers = 19 num_single_layers = 38 inner_dim = 3072 mlp_ratio = 4.0 try: sd = convert_flux_transformer_checkpoint_to_diffusers( sd, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio ) except Exception as e: print(e) finally: progress(1, desc="Converted FLUX.1 state dict to Diffusers format.") print("Converted FLUX.1 state dict to Diffusers format.") gc.collect() return sd with torch.no_grad(): def load_sharded_safetensors(path: str): import glob sd = {} try: for filepath in glob.glob(f"{path}/*.safetensors"): sharded_sd = load_file(str(filepath), device="cpu") for k, v in sharded_sd.items(): sharded_sd[k] = v.to(device="cpu") sd = sd | sharded_sd.copy() clear_sd(sharded_sd) except Exception as e: print(e) return sd # https://huggingface.co./docs/safetensors/api/torch with torch.no_grad(): def convert_flux_transformer_sd_to_diffusers_sharded(sd: dict, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)): from huggingface_hub import save_torch_state_dict#, load_torch_model import glob try: progress(0, desc=f"Saving temporary files to disk: {path}") print(f"Saving temporary files to disk: {path}") os.makedirs(path, exist_ok=True) for k, v in sd.items(): if k in set(keys_flux_transformer): sd[k] = v.to(device="cpu") save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size) clear_sd(sd) progress(0.25, desc=f"Saved temporary files to disk: {path}") print(f"Saved temporary files to disk: {path}") for filepath in glob.glob(f"{path}/*.safetensors"): progress(0.25, desc=f"Processing temporary files: {str(filepath)}") print(f"Processing temporary files: {str(filepath)}") sharded_sd = load_file(str(filepath), device="cpu") sharded_sd = convert_flux_transformer_sd_to_diffusers(sharded_sd) for k, v in sharded_sd.items(): sharded_sd[k] = v.to(device="cpu") save_file(sharded_sd, str(filepath)) clear_sd(sharded_sd) print(f"Loading temporary files from disk: {path}") sd = load_sharded_safetensors(path) print(f"Loaded temporary files from disk: {path}") except Exception as e: print(e) return sd with torch.no_grad(): def extract_normalized_flux_state_dict_sharded(loadpath: str, dtype: torch.dtype, dequant: bool, path: str, pattern: str, size: str, progress=gr.Progress(track_tqdm=True)): from huggingface_hub import save_torch_state_dict#, load_torch_model import glob try: progress(0, desc=f"Loading model file: {loadpath}") print(f"Loading model file: {loadpath}") sd = load_file(loadpath, device="cpu") progress(0, desc=f"Saving temporary files to disk: {path}") print(f"Saving temporary files to disk: {path}") os.makedirs(path, exist_ok=True) for k, v in sd.items(): sd[k] = v.to(device="cpu") save_torch_state_dict(sd, path, filename_pattern=pattern, max_shard_size=size) clear_sd(sd) progress(0.25, desc=f"Saved temporary files to disk: {path}") print(f"Saved temporary files to disk: {path}") for filepath in glob.glob(f"{path}/*.safetensors"): progress(0.25, desc=f"Processing temporary files: {str(filepath)}") print(f"Processing temporary files: {str(filepath)}") sharded_sd = extract_norm_flux_module_sd(str(filepath), dtype, dequant, "Transformer", keys_flux_transformer) for k, v in sharded_sd.items(): sharded_sd[k] = v.to(device="cpu") save_file(sharded_sd, str(filepath)) clear_sd(sharded_sd) print(f"Processed temporary files: {str(filepath)}") print(f"Loading temporary files from disk: {path}") sd = load_sharded_safetensors(path) print(f"Loaded temporary files from disk: {path}") except Exception as e: print(e) return sd def download_repo(repo_name, path, use_original=["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)): from huggingface_hub import snapshot_download print(f"Downloading {repo_name}.") try: if "text_encoder_2" in use_original: snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"]) else: snapshot_download(repo_id=repo_name, local_dir=path, ignore_patterns=["transformer/diffusion*.*", "text_encoder_2/model*.*", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp"]) except Exception as e: print(e) def copy_nontensor_files(from_path, to_path, use_original=["vae", "text_encoder"]): import shutil if "text_encoder_2" in use_original: te_from = str(Path(from_path, "text_encoder_2")) te_to = str(Path(to_path, "text_encoder_2")) print(f"Copying Text Encoder 2 files {te_from} to {te_to}") shutil.copytree(te_from, te_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) if "text_encoder" in use_original: te1_from = str(Path(from_path, "text_encoder")) te1_to = str(Path(to_path, "text_encoder")) print(f"Copying Text Encoder 1 files {te1_from} to {te1_to}") shutil.copytree(te1_from, te1_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) if "vae" in use_original: vae_from = str(Path(from_path, "vae")) vae_to = str(Path(to_path, "vae")) print(f"Copying VAE files {vae_from} to {vae_to}") shutil.copytree(vae_from, vae_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) tn2_from = str(Path(from_path, "tokenizer_2")) tn2_to = str(Path(to_path, "tokenizer_2")) print(f"Copying Tokenizer 2 files {tn2_from} to {tn2_to}") shutil.copytree(tn2_from, tn2_to, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.png", "*.webp"), dirs_exist_ok=True) print(f"Copying non-tensor files {from_path} to {to_path}") shutil.copytree(from_path, to_path, ignore=shutil.ignore_patterns("*.safetensors", "*.bin", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.png", "*.webp", "*.index.json"), dirs_exist_ok=True) def save_flux_other_diffusers(path: str, model_type: str = "dev", use_original: list = ["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)): import shutil progress(0, desc="Loading FLUX.1 Components.") print("Loading FLUX.1 Components.") temppath = system_temp_dir if model_type == "schnell": repo = flux_schnell_repo else: repo = flux_dev_repo os.makedirs(temppath, exist_ok=True) os.makedirs(path, exist_ok=True) download_repo(repo, temppath, use_original) progress(0.5, desc="Saving FLUX.1 Components.") print("Saving FLUX.1 Components.") copy_nontensor_files(temppath, path, use_original) shutil.rmtree(temppath) with torch.no_grad(): def fix_flux_safetensors(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16, quantization: bool = False, model_type: str = "dev", dequant: bool = False): save_flux_other_diffusers(savepath, model_type) normalize_flux_state_dict(loadpath, savepath, dtype, dequant) clear_cache() with torch.no_grad(): # Much lower memory consumption, but higher disk load def flux_to_diffusers_lowmem(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16, quantization: bool = False, model_type: str = "dev", dequant: bool = False, use_original: list = ["vae", "text_encoder"], new_repo_id: str = "", local: bool = False, progress=gr.Progress(track_tqdm=True)): unet_sd_path = savepath.removesuffix("/") + "/transformer" unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" unet_sd_size = "9.5GB" te_sd_path = savepath.removesuffix("/") + "/text_encoder_2" te_sd_pattern = "model{suffix}.safetensors" te_sd_size = "5GB" clip_sd_path = savepath.removesuffix("/") + "/text_encoder" clip_sd_pattern = "model{suffix}.safetensors" clip_sd_size = "9.5GB" vae_sd_path = savepath.removesuffix("/") + "/vae" vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" vae_sd_size = "9.5GB" print_resource_usage() # metadata = {"format": "pt", **read_safetensors_metadata(loadpath)} clear_cache() print_resource_usage() # if "vae" not in use_original: vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE", keys_flux_vae) to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size, quantization, "VAE", None) clear_sd(vae_sd) print_resource_usage() # if "text_encoder" not in use_original: clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder", keys_flux_clip) to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size, quantization, "Text Encoder", None) clear_sd(clip_sd) print_resource_usage() # if "text_encoder_2" not in use_original: te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2", keys_flux_t5xxl) to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size, quantization, "Text Encoder 2", None) clear_sd(te_sd) print_resource_usage() # unet_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Transformer", keys_flux_transformer) clear_cache() print_resource_usage() # if not local: os.remove(loadpath) print("Deleted downloaded file.") clear_cache() print_resource_usage() # unet_sd = convert_flux_transformer_sd_to_diffusers(unet_sd) clear_cache() print_resource_usage() # to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size, quantization, "Transformer", metadata) clear_sd(unet_sd) print_resource_usage() # save_flux_other_diffusers(savepath, model_type, use_original) print_resource_usage() # with torch.no_grad(): # lowest memory consumption, but higheest disk load def flux_to_diffusers_lowmem2(loadpath: str, savepath: str, dtype: torch.dtype = torch.bfloat16, quantization: bool = False, model_type: str = "dev", dequant: bool = False, use_original: list = ["vae", "text_encoder"], new_repo_id: str = "", progress=gr.Progress(track_tqdm=True)): unet_sd_path = savepath.removesuffix("/") + "/transformer" unet_temp_path = system_temp_dir.removesuffix("/") + "/sharded" unet_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" unet_sd_size = "10GB" unet_temp_size = "5GB" te_sd_path = savepath.removesuffix("/") + "/text_encoder_2" te_sd_pattern = "model{suffix}.safetensors" te_sd_size = "5GB" clip_sd_path = savepath.removesuffix("/") + "/text_encoder" clip_sd_pattern = "model{suffix}.safetensors" clip_sd_size = "10GB" vae_sd_path = savepath.removesuffix("/") + "/vae" vae_sd_pattern = "diffusion_pytorch_model{suffix}.safetensors" vae_sd_size = "10GB" print_resource_usage() # metadata = {"format": "pt", **read_safetensors_metadata(loadpath)} clear_cache() print_resource_usage() # if "vae" not in use_original: vae_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "VAE", keys_flux_vae) to_safetensors_flux_module(vae_sd, vae_sd_path, vae_sd_pattern, vae_sd_size, quantization, "VAE", None) clear_sd(vae_sd) print_resource_usage() # if "text_encoder" not in use_original: clip_sd = extract_norm_flux_module_sd(loadpath, torch.bfloat16, dequant, "Text Encoder", keys_flux_clip) to_safetensors_flux_module(clip_sd, clip_sd_path, clip_sd_pattern, clip_sd_size, quantization, "Text Encoder", None) clear_sd(clip_sd) print_resource_usage() # if "text_encoder_2" not in use_original: te_sd = extract_norm_flux_module_sd(loadpath, dtype, dequant, "Text Encoder 2", keys_flux_t5xxl) to_safetensors_flux_module(te_sd, te_sd_path, te_sd_pattern, te_sd_size, quantization, "Text Encoder 2", None) clear_sd(te_sd) print_resource_usage() # unet_sd = extract_normalized_flux_state_dict_sharded(loadpath, dtype, dequant, unet_temp_path, unet_sd_pattern, unet_temp_size) clear_cache() print_resource_usage() # unet_sd = convert_flux_transformer_sd_to_diffusers_sharded(unet_sd, unet_temp_path, unet_sd_pattern, unet_temp_size) clear_cache() print_resource_usage() # to_safetensors_flux_module(unet_sd, unet_sd_path, unet_sd_pattern, unet_sd_size, quantization, "Transformer", metadata) clear_sd(unet_sd) print_resource_usage() # save_flux_other_diffusers(savepath, model_type, use_original) print_resource_usage() # def convert_url_to_diffusers_flux(url, civitai_key="", is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False, use_original=["vae", "text_encoder"], hf_user="", hf_repo="", q=None, progress=gr.Progress(track_tqdm=True)): progress(0, desc="Start converting...") temp_dir = "." print_resource_usage() # new_file = get_download_file(temp_dir, url, civitai_key) if not new_file: print(f"Not found: {url}") return "" new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") # dtype = torch.bfloat16 quantization = False if data_type == "fp8": dtype = torch.float8_e4m3fn elif data_type == "fp16": dtype = torch.float16 elif data_type == "qfloat8": dtype = torch.bfloat16 quantization = True else: dtype = torch.bfloat16 new_repo_id = f"{hf_user}/{Path(new_repo_name).stem}" if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}" flux_to_diffusers_lowmem(new_file, new_repo_name, dtype, quantization, model_type, dequant, use_original, new_repo_id) """if is_upload_sf: import shutil shutil.move(str(Path(new_file).resolve()), str(Path(new_repo_name, Path(new_file).name).resolve())) else: os.remove(new_file)""" progress(1, desc="Converted.") q.put(new_repo_name) return new_repo_name def convert_url_to_fixed_flux_safetensors(url, civitai_key="", is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False, q=None, progress=gr.Progress(track_tqdm=True)): progress(0, desc="Start converting...") temp_dir = "." print_resource_usage() # new_file = get_download_file(temp_dir, url, civitai_key) if not new_file: print(f"Not found: {url}") return "" new_repo_name = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") # dtype = torch.bfloat16 quantization = False if data_type == "fp8": dtype = torch.float8_e4m3fn elif data_type == "fp16": dtype = torch.float16 elif data_type == "qfloat8": dtype = torch.bfloat16 quantization = True else: dtype = torch.bfloat16 fix_flux_safetensors(new_file, new_repo_name, dtype, model_type, dequant) os.remove(new_file) progress(1, desc="Converted.") q.put(new_repo_name) return new_repo_name def convert_url_to_diffusers_repo_flux(dl_url, hf_user, hf_repo, hf_token, civitai_key="", is_private=True, is_overwrite=False, is_upload_sf=False, data_type="bf16", model_type="dev", dequant=False, repo_urls=[], fix_only=False, use_original=["vae", "text_encoder"], progress=gr.Progress(track_tqdm=True)): import multiprocessing as mp import shutil if not hf_user: print(f"Invalid user name: {hf_user}") progress(1, desc=f"Invalid user name: {hf_user}") return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="") if hf_token and not os.environ.get("HF_TOKEN"): os.environ['HF_TOKEN'] = hf_token if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") q = mp.Queue() if fix_only: p = mp.Process(target=convert_url_to_fixed_flux_safetensors, args=(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant, q)) #new_path = convert_url_to_fixed_flux_safetensors(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant) else: p = mp.Process(target=convert_url_to_diffusers_flux, args=(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant, use_original, hf_user, hf_repo, q)) #new_path = convert_url_to_diffusers_flux(dl_url, civitai_key, is_upload_sf, data_type, model_type, dequant) p.start() new_path = q.get() p.join() if not new_path: return "" new_repo_id = f"{hf_user}/{Path(new_path).stem}" if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}" if not is_repo_name(new_repo_id): print(f"Invalid repo name: {new_repo_id}") progress(1, desc=f"Invalid repo name: {new_repo_id}") return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="") if not is_overwrite and is_repo_exists(new_repo_id): print(f"Repo already exists: {new_repo_id}") progress(1, desc=f"Repo already exists: {new_repo_id}") return gr.update(value=repo_urls, choices=repo_urls), gr.update(value="") #save_readme_md(new_path, dl_url) repo_url = create_diffusers_repo(new_repo_id, new_path, is_private, is_overwrite) shutil.rmtree(new_path) if not repo_urls: repo_urls = [] repo_urls.append(repo_url) md = "Your new repo:
" for u in repo_urls: md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})
" return gr.update(value=repo_urls, choices=repo_urls), gr.update(value=md) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", default=None, type=str, required=False, help="URL of the model to convert.") parser.add_argument("--file", default=None, type=str, required=False, help="Filename of the model to convert.") parser.add_argument("--fix", action="store_true", help="Only fix the keys of the local model.") parser.add_argument("--civitai_key", default=None, type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).") parser.add_argument("--dtype", type=str, default="fp8") parser.add_argument("--model", type=str, default="dev") parser.add_argument("--dequant", action="store_true", help="Dequantize model.") args = parser.parse_args() assert (args.url, args.file) != (None, None), "Must provide --url or --file!" dtype = torch.bfloat16 quantization = False if args.dtype == "fp8": dtype = torch.float8_e4m3fn elif args.dtype == "fp16": dtype = torch.float16 elif args.dtype == "qfloat8": dtype = torch.bfloat16 quantization = True else: dtype = torch.bfloat16 use_original = ["vae", "text_encoder"] new_repo_id = "" use_local = True if args.file is not None and Path(args.file).exists(): if args.fix: normalize_flux_state_dict(args.file, ".", dtype, args.dequant) else: flux_to_diffusers_lowmem(args.file, Path(args.file).stem, dtype, quantization, args.model, args.dequant, use_original, new_repo_id, use_local) elif args.url is not None: convert_url_to_diffusers_flux(args.url, args.civitai_key, False, args.dtype, args.model, args.dequant)