File size: 5,352 Bytes
4120479
 
 
a8a382e
 
4120479
 
 
1475e41
4120479
 
 
 
 
 
 
20706a7
4120479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f40f84
 
 
4120479
 
 
 
 
 
 
 
 
 
84e6fa0
4120479
7efd9a0
2f40f84
 
1d5c6d0
2f40f84
4120479
2f40f84
4120479
1d5c6d0
2f40f84
4120479
 
1d5c6d0
4120479
 
 
 
 
7efd9a0
4120479
 
ea9cf0a
4120479
 
 
 
 
 
 
 
a56c826
 
4120479
a56c826
 
4120479
 
 
 
 
88dfda8
4120479
 
 
 
 
1d5c6d0
a56c826
4120479
 
 
1d5c6d0
a56c826
4120479
 
 
 
 
 
 
 
 
 
 
 
a56c826
 
4120479
 
 
a56c826
4120479
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download

import torch
import json
import random
import copy

lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")

with open(lora_list, "r") as file:
    data = json.load(file)
    sdxl_loras = [
        {
            "image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co./spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
            "title": item["title"],
            "repo": item["repo"],
            "trigger_word": item["trigger_word"],
            "weights": item["weights"],
            "is_compatible": item["is_compatible"],
            "is_pivotal": item.get("is_pivotal", False),
            "text_embedding_weights": item.get("text_embedding_weights", None),
            "is_nc": item.get("is_nc", False)
        }
        for item in data
    ]

saved_names = [
    hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
]

for item, saved_name in zip(sdxl_loras, saved_names):
    item["saved_name"] = saved_name

css = '''
#title{text-align:center}
#plus_column{align-self: center}
#plus_button{font-size: 250%; text-align: center}
.gradio-container{width: 700px !important; margin: 0 auto !important}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 57px;right: 0;margin-right: 0.8em;border-bottom-left-radius: 0px;
    border-top-left-radius: 0px;}
'''

#@spaces.GPU
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
  pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) 
  pipe.to(torch_dtype=torch.float16)
  pipe.to("cuda")
  print("Loading LoRAs")
  pipe.load_lora_weights(shuffled_items[0]['saved_name'])
  pipe.fuse_lora(lora_1_scale)
  pipe.load_lora_weights(shuffled_items[1]['saved_name'])
  pipe.fuse_lora(lora_2_scale)
  
  
  if negative_prompt == "":
    negative_prompt = False
  print("Running inference")
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, guidance_scale=7).images[0]
  return image

def get_description(item):
      trigger_word = item["trigger_word"]
      return f"LoRA trigger word: `{trigger_word}`" if trigger_word else "No trigger word, will be applied automatically", trigger_word
    
def shuffle_images():
    compatible_items = [item for item in sdxl_loras if item['is_compatible']]
    random.shuffle(compatible_items)
    two_shuffled_items = compatible_items[:2]
    title_1  = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
    title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])

    description_1, trigger_word_1 = get_description(two_shuffled_items[0])
    description_2, trigger_word_2 = get_description(two_shuffled_items[1])
    
    prompt_description_1 = gr.update(value=description_1, visible=True)
    prompt_description_2 = gr.update(value=description_2, visible=True)
    prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
    
    return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items

with gr.Blocks(css=css) as demo:
  shuffled_items = gr.State()
  title = gr.HTML(
        '''<h1>LoRA Roulette 🎲</h1>
        <h4>This 2 LoRAs are loaded to SDXL at random, find a fun way to combine them 🎨</h4>
        ''',
        elem_id="title"
  )
  with gr.Row():
    with gr.Column(min_width=10, scale=6):
      lora_1 = gr.Image(interactive=False, height=300)
      lora_1_prompt = gr.Markdown(visible=False)
    with gr.Column(min_width=10, scale=1, elem_id="plus_column"):
      plus = gr.HTML("+", elem_id="plus_button")
    with gr.Column(min_width=10, scale=6):
      lora_2 = gr.Image(interactive=False, height=300)
      lora_2_prompt = gr.Markdown(visible=False)
  with gr.Row():
    prompt = gr.Textbox(label="Your prompt", info="arrange the trigger words of the two LoRAs in a coherent sentence", interactive=True, elem_id="prompt")
    run_btn = gr.Button("Run", elem_id="run_button")
  
  output_image = gr.Image()
  with gr.Accordion("Advanced settings", open=False):
    negative_prompt = gr.Textbox(label="Negative prompt")
    with gr.Row():
      lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
      lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
  shuffle_button = gr.Button("Reshuffle LoRAs!")
  
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")

  run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
  prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])

demo.queue()
demo.launch()