amildravid4292 commited on
Commit
d8a6bcb
·
verified ·
1 Parent(s): aa1da5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -99
app.py CHANGED
@@ -60,7 +60,6 @@ def sample_model():
60
 
61
 
62
 
63
-
64
  @torch.no_grad()
65
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
66
  global device
@@ -140,7 +139,7 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
140
  large_pad = torch.cat((large, padding), 1)
141
 
142
 
143
- edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*large_pad
144
 
145
  generator = generator.manual_seed(seed)
146
  latents = torch.randn(
@@ -200,22 +199,18 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
200
  return image
201
 
202
 
203
-
204
  def sample_then_run():
205
  sample_model()
206
  prompt = "sks person"
207
- negative_prompt = "low quality, blurry, unfinished, nudity"
208
  seed = 5
209
  cfg = 3.0
210
  steps = 50
211
  image = inference( prompt, negative_prompt, cfg, steps, seed)
212
- torch.save(network.proj, "sampled_model.pt" )
213
- return image, "sampled_model.pt"
214
-
215
-
216
 
217
 
218
- #directions
219
  global young
220
  global pointy
221
  global wavy
@@ -226,8 +221,7 @@ young = debias(young, "Male", df, pinverse, device)
226
  young = debias(young, "Pointy_Nose", df, pinverse, device)
227
  young = debias(young, "Wavy_Hair", df, pinverse, device)
228
  young = debias(young, "Chubby", df, pinverse, device)
229
- young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item()
230
- young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item()
231
 
232
  pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
233
  pointy = debias(pointy, "Young", df, pinverse, device)
@@ -235,8 +229,7 @@ pointy = debias(pointy, "Male", df, pinverse, device)
235
  pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
236
  pointy = debias(pointy, "Chubby", df, pinverse, device)
237
  pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
238
- pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item()
239
- pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item()
240
 
241
 
242
  wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
@@ -245,16 +238,24 @@ wavy = debias(wavy, "Male", df, pinverse, device)
245
  wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
246
  wavy = debias(wavy, "Chubby", df, pinverse, device)
247
  wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
248
- wavy_max = torch.max(proj@wavy[0]/(torch.norm(wavy))**2).item()
249
- wavy_min = torch.min(proj@wavy[0]/(torch.norm(wavy))**2).item()
250
 
251
- large = get_direction(df, "Chubby", pinverse, 1000, device)
 
252
  large = debias(large, "Male", df, pinverse, device)
253
  large = debias(large, "Young", df, pinverse, device)
254
  large = debias(large, "Pointy_Nose", df, pinverse, device)
255
  large = debias(large, "Wavy_Hair", df, pinverse, device)
256
- large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
257
- large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
 
 
 
 
 
 
 
 
 
258
 
259
  class CustomImageDataset(Dataset):
260
  def __init__(self, images, transform=None):
@@ -352,13 +353,16 @@ def run_inversion(dict, pcs, epochs, weight_decay,lr):
352
 
353
  #sample an image
354
  prompt = "sks person"
355
- negative_prompt = "low quality, blurry, unfinished, nudity"
356
  seed = 5
357
  cfg = 3.0
358
  steps = 50
359
  image = inference( prompt, negative_prompt, cfg, steps, seed)
360
- torch.save(network.proj, "inverted_model.pt" )
361
- return image, "inverted_model.pt"
 
 
 
362
 
363
 
364
 
@@ -367,6 +371,7 @@ def file_upload(file):
367
  del unet
368
  global network
369
  global device
 
370
 
371
 
372
 
@@ -379,18 +384,18 @@ def file_upload(file):
379
 
380
  unet, _, _, _, _ = load_models(device)
381
 
382
-
383
- network = LoRAw2w( proj, mean, std, v[:, :pcs],
384
  unet,
385
  rank=1,
386
  multiplier=1.0,
387
  alpha=27.0,
388
  train_method="xattn-strict"
389
  ).to(device, torch.bfloat16)
390
-
391
 
392
  prompt = "sks person"
393
- negative_prompt = "low quality, blurry, unfinished, nudity"
394
  seed = 5
395
  cfg = 3.0
396
  steps = 50
@@ -401,6 +406,11 @@ def file_upload(file):
401
 
402
 
403
 
 
 
 
 
 
404
  intro = """
405
  <div style="display: flex;align-items: center;justify-content: center">
406
  <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
@@ -420,101 +430,100 @@ intro = """
420
 
421
  with gr.Blocks(css="style.css") as demo:
422
  gr.HTML(intro)
423
- with gr.Tab("Sampling Models + Editing"):
424
- with gr.Row():
425
- with gr.Column():
426
- gallery1 = gr.Image(label="Identity from Sampled or Uploaded Model")
427
- sample = gr.Button("Sample New Model")
428
- file_input = gr.File(label="Upload Model", container=True)
429
- file_input.change(fn=file_upload, inputs=file_input, outputs = gallery1)
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
 
432
 
433
 
434
 
435
- with gr.Column():
436
- gallery2 = gr.Image(label="Identity from Edited Model")
437
- prompt = gr.Textbox(label="Prompt",
438
  info="Make sure to include 'sks person'" ,
439
  placeholder="sks person",
440
  value="sks person")
441
- seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
442
- with gr.Row():
443
- a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
444
- a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
445
- with gr.Row():
446
- a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
447
- a4 = gr.Slider(label="- placeholder +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
448
-
449
 
450
- with gr.Accordion("Advanced Options", open=False):
451
- with gr.Column():
452
- cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
453
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
454
- injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
455
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity", value="low quality, blurry, unfinished, nudity")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
- submit = gr.Button("Generate")
458
 
 
459
  file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
 
460
 
461
 
462
 
463
-
 
 
 
464
 
465
 
466
  sample.click(fn=sample_then_run, outputs=[gallery1, file_output])
467
 
468
- submit.click(fn=edit_inference,
469
- inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
 
 
 
470
  outputs=gallery2)
471
-
472
-
473
-
474
- with gr.Tab("Inversion"):
475
- with gr.Row():
476
- with gr.Column():
477
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
478
- height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
479
-
480
-
481
- lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
482
- pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
483
- with gr.Accordion("Advanced Options", open=False):
484
- with gr.Column():
485
- epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
486
- weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
487
-
488
-
489
 
490
- invert_button = gr.Button("Invert")
491
-
492
- with gr.Column():
493
- gallery = gr.Image(label="Sample from Inverted Model", height=512, width=512)
494
- prompt = gr.Textbox(label="Prompt",
495
- info="Make sure to include 'sks person'" ,
496
- placeholder="sks person",
497
- value="sks person")
498
- seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
499
- with gr.Accordion("Advanced Options", open=False):
500
- with gr.Column():
501
- cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
502
- steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
503
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity", value="low quality, blurry, unfinished, nudity")
504
- submit = gr.Button("Generate")
505
-
506
- file_output = gr.File(label="Download Inverted Model", container=True, interactive=False)
507
 
508
-
509
 
510
-
511
- invert_button.click(fn=run_inversion,
512
- inputs=[input_image, pcs, epochs, weight_decay,lr],
513
- outputs = [gallery, file_output])
514
-
515
- submit.click(fn=inference,
516
- inputs=[prompt, negative_prompt, cfg, steps, seed,],
517
- outputs=gallery)
518
-
519
 
520
- demo.queue().launch()
 
 
60
 
61
 
62
 
 
63
  @torch.no_grad()
64
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
65
  global device
 
139
  large_pad = torch.cat((large, padding), 1)
140
 
141
 
142
+ edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad
143
 
144
  generator = generator.manual_seed(seed)
145
  latents = torch.randn(
 
199
  return image
200
 
201
 
 
202
  def sample_then_run():
203
  sample_model()
204
  prompt = "sks person"
205
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
206
  seed = 5
207
  cfg = 3.0
208
  steps = 50
209
  image = inference( prompt, negative_prompt, cfg, steps, seed)
210
+ torch.save(network.proj, "model.pt" )
211
+ return image, "model.pt"
 
 
212
 
213
 
 
214
  global young
215
  global pointy
216
  global wavy
 
221
  young = debias(young, "Pointy_Nose", df, pinverse, device)
222
  young = debias(young, "Wavy_Hair", df, pinverse, device)
223
  young = debias(young, "Chubby", df, pinverse, device)
224
+
 
225
 
226
  pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
227
  pointy = debias(pointy, "Young", df, pinverse, device)
 
229
  pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
230
  pointy = debias(pointy, "Chubby", df, pinverse, device)
231
  pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
232
+
 
233
 
234
 
235
  wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
 
238
  wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
239
  wavy = debias(wavy, "Chubby", df, pinverse, device)
240
  wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
 
 
241
 
242
+
243
+ large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
244
  large = debias(large, "Male", df, pinverse, device)
245
  large = debias(large, "Young", df, pinverse, device)
246
  large = debias(large, "Pointy_Nose", df, pinverse, device)
247
  large = debias(large, "Wavy_Hair", df, pinverse, device)
248
+ large = debias(large, "Mustache", df, pinverse, device)
249
+ large = debias(large, "No_Beard", df, pinverse, device)
250
+ large = debias(large, "Sideburns", df, pinverse, device)
251
+ large = debias(large, "Big_Nose", df, pinverse, device)
252
+ large = debias(large, "Big_Lips", df, pinverse, device)
253
+ large = debias(large, "Black_Hair", df, pinverse, device)
254
+ large = debias(large, "Brown_Hair", df, pinverse, device)
255
+ large = debias(large, "Pale_Skin", df, pinverse, device)
256
+ large = debias(large, "Heavy_Makeup", df, pinverse, device)
257
+
258
+
259
 
260
  class CustomImageDataset(Dataset):
261
  def __init__(self, images, transform=None):
 
353
 
354
  #sample an image
355
  prompt = "sks person"
356
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
357
  seed = 5
358
  cfg = 3.0
359
  steps = 50
360
  image = inference( prompt, negative_prompt, cfg, steps, seed)
361
+ torch.save(network.proj, "model.pt" )
362
+ return image, "model.pt"
363
+
364
+
365
+
366
 
367
 
368
 
 
371
  del unet
372
  global network
373
  global device
374
+
375
 
376
 
377
 
 
384
 
385
  unet, _, _, _, _ = load_models(device)
386
 
387
+
388
+ network = LoRAw2w( proj, mean, std, v[:, :10000],
389
  unet,
390
  rank=1,
391
  multiplier=1.0,
392
  alpha=27.0,
393
  train_method="xattn-strict"
394
  ).to(device, torch.bfloat16)
395
+
396
 
397
  prompt = "sks person"
398
+ negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
399
  seed = 5
400
  cfg = 3.0
401
  steps = 50
 
406
 
407
 
408
 
409
+
410
+
411
+
412
+
413
+
414
  intro = """
415
  <div style="display: flex;align-items: center;justify-content: center">
416
  <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
 
430
 
431
  with gr.Blocks(css="style.css") as demo:
432
  gr.HTML(intro)
433
+ with gr.Row():
434
+ with gr.Column():
435
+ gr.Markdown("""<div style="text-align: justify;"> Click below to sample an identity-encoding model.""")
436
+ sample = gr.Button("Sample New Model")
437
+ gr.Markdown("""<div style="text-align: justify;"> Or upload an image below and click \"invert\". You can also optionally draw over the face to define a mask.""")
438
+ input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
439
+ height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
440
+
441
+ lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
442
+ pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
443
+ with gr.Accordion("Advanced Options", open=False):
444
+ with gr.Column():
445
+ epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
446
+ weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
447
+
448
+ invert_button = gr.Button("Invert")
449
+
450
+ gr.Markdown("""<div style="text-align: justify;"> Or you can upload a model below downloaded from this demo.""")
451
+
452
+ file_input = gr.File(label="Upload Model", container=True)
453
 
454
 
455
 
456
 
457
 
458
+ with gr.Column():
459
+ gallery1 = gr.Image(label="Identity from Original Model", interactive=False)
460
+ prompt1 = gr.Textbox(label="Prompt",
461
  info="Make sure to include 'sks person'" ,
462
  placeholder="sks person",
463
  value="sks person")
464
+ seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
 
 
 
 
 
 
 
465
 
466
+ with gr.Accordion("Advanced Options", open=False):
467
+ with gr.Column():
468
+ cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
469
+ steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
470
+ negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
471
+
472
+ submit1 = gr.Button("Generate")
473
+
474
+
475
+
476
+ with gr.Column():
477
+ gallery2 = gr.Image(label="Identity from Edited Model", interactive=False )
478
+ with gr.Row():
479
+ a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
480
+ a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
481
+ with gr.Row():
482
+ a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
483
+ a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
484
+ prompt2 = gr.Textbox(label="Prompt",
485
+ info="Make sure to include 'sks person'" ,
486
+ placeholder="sks person",
487
+ value="sks person")
488
+ seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
489
+ with gr.Accordion("Advanced Options", open=False):
490
+ with gr.Column():
491
+ cfg2 = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
492
+ steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
493
+ injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
494
+ negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
495
+
496
+ submit2 = gr.Button("Generate")
497
+
498
+
499
 
500
+ gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
501
 
502
+ with gr.Row():
503
  file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
504
+
505
 
506
 
507
 
508
+
509
+ invert_button.click(fn=run_inversion,
510
+ inputs=[input_image, pcs, epochs, weight_decay,lr],
511
+ outputs = [gallery1, file_output])
512
 
513
 
514
  sample.click(fn=sample_then_run, outputs=[gallery1, file_output])
515
 
516
+ submit1.click(fn=inference,
517
+ inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1],
518
+ outputs=gallery1)
519
+ submit2.click(fn=edit_inference,
520
+ inputs=[prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step, a1, a2, a3, a4],
521
  outputs=gallery2)
522
+ file_input.change(fn=file_upload, inputs=file_input, outputs = gallery1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
 
525
 
526
+
 
 
 
 
 
 
 
 
527
 
528
+
529
+ demo.queue().launch()