File size: 7,237 Bytes
0e0ee20
 
 
 
 
 
7dc34c1
0e0ee20
 
 
 
 
 
 
 
c59400c
0e0ee20
 
c59400c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e0ee20
 
 
 
 
 
 
 
 
 
 
d6802e8
dad6779
0e0ee20
 
 
 
 
 
 
 
701c4ad
 
 
 
0e0ee20
c59400c
 
 
 
 
0e0ee20
 
 
 
 
 
1c15549
0e0ee20
 
 
 
 
9640ef2
0e0ee20
 
 
 
 
 
 
 
 
 
d6802e8
 
 
 
 
0e0ee20
 
1fff27d
0e0ee20
 
 
 
 
 
1fff27d
 
8dce9c7
0e0ee20
 
1fff27d
 
1c15549
0e0ee20
 
 
1c15549
0e0ee20
 
 
 
 
 
 
 
 
 
 
 
 
 
1c15549
0e0ee20
 
 
 
d6802e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline

# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Initialize the base model
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
original_load_lora = copy.deepcopy(pipe.load_lora_into_transformer)
pipe.to("cuda")

def load_lora_into_transformer_patched(cls, state_dict, transformer, adapter_name=None, alpha=None, _pipeline=None):
    from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict

    keys = list(state_dict.keys())

    transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
    state_dict = {
        k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
    }

    if len(state_dict.keys()) > 0:
        # check with first key if is not in peft format
        first_key = next(iter(state_dict.keys()))
        if "lora_A" not in first_key:
            state_dict = convert_unet_state_dict_to_peft(state_dict)

        if adapter_name in getattr(transformer, "peft_config", {}):
            raise ValueError(
                f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
            )

        rank = {}
        for key, val in state_dict.items():
            if "lora_B" in key:
                rank[key] = val.shape[1]

        lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
        if "use_dora" in lora_config_kwargs:
            if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
                raise ValueError(
                    "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
                )
            else:
                lora_config_kwargs.pop("use_dora")

        
        lora_config_kwargs["lora_alpha"] = 32
        lora_config = LoraConfig(**lora_config_kwargs)

        # adapter_name
        if adapter_name is None:
            adapter_name = get_adapter_name(transformer)

        # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
        # otherwise loading LoRA weights will lead to an error
        is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)

        inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
        incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)

        if incompatible_keys is not None:
            # check only for unexpected keys
            unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
            if unexpected_keys:
                logger.warning(
                    f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
                    f" {unexpected_keys}. "
                )

        # Offload back.
        if is_model_cpu_offload:
            _pipeline.enable_model_cpu_offload()
        elif is_sequential_cpu_offload:
            _pipeline.enable_sequential_cpu_offload()
        # Unsafe code />

def update_selection(evt: gr.SelectData):
    selected_lora = loras[evt.index]
    new_placeholder = f"Type a prompt for {selected_lora['title']}"
    lora_repo = selected_lora["repo"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✨"
    return (
        gr.update(placeholder=new_placeholder),
        updated_text,
        evt.index
    )

@spaces.GPU(duration=90)
def run_lora(prompt, cfg_scale, steps, selected_index, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
    if selected_index is None:
        raise gr.Error("You must select a LoRA before proceeding.")

    selected_lora = loras[selected_index]
    lora_path = selected_lora["repo"]
    trigger_word = selected_lora["trigger_word"]

    # Load LoRA weights
    if "weights" in selected_lora:
        pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
    else:
        pipe.load_lora_weights(lora_path)

    if "custom_alpha" in selected_lora:
        pipe.load_lora_into_transformer = load_lora_into_transformer_patched
    else:
        pipe.load_lora_into_transformer = original_load_lora
        
    # Set random seed for reproducibility
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Generate image
    image = pipe(
        prompt=f"{prompt} {trigger_word}",
        #negative_prompt=negative_prompt,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        #cross_attention_kwargs={"scale": lora_scale},
    ).images[0]

    # Unload LoRA weights
    pipe.unload_lora_weights()

    return image

with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown("# FLUX.1 LoRA the Explorer")
    selected_index = gr.State(None)
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
        with gr.Column(scale=1):
            generate_button = gr.Button("Generate", variant="primary")
    with gr.Row():
        with gr.Column(scale=1):
            selected_info = gr.Markdown("")
            gallery = gr.Gallery(
                [(item["image"], item["title"]) for item in loras],
                label="LoRA Gallery",
                allow_preview=False,
                columns=2
            )
            
        with gr.Column(scale=2):
            result = gr.Image(label="Generated Image")

    with gr.Row():
        #with gr.Column():
            #prompt_title = gr.Markdown("### Click on a LoRA in the gallery to select it")
            #negative_prompt = gr.Textbox(label="Negative Prompt", lines=2, value="low quality, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry")

        with gr.Column():
            with gr.Row():
                cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
                steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=30)
            
            with gr.Row():
                width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
                height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
            
            with gr.Row():
                seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=0, randomize=True)
                lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=1)

    gallery.select(update_selection, outputs=[prompt, selected_info, selected_index])
    
    generate_button.click(
        fn=run_lora,
        inputs=[prompt, cfg_scale, steps, selected_index, seed, width, height, lora_scale],
        outputs=[result]
    )

app.queue()
app.launch()