weights2weights / app.py
amildravid4292's picture
Update app.py
f112774 verified
raw
history blame
21.4 kB
import os
# os.system("pip uninstall -y gradio")
# #os.system('pip install gradio==3.43.1')
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import os
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from huggingface_hub import snapshot_download
import numpy as np
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global network
global original_image
device = "cuda:0"
generator = torch.Generator(device=device)
from gradio_imageslider import ImageSlider
import spaces
models_path = snapshot_download(repo_id="Snapchat/w2w")
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
@spaces.GPU()
def sample_model():
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
@torch.no_grad()
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return image
@torch.no_grad()
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global young
global pointy
global wavy
global large
global original_image
original_weights = network.proj.clone()
#pad to same number of PCs
pcs_original = original_weights.shape[1]
pcs_edits = young.shape[1]
padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
young_pad = torch.cat((young, padding), 1)
pointy_pad = torch.cat((pointy, padding), 1)
wavy_pad = torch.cat((wavy, padding), 1)
large_pad = torch.cat((large, padding), 1)
edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*8e5*large_pad
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return (original_image, image)
def sample_then_run():
global original_image
sample_model()
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
seed = 5
cfg = 3.0
steps = 50
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
torch.save(network.proj, "model.pt" )
return (original_image, original_image), "model.pt"
global young
global pointy
global wavy
global large
young = get_direction(df, "Young", pinverse, 1000, device)
young = debias(young, "Male", df, pinverse, device)
young = debias(young, "Pointy_Nose", df, pinverse, device)
young = debias(young, "Wavy_Hair", df, pinverse, device)
young = debias(young, "Chubby", df, pinverse, device)
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
pointy = debias(pointy, "Young", df, pinverse, device)
pointy = debias(pointy, "Male", df, pinverse, device)
pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
pointy = debias(pointy, "Chubby", df, pinverse, device)
pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
wavy = debias(wavy, "Young", df, pinverse, device)
wavy = debias(wavy, "Male", df, pinverse, device)
wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
wavy = debias(wavy, "Chubby", df, pinverse, device)
wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
large = debias(large, "Male", df, pinverse, device)
large = debias(large, "Young", df, pinverse, device)
large = debias(large, "Pointy_Nose", df, pinverse, device)
large = debias(large, "Wavy_Hair", df, pinverse, device)
large = debias(large, "Mustache", df, pinverse, device)
large = debias(large, "No_Beard", df, pinverse, device)
large = debias(large, "Sideburns", df, pinverse, device)
large = debias(large, "Big_Nose", df, pinverse, device)
large = debias(large, "Big_Lips", df, pinverse, device)
large = debias(large, "Black_Hair", df, pinverse, device)
large = debias(large, "Brown_Hair", df, pinverse, device)
large = debias(large, "Pale_Skin", df, pinverse, device)
large = debias(large, "Heavy_Makeup", df, pinverse, device)
class CustomImageDataset(Dataset):
def __init__(self, images, transform=None):
self.images = images
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
return image
def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
global unet
del unet
global network
image = dict["background"].convert("RGB").resize((512, 512))
mask = dict["layers"][0].convert("RGB").resize((512, 512))
unet, _, _, _, _ = load_models(device)
proj = torch.zeros(1,pcs).bfloat16().to(device)
network = LoRAw2w( proj, mean, std, v[:, :pcs],
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
### load mask
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
### check if an actual mask was draw, otherwise mask is just all ones
if torch.sum(mask) == 0:
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
### single image dataset
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
train_dataset = CustomImageDataset([image], transform=image_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
### optimizer
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
### training loop
unet.train()
for epoch in tqdm.tqdm(range(epochs)):
for batch in train_dataloader:
### prepare inputs
batch = batch.to(device).bfloat16()
latents = vae.encode(batch).latent_dist.sample()
latents = latents*0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
### loss + sgd step
with network:
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
optim.zero_grad()
loss.backward()
optim.step()
### return optimized network
return network
def run_inversion(dict, pcs, epochs, weight_decay,lr):
global network
global original_image
# init_image = dict["image"].convert("RGB").resize((512, 512))
# mask = dict["ma print(dict)
network = invert( dict, pcs, epochs, weight_decay,lr)
#sample an image
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
seed = 5
cfg = 3.0
steps = 50
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
torch.save(network.proj, "model.pt" )
return (original_image, original_image), "model.pt"
def file_upload(file):
global unet
del unet
global network
global device
global original_image
proj = torch.load(file.name).to(device)
#pad to 10000 Principal components to keep everything consistent
pcs = proj.shape[1]
padding = torch.zeros((1,10000-pcs)).to(device)
proj = torch.cat((proj, padding), 1)
unet, _, _, _, _ = load_models(device)
network = LoRAw2w( proj, mean, std, v[:, :10000],
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
seed = 5
cfg = 3.0
steps = 50
original_image = inference( prompt, negative_prompt, cfg, steps, seed)
return (original_image, original_image)
intro = """
<div style="display: flex;align-items: center;justify-content: center">
<h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models (aka <b> <em>weights2weights</em></b>)</h2>
</div>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
<a href="https://snap-research.github.io/weights2weights/" target="_blank">Project Page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">Paper</a>
|
<a href="https://github.com/snap-research/weights2weights" target="_blank">Code</a> |
<a href="https://huggingface.co./spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
display: inline-block;
">
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
with gr.Tab("Model Editing"):
gr.Markdown("""
Click the `Sample New Model` to sample a new identity-encoding model or upload a model to get started ✨
""")
with gr.Column():
with gr.Row():
with gr.Column():
sample = gr.Button("🎲 Sample New Model")
file_output1 = gr.File(label="Download Sampled Model", container=True, interactive=False)
file_input = gr.File(label="Upload Model", container=True)
with gr.Column():
image_slider1 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User")
prompt1 = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
with gr.Row():
a1_1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a2_1 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Row():
a3_1 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a4_1 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Accordion("Advanced Options", open=False):
cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
injection_step1 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
submit1 = gr.Button("Generate")
with gr.Tab("Inversion"):
gr.Markdown("""
Upload an image and optionally define a mask by drawing over the face. Then click `invert` to get started ✨
""")
with gr.Row():
with gr.Column():
input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask", height=512, width=512, brush=gr.Brush(), layers=False)
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
invert_button = gr.Button("Invert")
file_output2 = gr.File(label="Download Inverted Model", container=True, interactive=False)
with gr.Column():
image_slider2 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User")
prompt2 = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
with gr.Row():
a1_2 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a2_2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Row():
a3_2 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a4_2 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Accordion("Advanced Options", open=False):
cfg2= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
injection_step2 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
submit2 = gr.Button("Generate")
sample.click(fn=sample_then_run, outputs=[image_slider1, file_output1])
submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step1, a1_1, a2_1, a3_1, a4_1], outputs=image_slider1)
file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider1)
invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [image_slider2, file_output2])
submit2.click(fn=edit_inference, inputs=[ prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step2, a1_2, a2_2, a3_2, a4_2], outputs=image_slider2)
demo.queue().launch()