Spaces:
Runtime error
Runtime error
File size: 8,335 Bytes
4120479 a8a382e 9b729f7 ca25d4d a8a382e 4120479 1475e41 3eb8dac 4120479 20706a7 4120479 9b729f7 ba35348 2f40f84 008aa9b 2f40f84 4120479 edf408e ce04d24 2130441 b819231 ce04d24 4120479 b48fe41 4120479 374e672 bc0b5e4 ca25d4d bc0b5e4 ca25d4d 4120479 ad8076c 3eb8dac 4120479 008aa9b 3eb8dac 0fac10f e6d0ada 9bd2ea4 4120479 9bd2ea4 4120479 1d5c6d0 4120479 a5454cf 3eb8dac 1cd10c1 3eb8dac bc0b5e4 4120479 145660e 4120479 ea9cf0a 4120479 a56c826 4120479 06bf887 a56c826 06bf887 4120479 b819231 4120479 0d23c84 ce04d24 374e672 b819231 ce04d24 374e672 b819231 ce04d24 edf408e ce04d24 0d23c84 ce04d24 3b258c8 ce04d24 bc0b5e4 b48fe41 bc0b5e4 b48fe41 bc0b5e4 b48fe41 ca25d4d bc0b5e4 4120479 b6fa736 06bf887 4120479 bc0b5e4 ca25d4d 4120479 d92fc15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
from time import sleep
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from share_btn import community_icon_html, loading_icon_html, share_js
import torch
import json
import random
import copy
import gc
lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
with open(lora_list, "r") as file:
data = json.load(file)
sdxl_loras = [
{
"image": item["image"] if item["image"].startswith("https://") else f'https://huggingface.co./spaces/multimodalart/LoraTheExplorer/resolve/main/{item["image"]}',
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"weights": item["weights"],
"is_compatible": item["is_compatible"],
"is_pivotal": item.get("is_pivotal", False),
"text_embedding_weights": item.get("text_embedding_weights", None),
"is_nc": item.get("is_nc", False)
}
for item in data
]
for item in sdxl_loras:
saved_name = hf_hub_download(item["repo"], item["weights"])
if saved_name.endswith('.safetensors'):
state_dict = load_file(saved_name)
else:
state_dict = torch.load(saved_name)
item["saved_name"] = saved_name
item["state_dict"] = state_dict #{k: v.to(device="cuda", dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}
css = '''
.gradio-container{max-width: 650px! important}
#title{text-align:center;}
#title h1{font-size: 250%}
.selected_random img{object-fit: cover}
.selected_random [data-testid="block-label"] span{display: none}
.plus_column{align-self: center}
.plus_button{font-size: 235% !important; text-align: center;margin-bottom: 19px}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button{position:absolute;margin-top: 36px;right: 0;margin-right: 1.5em;border-bottom-left-radius: 0px;
border-top-left-radius: 0px;}
.random_column{align-self: center; align-items: center}
#share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
#share-btn-container:hover {background-color: #060606}
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
#share-btn * {all: unset}
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
#share-btn-container .wrap {display: none !important}
#share-btn-container.hidden {display: none!important}
'''
original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
pipe = copy.deepcopy(original_pipe)
pipe.to("cuda")
pipe.load_lora_weights(state_dict_1)
pipe.fuse_lora(lora_1_scale)
pipe.load_lora_weights(state_dict_2)
pipe.fuse_lora(lora_2_scale)
if negative_prompt == "":
negative_prompt = None
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768).images[0]
del pipe
gc.collect()
torch.cuda.empty_cache()
return image, gr.update(visible=True)
def get_description(item):
trigger_word = item["trigger_word"]
return f"Trigger: `{trigger_word}`" if trigger_word else "No trigger, applied automatically", trigger_word
def shuffle_images():
compatible_items = [item for item in sdxl_loras if item['is_compatible']]
random.shuffle(compatible_items)
two_shuffled_items = compatible_items[:2]
title_1 = gr.update(label=two_shuffled_items[0]['title'], value=two_shuffled_items[0]['image'])
title_2 = gr.update(label=two_shuffled_items[1]['title'], value=two_shuffled_items[1]['image'])
description_1, trigger_word_1 = get_description(two_shuffled_items[0])
description_2, trigger_word_2 = get_description(two_shuffled_items[1])
prompt_description_1 = gr.update(value=description_1, visible=True)
prompt_description_2 = gr.update(value=description_2, visible=True)
prompt = gr.update(value=f"{trigger_word_1} {trigger_word_2}")
scale = gr.update(value=0.7)
return title_1, prompt_description_1, title_2, prompt_description_2, prompt, two_shuffled_items, scale, scale
with gr.Blocks(css=css) as demo:
shuffled_items = gr.State()
title = gr.HTML(
'''<h1>LoRA Roulette 🎲</h1>
<p>This random LoRAs are loaded into SDXL, can you find a fun way to combine them? 🎨</p>
''',
elem_id="title"
)
with gr.Column():
with gr.Column(min_width=10, scale=16, elem_classes="plus_column"):
with gr.Row():
with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
lora_1 = gr.Image(interactive=False, height=150, elem_classes="selected_random", show_share_button=False, show_download_button=False)
lora_1_prompt = gr.Markdown(visible=False)
with gr.Column(min_width=10, scale=1, elem_classes="plus_column"):
plus = gr.HTML("+", elem_classes="plus_button")
with gr.Column(min_width=10, scale=4, elem_classes="random_column"):
lora_2 = gr.Image(interactive=False, height=150, elem_classes="selected_random", show_share_button=False, show_download_button=False)
lora_2_prompt = gr.Markdown(visible=False)
with gr.Column(min_width=10, scale=2, elem_classes="plus_column"):
equal = gr.HTML("=", elem_classes="plus_button")
with gr.Column(min_width=10, scale=14):
with gr.Box():
with gr.Row():
prompt = gr.Textbox(label="Your prompt", info="Rearrange the trigger words into a coherent prompt", show_label=False, interactive=True, elem_id="prompt")
run_btn = gr.Button("Run", elem_id="run_button")
output_image = gr.Image(label="Output", height=355)
with gr.Row(visible=False) as post_gen_info:
with gr.Column(min_width=10):
thumbs_up = gr.Button("👍")
with gr.Column(min_width=10):
thumbs_down = gr.Button("👎")
with gr.Column(min_width=10):
with gr.Group(elem_id="share-btn-container") as share_group:
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
with gr.Accordion("Advanced settings", open=False):
negative_prompt = gr.Textbox(label="Negative prompt")
with gr.Row():
lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
shuffle_button = gr.Button("Reshuffle!")
demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_2, lora_2_prompt, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info])
prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info])
share_button.click(None, [], [], _js=share_js)
demo.queue()
demo.launch() |