multimodalart HF staff commited on
Commit
af5ea8a
1 Parent(s): f01b73c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -7
app.py CHANGED
@@ -4,6 +4,7 @@ from diffusers import DiffusionPipeline
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  from share_btn import community_icon_html, loading_icon_html, share_js
 
7
 
8
  import torch
9
  import json
@@ -13,6 +14,20 @@ import gc
13
 
14
  lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  with open(lora_list, "r") as file:
17
  data = json.load(file)
18
  sdxl_loras = [
@@ -66,7 +81,7 @@ div#share-btn-container > div {flex-direction: row;background: black;align-items
66
 
67
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
68
 
69
- def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, progress=gr.Progress(track_tqdm=True)):
70
  state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
71
  state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
72
  pipe = copy.deepcopy(original_pipe)
@@ -79,12 +94,15 @@ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lor
79
 
80
  if negative_prompt == "":
81
  negative_prompt = None
82
-
83
- image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768).images[0]
 
 
 
84
  del pipe
85
  gc.collect()
86
  torch.cuda.empty_cache()
87
- return image, gr.update(visible=True)
88
 
89
  def get_description(item):
90
  trigger_word = item["trigger_word"]
@@ -108,6 +126,15 @@ def shuffle_images():
108
 
109
  return title_1, prompt_description_1, repo_id_1, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale
110
 
 
 
 
 
 
 
 
 
 
111
  with gr.Blocks(css=css) as demo:
112
  shuffled_items = gr.State()
113
  title = gr.HTML(
@@ -147,9 +174,11 @@ with gr.Blocks(css=css) as demo:
147
  community_icon = gr.HTML(community_icon_html)
148
  loading_icon = gr.HTML(loading_icon_html)
149
  share_button = gr.Button("Share to community", elem_id="share-btn")
150
-
151
  with gr.Accordion("Advanced settings", open=False):
152
  negative_prompt = gr.Textbox(label="Negative prompt")
 
 
153
  with gr.Row():
154
  lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
155
  lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
@@ -158,8 +187,10 @@ with gr.Blocks(css=css) as demo:
158
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
159
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
160
 
161
- run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info])
162
- prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info])
 
 
163
  share_button.click(None, [], [], _js=share_js)
164
  demo.queue()
165
  demo.launch()
 
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
+ from uuid import uuid4
8
 
9
  import torch
10
  import json
 
14
 
15
  lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
16
 
17
+ IMAGE_DATASET_DIR = Path("image_dataset") / f"train-{uuid4()}"
18
+ IMAGE_DATASET_DIR.mkdir(parents=True, exist_ok=True)
19
+ IMAGE_JSONL_PATH = IMAGE_DATASET_DIR / "metadata.jsonl"
20
+
21
+ scheduler = CommitScheduler(
22
+ repo_id="multimodalart/lora-fusing-preferences",
23
+ repo_type="dataset",
24
+ folder_path=IMAGE_DATASET_DIR,
25
+ path_in_repo=IMAGE_DATASET_DIR.name,
26
+ every=10
27
+ )
28
+
29
+ client = InferenceClient()
30
+
31
  with open(lora_list, "r") as file:
32
  data = json.load(file)
33
  sdxl_loras = [
 
81
 
82
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
83
 
84
+ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed, progress=gr.Progress(track_tqdm=True)):
85
  state_dict_1 = copy.deepcopy(shuffled_items[0]['state_dict'])
86
  state_dict_2 = copy.deepcopy(shuffled_items[1]['state_dict'])
87
  pipe = copy.deepcopy(original_pipe)
 
94
 
95
  if negative_prompt == "":
96
  negative_prompt = None
97
+
98
+ if(seed < 0):
99
+ seed = random.randint(0, 2147483647)
100
+ generator = torch.Generator(device="cuda").manual_seed(seed)
101
+ image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
102
  del pipe
103
  gc.collect()
104
  torch.cuda.empty_cache()
105
+ return image, gr.update(visible=True), seed
106
 
107
  def get_description(item):
108
  trigger_word = item["trigger_word"]
 
126
 
127
  return title_1, prompt_description_1, repo_id_1, title_2, prompt_description_2, repo_id_2, prompt, two_shuffled_items, scale, scale
128
 
129
+ def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, generated_image, thumbs_direction, seed):
130
+ image_path = IMAGE_DATASET_DIR / f"{uuid4()}.png"
131
+ with scheduler.lock:
132
+ generated_image.save(image_path)
133
+ with IMAGE_JSONL_PATH.open("a") as f:
134
+ json.dump({"prompt": prompt, "file_name":image_path.name, "lora_1_id": lora_2_id, "lora_1_scale": lora_1_scale, "lora_2_id": lora_2_id, "lora_2_scale": lora_2_scale, "thumbs_direction": thumbs_direction, "seed": seed}, f)
135
+ f.write("\n")
136
+ return gr.update(visible=True)
137
+
138
  with gr.Blocks(css=css) as demo:
139
  shuffled_items = gr.State()
140
  title = gr.HTML(
 
174
  community_icon = gr.HTML(community_icon_html)
175
  loading_icon = gr.HTML(loading_icon_html)
176
  share_button = gr.Button("Share to community", elem_id="share-btn")
177
+ post_eval = gr.Markdown("Thanks for evaluating. The dataset with evaluations is [here](#)", visible=False)
178
  with gr.Accordion("Advanced settings", open=False):
179
  negative_prompt = gr.Textbox(label="Negative prompt")
180
+ seed = gr.Slider(label="Seed", info="-1 denotes a random seed", minimum=-1, maximum=2147483647, value=-1)
181
+ last_used_seed = gr.Slider(label="Last used seed", info="The seed used in the last generation", minimum=0, maximum=2147483647, value=-1, interactive=False)
182
  with gr.Row():
183
  lora_1_scale = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
184
  lora_2_scale = gr.Slider(label="LoRa 2 Scale", minimum=0, maximum=1, step=0.1, value=0.7)
 
187
  demo.load(shuffle_images, inputs=[], outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
188
  shuffle_button.click(shuffle_images, outputs=[lora_1, lora_1_prompt, lora_1_id, lora_2, lora_2_prompt, lora_2_id, prompt, shuffled_items, lora_1_scale, lora_2_scale], queue=False, show_progress="hidden")
189
 
190
+ run_btn.click(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info, last_used_seed])
191
+ prompt.submit(merge_and_run, inputs=[prompt, negative_prompt, shuffled_items, lora_1_scale, lora_2_scale], outputs=[output_image, post_gen_info, last_used_seed])
192
+ thumbs_up.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("up"), seed], outputs=[post_eval])
193
+ thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), seed], outputs=[post_eval])
194
  share_button.click(None, [], [], _js=share_js)
195
  demo.queue()
196
  demo.launch()