import argparse import torch from safetensors.torch import load_file, save_file from safetensors import safe_open from utils import model_utils import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def convert_from_diffusers(prefix, weights_sd): # convert from diffusers(?) to default LoRA # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...} # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...} # note: Diffusers has no alpha, so alpha is set to rank new_weights_sd = {} lora_dims = {} for key, weight in weights_sd.items(): diffusers_prefix, key_body = key.split(".", 1) if diffusers_prefix != "diffusion_model": logger.warning(f"unexpected key: {key} in diffusers format") continue new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.") new_weights_sd[new_key] = weight lora_name = new_key.split(".")[0] # before first dot if lora_name not in lora_dims and "lora_down" in new_key: lora_dims[lora_name] = weight.shape[0] # add alpha with rank for lora_name, dim in lora_dims.items(): new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim) return new_weights_sd def convert_to_diffusers(prefix, weights_sd): # convert from default LoRA to diffusers # get alphas lora_alphas = {} for key, weight in weights_sd.items(): if key.startswith(prefix): lora_name = key.split(".", 1)[0] # before first dot if lora_name not in lora_alphas and "alpha" in key: lora_alphas[lora_name] = weight new_weights_sd = {} for key, weight in weights_sd.items(): if key.startswith(prefix): if "alpha" in key: continue lora_name = key.split(".", 1)[0] # before first dot # HunyuanVideo lora name to module name: ugly but works module_name = lora_name[len(prefix) :] # remove "lora_unet_" module_name = module_name.replace("_", ".") # replace "_" with "." module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks module_name = module_name.replace("img.", "img_") # fix img module_name = module_name.replace("txt.", "txt_") # fix txt module_name = module_name.replace("attn.", "attn_") # fix attn diffusers_prefix = "diffusion_model" if "lora_down" in key: new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight" dim = weight.shape[0] elif "lora_up" in key: new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight" dim = weight.shape[1] else: logger.warning(f"unexpected key: {key} in default LoRA format") continue # scale weight by alpha if lora_name in lora_alphas: # we scale both down and up, so scale is sqrt scale = lora_alphas[lora_name] / dim scale = scale.sqrt() weight = weight * scale else: logger.warning(f"missing alpha for {lora_name}") new_weights_sd[new_key] = weight return new_weights_sd def convert(input_file, output_file, target_format): logger.info(f"loading {input_file}") weights_sd = load_file(input_file) with safe_open(input_file, framework="pt") as f: metadata = f.metadata() logger.info(f"converting to {target_format}") prefix = "lora_unet_" if target_format == "default": new_weights_sd = convert_from_diffusers(prefix, weights_sd) metadata = metadata or {} model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata) elif target_format == "other": new_weights_sd = convert_to_diffusers(prefix, weights_sd) else: raise ValueError(f"unknown target format: {target_format}") logger.info(f"saving to {output_file}") save_file(new_weights_sd, output_file, metadata=metadata) logger.info("done") def parse_args(): parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats") parser.add_argument("--input", type=str, required=True, help="input model file") parser.add_argument("--output", type=str, required=True, help="output model file") parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format") args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() convert(args.input, args.output, args.target)