import gradio as gr import sys import os import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import torch import gc import warnings warnings.filterwarnings("ignore") from PIL import Image from utils import load_models, save_model_w2w, save_model_for_diffusers from sampling import sample_weights from editing import get_direction, debias from huggingface_hub import snapshot_download global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global young_val global pointy_val global bags_val device = "cuda:0" generator = torch.Generator(device=device) models_path = snapshot_download(repo_id="Snapchat/w2w") mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device) std = torch.load(f"{models_path}/std.pt").bfloat16().to(device) v = torch.load(f"{models_path}/V.pt").bfloat16().to(device) proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device) df = torch.load(f"{models_path}/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/pinverse_1000pc.pt").bfloat16().to(device) unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) global network def sample_model(): global unet del unet global network unet, _, _, _, _ = load_models(device) network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) @torch.no_grad() def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return [image] @torch.no_grad() def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3): global device global generator global unet global vae global text_encoder global tokenizer global noise_scheduler global young global pointy global bags original_weights = network.proj.clone() edited_weights = original_weights+a1*young+a2*pointy+a3*bags generator = generator.manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: network.proj = torch.nn.Parameter(edited_weights) network.reset() with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) #reset weights back to original network.proj = torch.nn.Parameter(original_weights) network.reset() return [image] def sample_then_run(): global young_val global pointy_val global bags_val global young global pointy global bags sample_model() young_val = network.proj@young[0]/(torch.norm(young)**2).item() pointy_val = network.proj@pointy[0]/(torch.norm(pointy)**2).item() bags_val = network.proj@bags[0]/(torch.norm(bags)**2).item() prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, cartoon" seed = 5 cfg = 3.0 steps = 50 image = inference( prompt, negative_prompt, cfg, steps, seed) return image #directions global young global pointy global bags young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item() young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item() pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item() pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item() bags = get_direction(df, "Bags_Under_Eyes", pinverse, 1000, device) bags_max = torch.max(proj@bags[0]/(torch.norm(bags))**2).item() bags_min = torch.min(proj@bags[0]/(torch.norm(bags))**2).item() intro = """
project page | paper |
""" with gr.Blocks(css="style.css") as demo: gr.HTML(intro) with gr.Row(): with gr.Column(): gallery1 = gr.Gallery(label="Identity from Sampled Model") sample = gr.Button("Sample New Model") gallery2 = gr.Gallery(label="Identity from Edited Model") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon") with gr.Row(): a1 = gr.Slider(label="Young", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True) a2 = gr.Slider(label="Pointy Nose", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True) a3 = gr.Slider(label="Undereye Bags", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True) with gr.Accordion("Advanced Options", open=False): with gr.Column(): seed = gr.Number(value=5, label="Seed", interactive=True) cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) submit = gr.Button("Submit") #with gr.Column(): #gallery2 = gr.Gallery(label="Identity from Edited Model") sample.click(fn=sample_then_run, outputs=gallery1) submit.click(fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3], outputs=gallery2) demo.launch(share=True)