weights2weights / app.py
amildravid4292's picture
Update app.py
9bff723 verified
raw
history blame
19.4 kB
import os
os.system("pip uninstall -y gradio")
os.system('pip install gradio==3.43.1')
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
from utils import load_models
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from huggingface_hub import snapshot_download
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global network
device = "cuda:0"
generator = torch.Generator(device=device)
models_path = snapshot_download(repo_id="Snapchat/w2w")
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
def sample_model():
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
@torch.no_grad()
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return image
@torch.no_grad()
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global young
global pointy
global wavy
global large
original_weights = network.proj.clone()
#pad to same number of PCs
pcs_original = original_weights.shape[1]
pcs_edits = young.shape[1]
padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
young_pad = torch.cat((young, padding), 1)
pointy_pad = torch.cat((pointy, padding), 1)
wavy_pad = torch.cat((wavy, padding), 1)
large_pad = torch.cat((large, padding), 1)
edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*large_pad
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return image
def sample_then_run():
sample_model()
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
seed = 5
cfg = 3.0
steps = 50
image = inference( prompt, negative_prompt, cfg, steps, seed)
torch.save(network.proj, "model.pt" )
return image, "model.pt"
global young
global pointy
global wavy
global large
young = get_direction(df, "Young", pinverse, 1000, device)
young = debias(young, "Male", df, pinverse, device)
young = debias(young, "Pointy_Nose", df, pinverse, device)
young = debias(young, "Wavy_Hair", df, pinverse, device)
young = debias(young, "Chubby", df, pinverse, device)
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
pointy = debias(pointy, "Young", df, pinverse, device)
pointy = debias(pointy, "Male", df, pinverse, device)
pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
pointy = debias(pointy, "Chubby", df, pinverse, device)
pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
wavy = debias(wavy, "Young", df, pinverse, device)
wavy = debias(wavy, "Male", df, pinverse, device)
wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
wavy = debias(wavy, "Chubby", df, pinverse, device)
wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
large = debias(large, "Male", df, pinverse, device)
large = debias(large, "Young", df, pinverse, device)
large = debias(large, "Pointy_Nose", df, pinverse, device)
large = debias(large, "Wavy_Hair", df, pinverse, device)
large = debias(large, "Mustache", df, pinverse, device)
large = debias(large, "No_Beard", df, pinverse, device)
large = debias(large, "Sideburns", df, pinverse, device)
large = debias(large, "Big_Nose", df, pinverse, device)
large = debias(large, "Big_Lips", df, pinverse, device)
large = debias(large, "Black_Hair", df, pinverse, device)
large = debias(large, "Brown_Hair", df, pinverse, device)
large = debias(large, "Pale_Skin", df, pinverse, device)
large = debias(large, "Heavy_Makeup", df, pinverse, device)
class CustomImageDataset(Dataset):
def __init__(self, images, transform=None):
self.images = images
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
if self.transform:
image = self.transform(image)
return image
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
proj = torch.zeros(1,pcs).bfloat16().to(device)
network = LoRAw2w( proj, mean, std, v[:, :pcs],
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
### load mask
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
### check if an actual mask was draw, otherwise mask is just all ones
if torch.sum(mask) == 0:
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
### single image dataset
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.RandomCrop(512),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
train_dataset = CustomImageDataset(image, transform=image_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
### optimizer
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
### training loop
unet.train()
for epoch in tqdm.tqdm(range(epochs)):
for batch in train_dataloader:
### prepare inputs
batch = batch.to(device).bfloat16()
latents = vae.encode(batch).latent_dist.sample()
latents = latents*0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
### loss + sgd step
with network:
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
optim.zero_grad()
loss.backward()
optim.step()
### return optimized network
return network
def run_inversion(dict, pcs, epochs, weight_decay,lr):
global network
init_image = dict["image"].convert("RGB").resize((512, 512))
mask = dict["mask"].convert("RGB").resize((512, 512))
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
#sample an image
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, nudity"
seed = 5
cfg = 3.0
steps = 50
image = inference( prompt, negative_prompt, cfg, steps, seed)
torch.save(network.proj, "model.pt" )
return image, "model.pt"
def file_upload(file):
global unet
del unet
global network
global device
proj = torch.load(file.name).to(device)
#pad to 10000 Principal components to keep everything consistent
pcs = proj.shape[1]
padding = torch.zeros((1,10000-pcs)).to(device)
proj = torch.cat((proj, padding), 1)
unet, _, _, _, _ = load_models(device)
network = LoRAw2w( proj, mean, std, v[:, :pcs],
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, nudity"
seed = 5
cfg = 3.0
steps = 50
image = inference( prompt, negative_prompt, cfg, steps, seed)
return image
intro = """
<div style="display: flex;align-items: center;justify-content: center">
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3>
</div>
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
<a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a>
|
<a href="https://huggingface.co./spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
display: inline-block;
">
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a>
</p>
"""
with gr.Blocks(css="style.css") as demo:
gr.HTML(intro)
gr.Markdown("""<div style="text-align: justify;"> In this demo, you can get an identity-encoding model by sampling or inverting. To use a model previously downloaded from this demo see \"Uploading a model\" in the Advanced options. Next, you can generate new samples from it, or edit the identity encoded in the model and generate samples from the edited model. We provide detailed instructions and tips at the bottom of the page.""")
with gr.Column():
with gr.Row():
with gr.Column():
gr.Markdown("""1) Either sample a new model, or upload an image (optionally draw a mask over the face) and click `invert`. """)
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
width=512, height=512, brush_color='#00FFFF', mask_opacity=0.4)
with gr.Row():
sample = gr.Button("🎲 Sample New Model")
invert_button = gr.Button("⬆️ Invert")
with gr.Column():
gr.Markdown("""2) Generate images of the sampled/inverted identity or edit the identity and generate new images. """)
gallery = gr.Image(label="Image",height=512, width=512, interactive=False)
submit = gr.Button("Generate")
prompt = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
# Editing
with gr.Column():
with gr.Row():
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Row():
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
with gr.Accordion("Advanced Options", open=False):
with gr.Tab("Inversion"):
with gr.Row():
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
with gr.Row():
epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True)
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
with gr.Tab("Sampling"):
with gr.Row():
cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
with gr.Row():
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
with gr.Tab("Uploading a model"):
gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""")
file_input = gr.File(label="Upload Model", container=True, interactive=False)
gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
with gr.Row():
file_output = gr.File(label="Download Sampled/Inverted Model", container=True, interactive=False)
invert_button.click(fn=run_inversion,
inputs=[input_image, pcs, epochs, weight_decay,lr],
outputs = [gallery, file_output])
sample.click(fn=sample_then_run, outputs=[gallery, file_output])
submit.click(
fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
)
file_input.change(fn=file_upload, inputs=file_input, outputs = input_image)
demo.queue().launch()