FLUX.1-Merged

This repository provides the merged params for black-forest-labs/FLUX.1-dev and black-forest-labs/FLUX.1-schnell.

Merge & Upload

from diffusers import FluxTransformer2DModel
from huggingface_hub import snapshot_download
from huggingface_hub import upload_folder
from accelerate import init_empty_weights
from diffusers.models.model_loading_utils import load_model_dict_into_meta
import safetensors.torch
import glob
import torch


# Initialize the model with empty weights
with init_empty_weights():
    config = FluxTransformer2DModel.load_config("black-forest-labs/FLUX.1-dev", subfolder="transformer")
    model = FluxTransformer2DModel.from_config(config)

# Download the model checkpoints
dev_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-dev", allow_patterns="transformer/*")
schnell_ckpt = snapshot_download(repo_id="black-forest-labs/FLUX.1-schnell", allow_patterns="transformer/*")

# Get the paths to the model shards
dev_shards = sorted(glob.glob(f"{dev_ckpt}/transformer/*.safetensors"))
schnell_shards = sorted(glob.glob(f"{schnell_ckpt}/transformer/*.safetensors"))

# Merge the state dictionaries
merged_state_dict = {}
guidance_state_dict = {}

for i in range(len(dev_shards)):
    state_dict_dev_temp = safetensors.torch.load_file(dev_shards[i])
    state_dict_schnell_temp = safetensors.torch.load_file(schnell_shards[i])

    keys = list(state_dict_dev_temp.keys())
    for k in keys:
        if "guidance" not in k:
            merged_state_dict[k] = (state_dict_dev_temp.pop(k) + state_dict_schnell_temp.pop(k)) / 2
        else:
            guidance_state_dict[k] = state_dict_dev_temp.pop(k)

    if len(state_dict_dev_temp) > 0:
        raise ValueError(f"There should not be any residue but got: {list(state_dict_dev_temp.keys())}.")
    if len(state_dict_schnell_temp) > 0:
        raise ValueError(f"There should not be any residue but got: {list(state_dict_schnell_temp.keys())}.")

# Update the merged state dictionary with the guidance state dictionary
merged_state_dict.update(guidance_state_dict)

# Load the merged state dictionary into the model
load_model_dict_into_meta(model, merged_state_dict)

# Save the merged model
model.to(torch.bfloat16).save_pretrained("transformer")

# Upload the merged model to the Hugging Face Hub
upload_folder(
    repo_id="prithivMLmods/Flux.1-Merged",  # Replace with your Hugging Face username and desired repo name
    folder_path="transformer",
    path_in_repo="transformer",
)

Inference

from diffusers import FluxPipeline
import torch

pipeline = FluxPipeline.from_pretrained(
    "prithivMLmods/Flux.1-Merged", torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
    prompt="a tiny astronaut hatching from an egg on the moon",
    guidance_scale=3.5,
    num_inference_steps=4,
    height=880,
    width=1184,
    max_sequence_length=512,
    generator=torch.manual_seed(0),
).images[0]
image.save("merged_flux.png")
Downloads last month
0
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and the model is not deployed on the HF Inference API.

Model tree for prithivMLmods/Flux.1-Merged