File size: 6,484 Bytes
4120479
 
 
a8a382e
9b729f7
a8a382e
4120479
 
 
1475e41
3eb8dac
4120479
 
 
 
 
 
 
20706a7
4120479
 
 
 
 
 
 
 
 
 
 
 
9b729f7
 
ba35348
 
 
 
 
 
2f40f84
4ed3fa9
2f40f84
4120479
ce04d24
 
2130441
ce04d24
 
4120479
ce04d24
4120479
ce04d24
 
 
 
4120479
 
ad8076c
3eb8dac
4120479
ad8076c
e6d0ada
 
 
3eb8dac
e6d0ada
 
 
9b729f7
4120479
9b729f7
4120479
1d5c6d0
4120479
a5454cf
3eb8dac
1cd10c1
ad8076c
3eb8dac
 
 
4120479
 
 
ce04d24
4120479
 
ea9cf0a
4120479
 
 
 
 
 
 
 
a56c826
 
4120479
a56c826
 
4120479
 
 
 
 
 
 
 
ce04d24
 
 
 
 
41f8c1f
ce04d24
 
 
 
41f8c1f
ce04d24
 
 
 
 
 
 
 
 
b6fa736
ce04d24
 
4120479
 
 
 
 
 
b6fa736
 
a56c826
 
4120479
 
 
a56c826
4120479
d92fc15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

import torch
import json
import random
import copy
import gc

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co./spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

for item in sdxl_loras:
    saved_name = hf_hub_download(item["repo"], item["weights"])
    
    if saved_name.endswith('.safetensors'):
        state_dict = load_file(saved_name)
    else:
        state_dict = torch.load(saved_name)
        
    item["saved_name"] = saved_name
    item["state_dict"] = state_dict #{k: v.to(device="cuda", dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}

css = '''
#title{text-align:center;}
#title h1{font-size: 250%}
.selected_random img{object-fit: cover}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 12px;right: 0;margin-right: 1.5em;border-bottom-left-radius: 0px;
    border-top-left-radius: 0px;}
.random_column{align-self: center}
@media (max-width: 1024px) {
.roulette_group{flex-direction: column}
}
'''

original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
  print(shuffled_items)
  state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
  state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
  
  pipe = copy.deepcopy(original_pipe)
  pipe.to("cuda")  
    
  
  pipe.load_lora_weights(shuffled_items[0]['state_dict'])
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(shuffled_items[1]['state_dict'])
  pipe.fuse_lora(lora_2_scale)
  
  if negative_prompt == "":
    negative_prompt = None
      
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768).images[0]
  yield image
  del pipe
  gc.collect()
  torch.cuda.empty_cache()

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger word, will be applied automatically", trigger_word
    
def shuffle_images():
    compatible_items = [item for item in sdxl_loras if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1  = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])

    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])
    
    prompt_description_1 = gr.update(value=description_1, visible=True)
    prompt_description_2 = gr.update(value=description_2, visible=True)
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    
    return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎲</h1>
        ''',
        elem_id="title"
  )
  with gr.Row(elem_classes="roulette_group"):
    with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
        gr.HTML("<p>This 2 random LoRAs are loaded to SDXL, find a fun way to combine them 🎨</p>")
        with gr.Row():
            with gr.Column(min_width=10, scale=8, elem_classes="random_column"):
              lora_1 = gr.Image(interactive=False, height=263, elem_classes="selected_random")
              lora_1_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
              plus = gr.HTML("+", elem_classes="plus_button")
            with gr.Column(min_width=10, scale=8, elem_classes="random_column"):
              lora_2 = gr.Image(interactive=False, height=263, elem_classes="selected_random")
              lora_2_prompt = gr.Markdown(visible=False)
            with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
               equal = gr.HTML("=", elem_classes="plus_button")
    with gr.Column(min_width=10, scale=14):
        with gr.Box():
            with gr.Row():
                prompt = gr.Textbox(label="Your prompt", show_label=False, interactive=True, elem_id="prompt")
                run_btn = gr.Button("Run", elem_id="run_button")
            output_image = gr.Image(label="Output", height=355)
    
  
  
  
  with gr.Accordion("Advanced settings", open=False):
    negative_prompt = gr.Textbox(label="Negative prompt")
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
  shuffle_button = gr.Button("Reshuffle!")
    
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")

  run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
  prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])

demo.queue()
demo.launch()