File size: 4,363 Bytes
0e0ee20
 
 
 
 
 
7dc34c1
7039ded
0e0ee20
 
 
 
 
 
 
 
 
 
f3e96f9
c59400c
0e0ee20
 
 
 
 
 
 
 
 
 
 
d6802e8
f3e96f9
0e0ee20
 
 
 
 
 
 
 
701c4ad
 
 
 
c59400c
0e0ee20
f3e96f9
 
0e0ee20
 
 
 
 
 
 
 
 
 
f3e96f9
0e0ee20
 
fd8e800
 
0e0ee20
fd8e800
0e0ee20
1441e58
07d3eff
504da62
 
 
07d3eff
504da62
 
 
 
 
0e0ee20
d6802e8
 
 
02302e4
07d3eff
0e0ee20
db98dea
1fff27d
0e0ee20
 
 
 
db98dea
0e0ee20
1fff27d
db98dea
8dce9c7
0e0ee20
 
2c6d128
 
 
 
 
 
 
 
 
 
 
 
 
 
0e0ee20
 
07d3eff
 
 
0e0ee20
f3e96f9
0e0ee20
 
 
 
d6802e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline
import copy

# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Initialize the base model
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pipe.to("cuda")

MAX_SEED = 2**32-1

def update_selection(evt: gr.SelectData):
    selected_lora = loras[evt.index]
    new_placeholder = f"Type a prompt for {selected_lora['title']}"
    lora_repo = selected_lora["repo"]
    updated_text = f"### Selected: [{lora_repo}](https://huggingface.co./{lora_repo}) ✨"
    return (
        gr.update(placeholder=new_placeholder),
        updated_text,
        evt.index
    )

@spaces.GPU(duration=90)
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
    if selected_index is None:
        raise gr.Error("You must select a LoRA before proceeding.")

    selected_lora = loras[selected_index]
    lora_path = selected_lora["repo"]
    trigger_word = selected_lora["trigger_word"]

    # Load LoRA weights
    if "weights" in selected_lora:
        pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
    else:
        pipe.load_lora_weights(lora_path)
        
    # Set random seed for reproducibility
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)

    # Generate image
    image = pipe(
        prompt=f"{prompt} {trigger_word}",
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        joint_attention_kwargs={"scale": lora_scale},
    ).images[0]

    yield image
    
    pipe.unload_lora_weights()
    

css = '''
#gen_btn{height: 100%}
#title{text-align: center;}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
'''
with gr.Blocks(theme=gr.themes.Soft(), css=css) as app:
    title = gr.HTML(
        """<h1><img src="https://i.imgur.com/vT48NAO.png" alt="LoRA"> FLUX LoRA the Explorer</h1>""",
        elem_id="title",
    )
    selected_index = gr.State(None)
    with gr.Row():
        with gr.Column(scale=3):
            prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
        with gr.Column(scale=1, elem_id="gen_column"):
            generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
    with gr.Row():
        with gr.Column(scale=3):
            selected_info = gr.Markdown("")
            gallery = gr.Gallery(
                [(item["image"], item["title"]) for item in loras],
                label="LoRA Gallery",
                allow_preview=False,
                columns=3
            )
            
        with gr.Column(scale=4):
            result = gr.Image(label="Generated Image")

    with gr.Row():
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Column():
                with gr.Row():
                    cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
                    steps = gr.Slider(label="Steps", minimum=1, maximum=100, step=1, value=30)
                
                with gr.Row():
                    width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
                    height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
                
                with gr.Row():
                    randomize_seed = gr.Checkbox(True, label="Randomize seed")
                    seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
                    lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.85)

    gallery.select(update_selection, outputs=[prompt, selected_info, selected_index])

    gr.on(
        triggers=[generate_button.click, prompt.submit],
        fn=run_lora,
        inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale],
        outputs=[result]
    )

app.queue()
app.launch()