Spaces:
Running
on
Zero
Running
on
Zero
amildravid4292
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -290,29 +290,29 @@ def sample_then_run(self):
|
|
290 |
|
291 |
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
|
314 |
-
|
315 |
-
|
316 |
unet,
|
317 |
rank=1,
|
318 |
multiplier=1.0,
|
@@ -320,87 +320,83 @@ def sample_then_run(self):
|
|
320 |
train_method="xattn-strict"
|
321 |
).to(device, torch.bfloat16)
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
|
330 |
-
|
331 |
-
|
332 |
transforms.RandomCrop(512),
|
333 |
transforms.ToTensor(),
|
334 |
transforms.Normalize([0.5], [0.5])])
|
335 |
|
336 |
|
337 |
-
|
338 |
-
|
339 |
|
340 |
-
|
341 |
-
|
342 |
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
torch.save(network.proj, "model.pt" )
|
387 |
-
return image, "model.pt"
|
388 |
|
389 |
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
proj = torch.load(file.name).to(device)
|
394 |
-
|
395 |
-
#pad to 10000 Principal components to keep everything consistent
|
396 |
-
pcs = proj.shape[1]
|
397 |
-
padding = torch.zeros((1,10000-pcs)).to(device)
|
398 |
-
proj = torch.cat((proj, padding), 1)
|
399 |
|
400 |
-
|
|
|
|
|
|
|
|
|
401 |
|
402 |
|
403 |
-
|
404 |
unet,
|
405 |
rank=1,
|
406 |
multiplier=1.0,
|
@@ -409,13 +405,13 @@ def sample_then_run(self):
|
|
409 |
).to(device, torch.bfloat16)
|
410 |
|
411 |
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
|
420 |
|
421 |
|
|
|
290 |
|
291 |
|
292 |
|
293 |
+
class CustomImageDataset(Dataset):
|
294 |
+
def __init__(self, images, transform=None):
|
295 |
+
self.images = images
|
296 |
+
self.transform = transform
|
297 |
+
|
298 |
+
def __len__(self):
|
299 |
+
return len(self.images)
|
300 |
+
|
301 |
+
def __getitem__(self, idx):
|
302 |
+
image = self.images[idx]
|
303 |
+
if self.transform:
|
304 |
+
image = self.transform(image)
|
305 |
+
return image
|
306 |
+
|
307 |
+
@spaces.GPU
|
308 |
+
def invert(self, image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
309 |
|
310 |
+
del unet
|
311 |
+
del network
|
312 |
+
unet, _, _, _, _ = load_models(device)
|
313 |
|
314 |
+
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
315 |
+
network = LoRAw2w( proj, mean, std, v[:, :pcs],
|
316 |
unet,
|
317 |
rank=1,
|
318 |
multiplier=1.0,
|
|
|
320 |
train_method="xattn-strict"
|
321 |
).to(device, torch.bfloat16)
|
322 |
|
323 |
+
### load mask
|
324 |
+
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
|
325 |
+
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
|
326 |
+
### check if an actual mask was draw, otherwise mask is just all ones
|
327 |
+
if torch.sum(mask) == 0:
|
328 |
+
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
|
329 |
|
330 |
+
### single image dataset
|
331 |
+
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
332 |
transforms.RandomCrop(512),
|
333 |
transforms.ToTensor(),
|
334 |
transforms.Normalize([0.5], [0.5])])
|
335 |
|
336 |
|
337 |
+
train_dataset = CustomImageDataset(image, transform=image_transforms)
|
338 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
339 |
|
340 |
+
### optimizer
|
341 |
+
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
|
342 |
|
343 |
+
### training loop
|
344 |
+
unet.train()
|
345 |
+
for epoch in tqdm.tqdm(range(epochs)):
|
346 |
+
for batch in train_dataloader:
|
347 |
+
### prepare inputs
|
348 |
+
batch = batch.to(device).bfloat16()
|
349 |
+
latents = vae.encode(batch).latent_dist.sample()
|
350 |
+
latents = latents*0.18215
|
351 |
+
noise = torch.randn_like(latents)
|
352 |
+
bsz = latents.shape[0]
|
353 |
|
354 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
355 |
+
timesteps = timesteps.long()
|
356 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
357 |
+
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
358 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
|
359 |
+
|
360 |
+
### loss + sgd step
|
361 |
+
with network:
|
362 |
+
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
|
363 |
+
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
|
364 |
+
optim.zero_grad()
|
365 |
+
loss.backward()
|
366 |
+
optim.step()
|
367 |
+
|
368 |
+
### return optimized network
|
369 |
+
return network
|
370 |
|
371 |
|
372 |
+
@spaces.GPU
|
373 |
+
def run_inversion(self, dict, pcs, epochs, weight_decay,lr):
|
374 |
+
init_image = dict["image"].convert("RGB").resize((512, 512))
|
375 |
+
mask = dict["mask"].convert("RGB").resize((512, 512))
|
376 |
+
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
|
377 |
+
#sample an image
|
378 |
+
prompt = "sks person"
|
379 |
+
negative_prompt = "low quality, blurry, unfinished, nudity"
|
380 |
+
seed = 5
|
381 |
+
cfg = 3.0
|
382 |
+
steps = 25
|
383 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
384 |
+
torch.save(network.proj, "model.pt" )
|
385 |
+
return image, "model.pt"
|
|
|
|
|
386 |
|
387 |
|
388 |
+
@spaces.GPU
|
389 |
+
def file_upload(self, file):
|
390 |
+
proj = torch.load(file.name).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
+
#pad to 10000 Principal components to keep everything consistent
|
393 |
+
pcs = proj.shape[1]
|
394 |
+
padding = torch.zeros((1,10000-pcs)).to(device)
|
395 |
+
proj = torch.cat((proj, padding), 1)
|
396 |
+
unet, _, _, _, _ = load_models(device)
|
397 |
|
398 |
|
399 |
+
network = LoRAw2w( proj, mean, std, v[:, :10000],
|
400 |
unet,
|
401 |
rank=1,
|
402 |
multiplier=1.0,
|
|
|
405 |
).to(device, torch.bfloat16)
|
406 |
|
407 |
|
408 |
+
prompt = "sks person"
|
409 |
+
negative_prompt = "low quality, blurry, unfinished, nudity"
|
410 |
+
seed = 5
|
411 |
+
cfg = 3.0
|
412 |
+
steps = 25
|
413 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
414 |
+
return image
|
415 |
|
416 |
|
417 |
|