Spaces:
Runtime error
Runtime error
File size: 6,549 Bytes
4120479 a8a382e 9b729f7 a8a382e 4120479 1475e41 3eb8dac 4120479 20706a7 4120479 9b729f7 ba35348 2f40f84 008aa9b 2f40f84 4120479 374e672 ce04d24 374e672 ce04d24 2130441 ce04d24 4120479 ce04d24 4120479 374e672 ce04d24 4120479 ad8076c 3eb8dac 4120479 008aa9b 3eb8dac 0fac10f e6d0ada 9bd2ea4 4120479 9bd2ea4 4120479 1d5c6d0 4120479 a5454cf 3eb8dac 1cd10c1 3eb8dac 0fac10f 4120479 ce04d24 4120479 ea9cf0a 4120479 a56c826 4120479 06bf887 a56c826 06bf887 4120479 0d23c84 ce04d24 374e672 ce04d24 374e672 ce04d24 0d23c84 ce04d24 b6fa736 ce04d24 4120479 b6fa736 06bf887 4120479 a56c826 4120479 d92fc15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
import json
import random
import copy
import gc
lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
with open(lora_list, "r") as file:
data = json.load(file)
sdxl_loras = [
{
"image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co./spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
"is_nc": item.get("is_nc", False)
}
for item in data
]
for item in sdxl_loras:
saved_name = hf_hub_download(item["repo"], item["weights"])
if saved_name.endswith('.safetensors'):
state_dict = load_file(saved_name)
else:
state_dict = torch.load(saved_name)
item["saved_name"] = saved_name
item["state_dict"] = state_dict #{k: v.to(device="cuda", dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}
css = '''
.gradio-container{max-width: 650px}
#title{text-align:center;}
#title h1{font-size: 250%}
.selected_random img{object-fit: cover}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 12px;right: 0;margin-right: 1.5em;border-bottom-left-radius: 0px;
border-top-left-radius: 0px;}
.random_column{align-self: center; align-items: center}
@media (max-width: 1024px) {
.roulette_group{flex-direction: column}
}
'''
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
pipe = copy.deepcopy(original_pipe)
pipe.to("cuda")
pipe.load_lora_weights(state_dict_1)
pipe.fuse_lora(lora_1_scale)
pipe.load_lora_weights(state_dict_2)
pipe.fuse_lora(lora_2_scale)
if negative_prompt == "":
negative_prompt = None
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768).images[0]
del pipe
gc.collect()
torch.cuda.empty_cache()
return image
def get_description(item):
trigger_word = item["trigger_word"]
return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger word, will be applied automatically", trigger_word
def shuffle_images():
compatible_items = [item for item in sdxl_loras if item['is_compatible']]
random.shuffle(compatible_items)
two_shuffled_items = compatible_items[:2]
title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
description_1, trigger_word_1 = get_description(two_shuffled_items[0])
description_2, trigger_word_2 = get_description(two_shuffled_items[1])
prompt_description_1 = gr.update(value=description_1, visible=True)
prompt_description_2 = gr.update(value=description_2, visible=True)
prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
scale = gr.update(value=0.7)
return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items, scale, scale
with gr.Blocks(css=css) as demo:
shuffled_items = gr.State()
title = gr.HTML(
'''<h1>LoRA Roulette 🎲</h1>
''',
elem_id="title"
)
with gr.Column():
with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
gr.HTML("<p>This 2 random LoRAs are loaded to SDXL, find a fun way to combine them 🎨</p>")
with gr.Row():
with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
lora_1 = gr.Image(interactive=False, height=150, elem_classes="selected_random")
lora_1_prompt = gr.Markdown(visible=False)
with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
plus = gr.HTML("+", elem_classes="plus_button")
with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
lora_2 = gr.Image(interactive=False, height=150, elem_classes="selected_random")
lora_2_prompt = gr.Markdown(visible=False)
with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
equal = gr.HTML("=", elem_classes="plus_button")
with gr.Column(min_width=10, scale=14):
with gr.Box():
with gr.Row():
prompt = gr.Textbox(label="Your prompt", show_label=False, interactive=True, elem_id="prompt")
run_btn = gr.Button("Run", elem_id="run_button")
output_image = gr.Image(label="Output", height=355)
with gr.Accordion("Advanced settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt")
with gr.Row():
lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
shuffle_button = gr.Button("Reshuffle!")
demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
demo.queue()
demo.launch() |