amildravid4292 commited on
Commit
aa1da5e
·
verified ·
1 Parent(s): cbf6563

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -47
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import os
3
  os.system("pip uninstall -y gradio")
4
  os.system('pip install gradio==3.43.1')
5
 
@@ -9,6 +8,7 @@ import torchvision.transforms as transforms
9
  from torch.utils.data import Dataset, DataLoader
10
  import gradio as gr
11
  import sys
 
12
  import tqdm
13
  sys.path.append(os.path.abspath(os.path.join("", "..")))
14
  import torch
@@ -112,6 +112,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
112
 
113
 
114
 
 
115
  @torch.no_grad()
116
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
117
 
@@ -129,8 +130,17 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
129
 
130
  original_weights = network.proj.clone()
131
 
 
 
 
 
 
 
 
 
 
132
 
133
- edited_weights = original_weights+a1*1e6*young+a2*1e6*pointy+a3*1e6*wavy+a4*2e6*large
134
 
135
  generator = generator.manual_seed(seed)
136
  latents = torch.randn(
@@ -191,16 +201,16 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
191
 
192
 
193
 
194
-
195
  def sample_then_run():
196
  sample_model()
197
  prompt = "sks person"
198
- negative_prompt = "low quality, blurry, unfinished, cartoon"
199
  seed = 5
200
  cfg = 3.0
201
  steps = 50
202
  image = inference( prompt, negative_prompt, cfg, steps, seed)
203
- return image
 
204
 
205
 
206
 
@@ -342,20 +352,55 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
342
 
343
  #sample an image
344
  prompt = "sks person"
345
- negative_prompt = "low quality, blurry, unfinished, cartoon"
346
  seed = 5
347
  cfg = 3.0
348
  steps = 50
349
  image = inference( prompt, negative_prompt, cfg, steps, seed)
350
- torch.save(network.proj, "model.pt" )
351
- return image, "model.pt"
352
 
353
 
354
 
 
 
 
 
 
355
 
 
 
 
 
 
 
 
 
 
 
356
 
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
 
 
 
359
  intro = """
360
  <div style="display: flex;align-items: center;justify-content: center">
361
  <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
@@ -378,39 +423,47 @@ with gr.Blocks(css="style.css") as demo:
378
  with gr.Tab("Sampling Models + Editing"):
379
  with gr.Row():
380
  with gr.Column():
381
- gallery1 = gr.Image(label="Identity from Sampled Model")
382
  sample = gr.Button("Sample New Model")
383
- gallery2 = gr.Image(label="Identity from Edited Model")
 
 
 
 
384
 
385
 
386
- with gr.Row():
387
- with gr.Column():
388
  prompt = gr.Textbox(label="Prompt",
389
- info="Make sure to include 'sks person'" ,
390
- placeholder="sks person",
391
- value="sks person")
392
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
393
- with gr.Row():
394
  a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
395
-
396
- a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
397
- with gr.Row():
398
  a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
399
- a4 = gr.Slider(label="- placeholder for some fourth attribute +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
-
402
- with gr.Accordion("Advanced Options", open=False):
403
- with gr.Column():
404
- seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
405
- cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
406
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
407
- injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
408
 
409
 
410
 
411
- submit = gr.Button("Generate")
412
 
413
- sample.click(fn=sample_then_run, outputs=gallery1)
414
 
415
  submit.click(fn=edit_inference,
416
  inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
@@ -421,15 +474,18 @@ with gr.Blocks(css="style.css") as demo:
421
  with gr.Tab("Inversion"):
422
  with gr.Row():
423
  with gr.Column():
424
- input_image = gr.Image(sources='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
425
- height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
426
- # input_image = gr.ImageEditor(sources='upload', elem_id="image_upload", type='pil', label="Upload image and draw to define mask",
427
- # height=512, width=512)
428
 
429
  lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
430
- weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
431
  pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
432
- epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
 
 
 
 
 
433
 
434
  invert_button = gr.Button("Invert")
435
 
@@ -439,13 +495,15 @@ with gr.Blocks(css="style.css") as demo:
439
  info="Make sure to include 'sks person'" ,
440
  placeholder="sks person",
441
  value="sks person")
442
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
443
  seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
444
- cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
445
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
 
 
 
446
  submit = gr.Button("Generate")
447
 
448
- file_output = gr.File(label="Download Model", container=False)
449
 
450
 
451
 
@@ -459,10 +517,4 @@ with gr.Blocks(css="style.css") as demo:
459
  outputs=gallery)
460
 
461
 
462
-
463
-
464
-
465
-
466
-
467
- demo.queue().launch() #share=True)
468
-
 
1
  import os
 
2
  os.system("pip uninstall -y gradio")
3
  os.system('pip install gradio==3.43.1')
4
 
 
8
  from torch.utils.data import Dataset, DataLoader
9
  import gradio as gr
10
  import sys
11
+ import os
12
  import tqdm
13
  sys.path.append(os.path.abspath(os.path.join("", "..")))
14
  import torch
 
112
 
113
 
114
 
115
+
116
  @torch.no_grad()
117
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
118
 
 
130
 
131
  original_weights = network.proj.clone()
132
 
133
+ #pad to same number of PCs
134
+ pcs_original = original_weights.shape[1]
135
+ pcs_edits = young.shape[1]
136
+ padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
137
+ young_pad = torch.cat((young, padding), 1)
138
+ pointy_pad = torch.cat((pointy, padding), 1)
139
+ wavy_pad = torch.cat((wavy, padding), 1)
140
+ large_pad = torch.cat((large, padding), 1)
141
+
142
 
143
+ edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*large_pad
144
 
145
  generator = generator.manual_seed(seed)
146
  latents = torch.randn(
 
201
 
202
 
203
 
 
204
  def sample_then_run():
205
  sample_model()
206
  prompt = "sks person"
207
+ negative_prompt = "low quality, blurry, unfinished, nudity"
208
  seed = 5
209
  cfg = 3.0
210
  steps = 50
211
  image = inference( prompt, negative_prompt, cfg, steps, seed)
212
+ torch.save(network.proj, "sampled_model.pt" )
213
+ return image, "sampled_model.pt"
214
 
215
 
216
 
 
352
 
353
  #sample an image
354
  prompt = "sks person"
355
+ negative_prompt = "low quality, blurry, unfinished, nudity"
356
  seed = 5
357
  cfg = 3.0
358
  steps = 50
359
  image = inference( prompt, negative_prompt, cfg, steps, seed)
360
+ torch.save(network.proj, "inverted_model.pt" )
361
+ return image, "inverted_model.pt"
362
 
363
 
364
 
365
+ def file_upload(file):
366
+ global unet
367
+ del unet
368
+ global network
369
+ global device
370
 
371
+
372
+
373
+ proj = torch.load(file.name).to(device)
374
+
375
+ #pad to 10000 Principal components to keep everything consistent
376
+ pcs = proj.shape[1]
377
+ padding = torch.zeros((1,10000-pcs)).to(device)
378
+ proj = torch.cat((proj, padding), 1)
379
+
380
+ unet, _, _, _, _ = load_models(device)
381
 
382
 
383
+ network = LoRAw2w( proj, mean, std, v[:, :pcs],
384
+ unet,
385
+ rank=1,
386
+ multiplier=1.0,
387
+ alpha=27.0,
388
+ train_method="xattn-strict"
389
+ ).to(device, torch.bfloat16)
390
+
391
+
392
+ prompt = "sks person"
393
+ negative_prompt = "low quality, blurry, unfinished, nudity"
394
+ seed = 5
395
+ cfg = 3.0
396
+ steps = 50
397
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
398
+ return image
399
+
400
 
401
+
402
+
403
+
404
  intro = """
405
  <div style="display: flex;align-items: center;justify-content: center">
406
  <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
 
423
  with gr.Tab("Sampling Models + Editing"):
424
  with gr.Row():
425
  with gr.Column():
426
+ gallery1 = gr.Image(label="Identity from Sampled or Uploaded Model")
427
  sample = gr.Button("Sample New Model")
428
+ file_input = gr.File(label="Upload Model", container=True)
429
+ file_input.change(fn=file_upload, inputs=file_input, outputs = gallery1)
430
+
431
+
432
+
433
 
434
 
435
+ with gr.Column():
436
+ gallery2 = gr.Image(label="Identity from Edited Model")
437
  prompt = gr.Textbox(label="Prompt",
438
+ info="Make sure to include 'sks person'" ,
439
+ placeholder="sks person",
440
+ value="sks person")
441
+ seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
442
+ with gr.Row():
443
  a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
444
+ a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
445
+ with gr.Row():
 
446
  a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
447
+ a4 = gr.Slider(label="- placeholder +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
448
+
449
+
450
+ with gr.Accordion("Advanced Options", open=False):
451
+ with gr.Column():
452
+ cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
453
+ steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
454
+ injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
455
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity", value="low quality, blurry, unfinished, nudity")
456
+
457
+ submit = gr.Button("Generate")
458
+
459
+ file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
460
 
 
 
 
 
 
 
 
461
 
462
 
463
 
464
+
465
 
466
+ sample.click(fn=sample_then_run, outputs=[gallery1, file_output])
467
 
468
  submit.click(fn=edit_inference,
469
  inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
 
474
  with gr.Tab("Inversion"):
475
  with gr.Row():
476
  with gr.Column():
477
+ input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
478
+ height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
479
+
 
480
 
481
  lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
 
482
  pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
483
+ with gr.Accordion("Advanced Options", open=False):
484
+ with gr.Column():
485
+ epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
486
+ weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
487
+
488
+
489
 
490
  invert_button = gr.Button("Invert")
491
 
 
495
  info="Make sure to include 'sks person'" ,
496
  placeholder="sks person",
497
  value="sks person")
 
498
  seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
499
+ with gr.Accordion("Advanced Options", open=False):
500
+ with gr.Column():
501
+ cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
502
+ steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
503
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity", value="low quality, blurry, unfinished, nudity")
504
  submit = gr.Button("Generate")
505
 
506
+ file_output = gr.File(label="Download Inverted Model", container=True, interactive=False)
507
 
508
 
509
 
 
517
  outputs=gallery)
518
 
519
 
520
+ demo.queue().launch()