Spaces:
Runtime error
Runtime error
File size: 6,244 Bytes
4120479 a8a382e 9b729f7 a8a382e 4120479 1475e41 3eb8dac 4120479 20706a7 4120479 9b729f7 ba35348 2f40f84 4ed3fa9 2f40f84 4120479 ce04d24 4120479 ce04d24 4120479 ce04d24 4120479 3eb8dac 4120479 3eb8dac 9b729f7 4120479 9b729f7 4120479 1d5c6d0 4120479 3eb8dac 9b729f7 3eb8dac 4120479 ce04d24 4120479 ea9cf0a 4120479 a56c826 4120479 a56c826 4120479 ce04d24 b6fa736 ce04d24 4120479 b6fa736 a56c826 4120479 a56c826 4120479 d92fc15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch
import json
import random
import copy
import gc
lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
with open(lora_list, "r") as file:
data = json.load(file)
sdxl_loras = [
{
"image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co./spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
"is_nc": item.get("is_nc", False)
}
for item in data
]
for item in sdxl_loras:
saved_name = hf_hub_download(item["repo"], item["weights"])
if saved_name.endswith('.safetensors'):
state_dict = load_file(saved_name)
else:
state_dict = torch.load(saved_name)
item["saved_name"] = saved_name
item["state_dict"] = state_dict #{k: v.to(device="cuda", dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}
css = '''
#title{text-align:center;}
#title h1{font-size: 250%}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 12px;right: 0;margin-right: 1.5em;border-bottom-left-radius: 0px;
border-top-left-radius: 0px;}
.random_column{align-self: center}
@media (max-width: 1024px) {
.roulette_group{flex-direction: column}
}
'''
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
original_pipe = copy.deepcopy(pipe)
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
pipe = copy.deepcopy(original_pipe)
pipe.to("cuda")
pipe.load_lora_weights(shuffled_items[0]['state_dict'])
pipe.fuse_lora(lora_1_scale)
pipe.load_lora_weights(shuffled_items[1]['state_dict'])
pipe.fuse_lora(lora_2_scale)
if negative_prompt == "":
negative_prompt = False
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=22, width=768, height=768).images[0]
del pipe
gc.collect()
torch.cuda.empty_cache()
return image
def get_description(item):
trigger_word = item["trigger_word"]
return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger word, will be applied automatically", trigger_word
def shuffle_images():
compatible_items = [item for item in sdxl_loras if item['is_compatible']]
random.shuffle(compatible_items)
two_shuffled_items = compatible_items[:2]
title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
description_1, trigger_word_1 = get_description(two_shuffled_items[0])
description_2, trigger_word_2 = get_description(two_shuffled_items[1])
prompt_description_1 = gr.update(value=description_1, visible=True)
prompt_description_2 = gr.update(value=description_2, visible=True)
prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items
with gr.Blocks(css=css) as demo:
shuffled_items = gr.State()
title = gr.HTML(
'''<h1>LoRA Roulette 🎲</h1>
''',
elem_id="title"
)
with gr.Row(elem_classes="roulette_group"):
with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
gr.HTML("<p>This 2 random LoRAs are loaded to SDXL, find a fun way to combine them 🎨</p>")
with gr.Row():
with gr.Column(min_width=10, scale=8, elem_classes="random_column"):
lora_1 = gr.Image(interactive=False, height=263)
lora_1_prompt = gr.Markdown(visible=False)
with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
plus = gr.HTML("+", elem_classes="plus_button")
with gr.Column(min_width=10, scale=8, elem_classes="random_column"):
lora_2 = gr.Image(interactive=False, height=263)
lora_2_prompt = gr.Markdown(visible=False)
with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
equal = gr.HTML("=", elem_classes="plus_button")
with gr.Column(min_width=10, scale=14):
with gr.Box():
with gr.Row():
prompt = gr.Textbox(label="Your prompt", show_label=False, interactive=True, elem_id="prompt")
run_btn = gr.Button("Run", elem_id="run_button")
output_image = gr.Image(label="Output", height=355)
with gr.Accordion("Advanced settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt")
with gr.Row():
lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
shuffle_button = gr.Button("Reshuffle!")
demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items], queue=False, show_progress="hidden")
run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image])
demo.queue()
demo.launch() |