amildravid4292 commited on
Commit
cf6ad0d
·
verified ·
1 Parent(s): 51836fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -93
app.py CHANGED
@@ -290,29 +290,29 @@ def sample_then_run(self):
290
 
291
 
292
 
293
- class CustomImageDataset(Dataset):
294
- def __init__(self, images, transform=None):
295
- self.images = images
296
- self.transform = transform
297
-
298
- def __len__(self):
299
- return len(self.images)
300
-
301
- def __getitem__(self, idx):
302
- image = self.images[idx]
303
- if self.transform:
304
- image = self.transform(image)
305
- return image
306
-
307
- @spaces.GPU
308
- def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
309
 
310
- del unet
311
- del network
312
- unet, _, _, _, _ = load_models(device)
313
 
314
- proj = torch.zeros(1,pcs).bfloat16().to(device)
315
- network = LoRAw2w( proj, mean, std, v[:, :pcs],
316
  unet,
317
  rank=1,
318
  multiplier=1.0,
@@ -320,87 +320,83 @@ def sample_then_run(self):
320
  train_method="xattn-strict"
321
  ).to(device, torch.bfloat16)
322
 
323
- ### load mask
324
- mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
325
- mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
326
- ### check if an actual mask was draw, otherwise mask is just all ones
327
- if torch.sum(mask) == 0:
328
- mask = torch.ones((1,1,64,64)).to(device).bfloat16()
329
 
330
- ### single image dataset
331
- image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
332
  transforms.RandomCrop(512),
333
  transforms.ToTensor(),
334
  transforms.Normalize([0.5], [0.5])])
335
 
336
 
337
- train_dataset = CustomImageDataset(image, transform=image_transforms)
338
- train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
339
 
340
- ### optimizer
341
- optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
342
 
343
- ### training loop
344
- unet.train()
345
- for epoch in tqdm.tqdm(range(epochs)):
346
- for batch in train_dataloader:
347
- ### prepare inputs
348
- batch = batch.to(device).bfloat16()
349
- latents = vae.encode(batch).latent_dist.sample()
350
- latents = latents*0.18215
351
- noise = torch.randn_like(latents)
352
- bsz = latents.shape[0]
353
 
354
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
355
- timesteps = timesteps.long()
356
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
357
- text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
358
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
359
-
360
- ### loss + sgd step
361
- with network:
362
- model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
363
- loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
364
- optim.zero_grad()
365
- loss.backward()
366
- optim.step()
367
-
368
- ### return optimized network
369
- return network
370
 
371
 
372
- @spaces.GPU
373
- def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
374
- init_image = dict["image"].convert("RGB").resize((512, 512))
375
- mask = dict["mask"].convert("RGB").resize((512, 512))
376
- network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
377
-
378
-
379
- #sample an image
380
- prompt = "sks person"
381
- negative_prompt = "low quality, blurry, unfinished, nudity"
382
- seed = 5
383
- cfg = 3.0
384
- steps = 25
385
- image = inference( prompt, negative_prompt, cfg, steps, seed)
386
- torch.save(network.proj, "model.pt" )
387
- return image, "model.pt"
388
 
389
 
390
- @spaces.GPU
391
- def file_upload(self, file):
392
-
393
- proj = torch.load(file.name).to(device)
394
-
395
- #pad to 10000 Principal components to keep everything consistent
396
- pcs = proj.shape[1]
397
- padding = torch.zeros((1,10000-pcs)).to(device)
398
- proj = torch.cat((proj, padding), 1)
399
 
400
- unet, _, _, _, _ = load_models(device)
 
 
 
 
401
 
402
 
403
- network = LoRAw2w( proj, mean, std, v[:, :10000],
404
  unet,
405
  rank=1,
406
  multiplier=1.0,
@@ -409,13 +405,13 @@ def sample_then_run(self):
409
  ).to(device, torch.bfloat16)
410
 
411
 
412
- prompt = "sks person"
413
- negative_prompt = "low quality, blurry, unfinished, nudity"
414
- seed = 5
415
- cfg = 3.0
416
- steps = 25
417
- image = inference( prompt, negative_prompt, cfg, steps, seed)
418
- return image
419
 
