import os # os.system("pip uninstall -y gradio") # #os.system('pip install gradio==3.43.1') import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import os import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import torch import gc import warnings warnings.filterwarnings("ignore") from PIL import Image from utils import load_models, save_model_w2w, save_model_for_diffusers from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from huggingface_hub import snapshot_download import numpy as np global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global network global original_image device = "cuda:0" generator = torch.Generator(device=device) from gradio_imageslider import ImageSlider import spaces models_path = snapshot_download(repo_id="Snapchat/w2w") mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device) unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) @spaces.GPU() def sample_model(): global unet del unet global network unet, _, _, _, _ = load_models(device) network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) @torch.no_grad() def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global young global pointy global wavy global large global original_image original_weights = network.proj.clone() #pad to same number of PCs pcs_original = original_weights.shape[1] pcs_edits = young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((young, padding), 1) pointy_pad = torch.cat((pointy, padding), 1) wavy_pad = torch.cat((wavy, padding), 1) large_pad = torch.cat((large, padding), 1) edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: network.proj = torch.nn.Parameter(edited_weights) network.reset() with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) #reset weights back to original network.proj = torch.nn.Parameter(original_weights) network.reset() return (original_image, image) def sample_then_run(): global original_image sample_model() prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 original_image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return (original_image, original_image), "model.pt" global young global pointy global wavy global large young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) large = debias(large, "Male", df, pinverse, device) large = debias(large, "Young", df, pinverse, device) large = debias(large, "Pointy_Nose", df, pinverse, device) large = debias(large, "Wavy_Hair", df, pinverse, device) large = debias(large, "Mustache", df, pinverse, device) large = debias(large, "No_Beard", df, pinverse, device) large = debias(large, "Sideburns", df, pinverse, device) large = debias(large, "Big_Nose", df, pinverse, device) large = debias(large, "Big_Lips", df, pinverse, device) large = debias(large, "Black_Hair", df, pinverse, device) large = debias(large, "Brown_Hair", df, pinverse, device) large = debias(large, "Pale_Skin", df, pinverse, device) large = debias(large, "Heavy_Makeup", df, pinverse, device) class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): global unet del unet global network image = dict["background"].convert("RGB").resize((512, 512)) mask = dict["layers"][0].convert("RGB").resize((512, 512)) unet, _, _, _, _ = load_models(device) proj = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( proj, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset([image], transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() ### return optimized network return network def run_inversion(dict, pcs, epochs, weight_decay,lr): global network global original_image # init_image = dict["image"].convert("RGB").resize((512, 512)) # mask = dict["ma print(dict) network = invert( dict, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 original_image = inference( prompt, negative_prompt, cfg, steps, seed) torch.save(network.proj, "model.pt" ) return (original_image, original_image), "model.pt" def file_upload(file): global unet del unet global network global device global original_image proj = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = proj.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) proj = torch.cat((proj, padding), 1) unet, _, _, _, _ = load_models(device) network = LoRAw2w( proj, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 50 original_image = inference( prompt, negative_prompt, cfg, steps, seed) return (original_image, original_image) intro = """

Interpreting the Weight Space of Customized Diffusion Models (aka weights2weights)

Project Page | Paper | Code | Duplicate Space

""" with gr.Blocks(css="style.css") as demo: gr.HTML(intro) with gr.Tab("Model Editing"): gr.Markdown(""" Click the `Sample New Model` to sample a new identity-encoding model or upload a model to get started ✨ """) with gr.Column(): with gr.Row(): with gr.Column(): sample = gr.Button("🎲 Sample New Model") file_output1 = gr.File(label="Download Sampled Model", container=True, interactive=False) file_input = gr.File(label="Upload Model", container=True) with gr.Column(): image_slider1 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User") prompt1 = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True) with gr.Row(): a1_1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a2_1 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Row(): a3_1 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a4_1 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Accordion("Advanced Options", open=False): cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") injection_step1 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) submit1 = gr.Button("Generate") with gr.Tab("Inversion"): gr.Markdown(""" Upload an image and optionally define a mask by drawing over the face. Then click `invert` to get started ✨ """) with gr.Row(): with gr.Column(): input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False) lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True) weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) invert_button = gr.Button("Invert") file_output2 = gr.File(label="Download Inverted Model", container=True, interactive=False) with gr.Column(): image_slider2 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User") prompt2 = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True) with gr.Row(): a1_2 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a2_2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Row(): a3_2 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a4_2 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Accordion("Advanced Options", open=False): cfg2= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") injection_step2 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) submit2 = gr.Button("Generate") sample.click(fn=sample_then_run, outputs=[image_slider1, file_output1]) submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step1, a1_1, a2_1, a3_1, a4_1], outputs=image_slider1) file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider1) invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [image_slider2, file_output2]) submit2.click(fn=edit_inference, inputs=[ prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step2, a1_2, a2_2, a3_2, a4_2], outputs=image_slider2) demo.queue().launch()