File size: 10,751 Bytes
4120479
 
 
a8a382e
9b729f7
ca25d4d
af5ea8a
c6550c9
4120479
 
 
1475e41
3eb8dac
4120479
 
 
af5ea8a
 
 
 
 
 
 
 
 
 
 
 
 
 
4120479
 
 
 
20706a7
4120479
 
 
 
 
 
 
 
 
 
 
 
9b729f7
 
ba35348
 
 
 
 
 
2f40f84
008aa9b
2f40f84
4120479
edf408e
ce04d24
 
2130441
b819231
ce04d24
 
4120479
b48fe41
4120479
374e672
bc0b5e4
ca25d4d
 
bc0b5e4
ca25d4d
 
 
 
92c7c82
4120479
 
ad8076c
3eb8dac
d9a8a67
008aa9b
 
3eb8dac
0fac10f
e6d0ada
9bd2ea4
4120479
9bd2ea4
4120479
1d5c6d0
4120479
a5454cf
af5ea8a
 
 
 
 
3eb8dac
 
 
af5ea8a
4120479
 
 
145660e
4120479
 
ea9cf0a
4120479
 
0052f82
4120479
0052f82
 
4120479
 
 
a56c826
 
4120479
06bf887
a56c826
0052f82
4120479
af5ea8a
 
 
 
 
 
 
 
 
4120479
 
 
 
b819231
4120479
 
 
0d23c84
 
ce04d24
d8635e5
92c7c82
d1f8613
ce04d24
 
 
d8635e5
92c7c82
d1f8613
ce04d24
edf408e
ce04d24
0d23c84
ce04d24
 
3b258c8
ce04d24
3e21e1d
92c7c82
b48fe41
bc0b5e4
b48fe41
bc0b5e4
b48fe41
ca25d4d
 
 
 
af5ea8a
4120479
 
af5ea8a
 
4120479
 
 
b6fa736
 
0052f82
 
4120479
af5ea8a
 
 
 
ca25d4d
4120479
d92fc15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from share_btn import community_icon_html, loading_icon_html, share_js
from uuid import uuid4
from pathlib import Path
import torch
import json
import random
import copy
import gc

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"

scheduler = CommitScheduler(
    repo_id="multimodalart/lora-fusing-preferences",
    repo_type="dataset",
    folder_path=IMAGE_DATASET_DIR,
    path_in_repo=IMAGE_DATASET_DIR.name,
    every=10
)

client = InferenceClient()

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co./spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

for item in sdxl_loras:
    saved_name = hf_hub_download(item["repo"], item["weights"])
    
    if saved_name.endswith('.safetensors'):
        state_dict = load_file(saved_name)
    else:
        state_dict = torch.load(saved_name)
        
    item["saved_name"] = saved_name
    item["state_dict"] = state_dict #{k: v.to(device="cuda", dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}

css = '''
.gradio-container{max-width: 650px! important}
#title{text-align:center;}
#title h1{font-size: 250%}
.selected_random img{object-fit: cover}
.selected_random [data-testid="block-label"] span{display: none}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 36px;right: 0;margin-right: 1.5em;border-bottom-left-radius: 0px;
    border-top-left-radius: 0px;}
.random_column{align-self: center; align-items: center}
#share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
#share-btn-container:hover {background-color: #060606}
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
#share-btn * {all: unset}
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
#share-btn-container .wrap {display: none !important}
#share-btn-container.hidden {display: none!important}
#post_gen_info{margin-top: .5em}
'''

original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1, progress=gr.Progress(track_tqdm=True)):
  state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
  state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
  pipe = copy.deepcopy(original_pipe)
  pipe.to("cuda") 
  
  pipe.load_lora_weights(state_dict_1)
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(state_dict_2)
  pipe.fuse_lora(lora_2_scale)
  
  if negative_prompt == "":
    negative_prompt = None
          
  if(seed < 0):
      seed = random.randint(0, 2147483647)
  generator = torch.Generator(device="cuda").manual_seed(seed)
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
  del pipe
  gc.collect()
  torch.cuda.empty_cache()
  return image, gr.update(visible=True), seed

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger, applied automatically", trigger_word
    