420
 
421
 
 
290
 
291
 
292
 
293
+ class CustomImageDataset(Dataset):
294
+ def __init__(self, images, transform=None):
295
+ self.images = images
296
+ self.transform = transform
297
+
298
+ def __len__(self):
299
+ return len(self.images)
300
+
301
+ def __getitem__(self, idx):
302
+ image = self.images[idx]
303
+ if self.transform:
304
+ image = self.transform(image)
305
+ return image
306
+
307
+ @spaces.GPU
308
+ def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
309
 
310
+ del unet
311
+ del network
312
+ unet, _, _, _, _ = load_models(device)
313
 
314
+ proj = torch.zeros(1,pcs).bfloat16().to(device)
315
+ network = LoRAw2w( proj, mean, std, v[:, :pcs],
316
  unet,
317
  rank=1,
318
  multiplier=1.0,
 
320
  train_method="xattn-strict"
321
  ).to(device, torch.bfloat16)
322
 
323
+ ### load mask
324
+ mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
325
+ mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
326
+ ### check if an actual mask was draw, otherwise mask is just all ones
327
+ if torch.sum(mask) == 0:
328
+ mask = torch.ones((1,1,64,64)).to(device).bfloat16()
329
 
330
+ ### single image dataset
331
+ image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
332
  transforms.RandomCrop(512),
333
  transforms.ToTensor(),
334
  transforms.Normalize([0.5], [0.5])])
335
 
336
 
337
+ train_dataset = CustomImageDataset(image, transform=image_transforms)
338
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
339
 
340
+ ### optimizer
341
+ optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
342
 
343
+ ### training loop
344
+ unet.train()
345
+ for epoch in tqdm.tqdm(range(epochs)):
346
+ for batch in train_dataloader:
347
+ ### prepare inputs
348
+ batch = batch.to(device).bfloat16()
349
+ latents = vae.encode(batch).latent_dist.sample()
350
+ latents = latents*0.18215
351
+ noise = torch.randn_like(latents)
352
+ bsz = latents.shape[0]
353
 
354
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
355
+ timesteps = timesteps.long()
356
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
357
+ text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
358
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
359
+
360
+ ### loss + sgd step
361
+ with network:
362
+ model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
363
+ loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
364
+ optim.zero_grad()
365
+ loss.backward()
366
+ optim.step()
367
+
368
+ ### return optimized network
369
+ return network
370
 
371
 
372
+ @spaces.GPU
373
+ def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
374
+ init_image = dict["image"].convert("RGB").resize((512, 512))
375
+ mask = dict["mask"].convert("RGB").resize((512, 512))
376
+ network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
377
+ #sample an image
378
+ prompt = "sks person"
379
+ negative_prompt = "low quality, blurry, unfinished, nudity"
380
+ seed = 5
381
+ cfg = 3.0
382
+ steps = 25
383
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
384
+ torch.save(network.proj, "model.pt" )
385
+ return image, "model.pt"
 
 
386
 
387
 
388
+ @spaces.GPU
389
+ def file_upload(self, file):
390
+ proj = torch.load(file.name).to(device)
 
 
 
 
 
 
391
 
392
+ #pad to 10000 Principal components to keep everything consistent
393
+ pcs = proj.shape[1]
394
+ padding = torch.zeros((1,10000-pcs)).to(device)
395
+ proj = torch.cat((proj, padding), 1)
396
+ unet, _, _, _, _ = load_models(device)
397
 
398
 
399
+ network = LoRAw2w( proj, mean, std, v[:, :10000],
400
  unet,
401
  rank=1,
402
  multiplier=1.0,
 
405
  ).to(device, torch.bfloat16)
406
 
407
 
408
+ prompt = "sks person"
409
+ negative_prompt = "low quality, blurry, unfinished, nudity"
410
+ seed = 5
411
+ cfg = 3.0
412
+ steps = 25
413
+ image = inference( prompt, negative_prompt, cfg, steps, seed)
414
+ return image
415
 
416
 
417