import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import uuid import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import gc import warnings warnings.filterwarnings("ignore") from PIL import Image import numpy as np from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from transformers import CLIPTextModel from lora_w2w import LoRAw2w from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler from transformers import AutoTokenizer, PretrainedConfig from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline ) from huggingface_hub import snapshot_download import spaces models_path = snapshot_download(repo_id="Snapchat/w2w") device = "cuda" pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" revision = None weight_dtype = torch.bfloat16 # Load scheduler, tokenizer and models. pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", torch_dtype=torch.float16,safety_checker = None, requires_safety_checker = False).to(device) noise_scheduler = pipe.scheduler del pipe tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", revision=revision ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision ) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", revision=revision ) unet.requires_grad_(False) unet.to(device, dtype=weight_dtype) vae.requires_grad_(False) text_encoder.requires_grad_(False) vae.requires_grad_(False) vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) print("") mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) young = debias(young, "No_Beard", df, pinverse, device) young = debias(young, "Mustache", df, pinverse, device) pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) thick = debias(thick, "Male", df, pinverse, device) thick = debias(thick, "Young", df, pinverse, device) thick = debias(thick, "Pointy_Nose", df, pinverse, device) thick = debias(thick, "Wavy_Hair", df, pinverse, device) thick = debias(thick, "Mustache", df, pinverse, device) thick = debias(thick, "No_Beard", df, pinverse, device) thick = debias(thick, "Sideburns", df, pinverse, device) thick = debias(thick, "Big_Nose", df, pinverse, device) thick = debias(thick, "Big_Lips", df, pinverse, device) thick = debias(thick, "Black_Hair", df, pinverse, device) thick = debias(thick, "Brown_Hair", df, pinverse, device) thick = debias(thick, "Pale_Skin", df, pinverse, device) thick = debias(thick, "Heavy_Makeup", df, pinverse, device) @torch.no_grad() @spaces.GPU def sample_then_run(net): device = "cuda" # get mean and standard deviation for each principal component m = torch.mean(proj, 0) standev = torch.std(proj, 0) # sample sample = torch.zeros([1, 10000]).to(device) #only first 1000 PCs for i in range(1000): sample[0, i] = torch.normal(m[i], standev[i], (1,1)) net = "model_"+str(uuid.uuid4())[:4]+".pt" torch.save(sample, net) image = prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = inference(net, prompt, negative_prompt, cfg, steps, seed) return net,net,image @torch.no_grad() @spaces.GPU() def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed): device = "cuda" mean.to(device) std.to(device) v.to(device) weights = torch.load(net).to(device) network = LoRAw2w(weights, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16() noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with torch.no_grad(): with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() @spaces.GPU() def edit_inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): device = "cuda" mean.to(device) std.to(device) v.to(device) weights = torch.load(net).to(device) network = LoRAw2w(weights, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) #pad to same number of PCs pcs_original = weights.shape[1] pcs_edits = young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((young.to(device), padding), 1) pointy_pad = torch.cat((pointy.to(device), padding), 1) wavy_pad = torch.cat((wavy.to(device), padding), 1) thick_pad = torch.cat((thick.to(device), padding), 1) edited_weights = weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16() noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: network.proj = torch.nn.Parameter(edited_weights) network.reset() with torch.no_grad(): with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return net, image class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image @spaces.GPU def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): device = "cuda" mean.to(device) std.to(device) v.to(device) weights = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( weights, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset(image, transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() #pad to 10000 PCs pcs_original = weights.shape[1] padding = torch.zeros((1,10000-pcs_original)).to(device) weights = network.proj.detach() weights = torch.cat((weights, padding), 1) net = "model_"+str(uuid.uuid4())[:4]+".pt" torch.save(weights, net) return net @spaces.GPU def run_inversion(net, dict, pcs, epochs, weight_decay,lr): init_image = dict["background"].convert("RGB").resize((512, 512)) mask = dict["layers"][0].convert("RGB").resize((512, 512)) net = invert([init_image], mask, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity" seed = 5 cfg = 3.0 steps = 25 image = inference( net, prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return net, net, image @spaces.GPU def file_upload(file, net): device="cuda" weights = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = weights.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) weights = torch.cat((weights, padding), 1) net = "model_"+str(uuid.uuid4())[:4]+".pt" torch.save(weights, net) image = prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = inference(net, prompt, negative_prompt, cfg, steps, seed) return net, image intro = """
Project Page | Paper
| Code |