def shuffle_images():
    compatible_items = [item for item in sdxl_loras if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
    repo_id_1 = gr.update(value=two_shuffled_items[0]['repo'])
    repo_id_2 = gr.update(value=two_shuffled_items[1]['repo'])
    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])
    
    prompt_description_1 = gr.update(value=description_1, visible=True)
    prompt_description_2 = gr.update(value=description_2, visible=True)
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    scale = gr.update(value=0.7)
    
    return title_1, prompt_description_1, repo_id_1, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale

def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, generated_image, thumbs_direction, seed):
    image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
    with scheduler.lock:
        generated_image.save(image_path)
        with IMAGE_JSONL_PATH.open("a") as f:
            json.dump({"prompt": prompt, "file_name":image_path.name, "lora_1_id": lora_2_id, "lora_1_scale": lora_1_scale, "lora_2_id": lora_2_id, "lora_2_scale": lora_2_scale, "thumbs_direction": thumbs_direction, "seed": seed}, f)
            f.write("\n")
    return gr.update(visible=True)

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎲</h1>
        <p>This random LoRAs are loaded into SDXL, can you find a fun way to combine them? 🎨</p>
        ''',
        elem_id="title"
  )
  with gr.Column():
    with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
        with gr.Row():
            with gr.Column(min_width=10, scale=3, elem_classes="random_column"):
              lora_1 = gr.Image(interactive=False, height=150, elem_classes="selected_random", elem_id="randomLoRA_1", show_share_button=False, show_download_button=False)
              lora_1_id = gr.Textbox(visible=False, elem_id="random_lora_1_id")
              lora_1_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
              plus = gr.HTML("+", elem_classes="plus_button")
            with gr.Column(min_width=10, scale=3, elem_classes="random_column"):
              lora_2 = gr.Image(interactive=False, height=150, elem_classes="selected_random", elem_id="randomLoRA_2", show_share_button=False, show_download_button=False)
              lora_2_id = gr.Textbox(visible=False, elem_id="random_lora_2_id")
              lora_2_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=2, elem_classes="plus_column"):
               equal = gr.HTML("=", elem_classes="plus_button")
    with gr.Column(min_width=10, scale=14):
        with gr.Box():
            with gr.Row():
                prompt = gr.Textbox(label="Your prompt", info="Rearrange the trigger words into a coherent prompt", show_label=False, interactive=True, elem_id="prompt")
                run_btn = gr.Button("Run", elem_id="run_button")
            output_image = gr.Image(label="Output", height=355, elem_id="output_image")
            with gr.Row(visible=False, elem_id="post_gen_info") as post_gen_info:
                with gr.Column(min_width=10):
                    thumbs_up = gr.Button("👍")
                with gr.Column(min_width=10):
                    thumbs_down = gr.Button("👎")
                with gr.Column(min_width=10):
                    with gr.Group(elem_id="share-btn-container") as share_group:
                        community_icon = gr.HTML(community_icon_html)
                        loading_icon = gr.HTML(loading_icon_html)
                        share_button = gr.Button("Share to community", elem_id="share-btn")
            post_eval = gr.Markdown("Thanks for evaluating. The dataset with evaluations is [here](#)", visible=False)
  with gr.Accordion("Advanced settings", open=False):
    negative_prompt = gr.Textbox(label="Negative prompt")
    seed = gr.Slider(label="Seed", info="-1 denotes a random seed", minimum=-1, maximum=2147483647, value=-1)
    last_used_seed = gr.Slider(label="Last used seed", info="The seed used in the last generation", minimum=0, maximum=2147483647, value=-1, interactive=False)
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
  shuffle_button = gr.Button("Reshuffle!")
    
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")

  run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info, last_used_seed])
  prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info, last_used_seed])
  thumbs_up.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("up"), seed], outputs=[post_eval])
  thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), seed], outputs=[post_eval])
  share_button.click(None, [], [], _js=share_js)
demo.queue()
demo.launch()