Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.system("pip uninstall -y gradio") | |
os.system('pip install gradio==3.43.1') | |
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from torch.utils.data import Dataset, DataLoader | |
import gradio as gr | |
import sys | |
import tqdm | |
sys.path.append(os.path.abspath(os.path.join("", ".."))) | |
import gc | |
import warnings | |
warnings.filterwarnings("ignore") | |
from PIL import Image | |
import numpy as np | |
from utils import load_models | |
from editing import get_direction, debias | |
from sampling import sample_weights | |
from lora_w2w import LoRAw2w | |
from huggingface_hub import snapshot_download | |
global device | |
global generator | |
global unet | |
global vae | |
global text_encoder | |
global tokenizer | |
global noise_scheduler | |
global network | |
device = "cuda:0" | |
generator = torch.Generator(device=device) | |
models_path = snapshot_download(repo_id="Snapchat/w2w") | |
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device) | |
std = torch.load(f"{models_path}/files/std.pt").bfloat16().to(device) | |
v = torch.load(f"{models_path}/files/V.pt").bfloat16().to(device) | |
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16().to(device) | |
df = torch.load(f"{models_path}/files/identity_df.pt") | |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") | |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device) | |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device) | |
def sample_model(): | |
global unet | |
del unet | |
global network | |
unet, _, _, _, _ = load_models(device) | |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00) | |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed): | |
global device | |
global generator | |
global unet | |
global vae | |
global text_encoder | |
global tokenizer | |
global noise_scheduler | |
generator = generator.manual_seed(seed) | |
latents = torch.randn( | |
(1, unet.in_channels, 512 // 8, 512 // 8), | |
generator = generator, | |
device = device | |
).bfloat16() | |
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
noise_scheduler.set_timesteps(ddim_steps) | |
latents = latents * noise_scheduler.init_noise_sigma | |
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
with network: | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
#guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] | |
image = Image.fromarray((image * 255).round().astype("uint8")) | |
return image | |
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): | |
global device | |
global generator | |
global unet | |
global vae | |
global text_encoder | |
global tokenizer | |
global noise_scheduler | |
global young | |
global pointy | |
global wavy | |
global large | |
original_weights = network.proj.clone() | |
#pad to same number of PCs | |
pcs_original = original_weights.shape[1] | |
pcs_edits = young.shape[1] | |
padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) | |
young_pad = torch.cat((young, padding), 1) | |
pointy_pad = torch.cat((pointy, padding), 1) | |
wavy_pad = torch.cat((wavy, padding), 1) | |
large_pad = torch.cat((large, padding), 1) | |
edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*large_pad | |
generator = generator.manual_seed(seed) | |
latents = torch.randn( | |
(1, unet.in_channels, 512 // 8, 512 // 8), | |
generator = generator, | |
device = device | |
).bfloat16() | |
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = tokenizer( | |
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
noise_scheduler.set_timesteps(ddim_steps) | |
latents = latents * noise_scheduler.init_noise_sigma | |
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): | |
latent_model_input = torch.cat([latents] * 2) | |
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) | |
if t>start_noise: | |
pass | |
elif t<=start_noise: | |
network.proj = torch.nn.Parameter(edited_weights) | |
network.reset() | |
with network: | |
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample | |
#guidance | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample | |
latents = 1 / 0.18215 * latents | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] | |
image = Image.fromarray((image * 255).round().astype("uint8")) | |
#reset weights back to original | |
network.proj = torch.nn.Parameter(original_weights) | |
network.reset() | |
return image | |
def sample_then_run(): | |
sample_model() | |
prompt = "sks person" | |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon" | |
seed = 5 | |
cfg = 3.0 | |
steps = 50 | |
image = inference( prompt, negative_prompt, cfg, steps, seed) | |
torch.save(network.proj, "model.pt" ) | |
return image, "model.pt" | |
global young | |
global pointy | |
global wavy | |
global large | |
young = get_direction(df, "Young", pinverse, 1000, device) | |
young = debias(young, "Male", df, pinverse, device) | |
young = debias(young, "Pointy_Nose", df, pinverse, device) | |
young = debias(young, "Wavy_Hair", df, pinverse, device) | |
young = debias(young, "Chubby", df, pinverse, device) | |
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) | |
pointy = debias(pointy, "Young", df, pinverse, device) | |
pointy = debias(pointy, "Male", df, pinverse, device) | |
pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) | |
pointy = debias(pointy, "Chubby", df, pinverse, device) | |
pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) | |
wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) | |
wavy = debias(wavy, "Young", df, pinverse, device) | |
wavy = debias(wavy, "Male", df, pinverse, device) | |
wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) | |
wavy = debias(wavy, "Chubby", df, pinverse, device) | |
wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) | |
large = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) | |
large = debias(large, "Male", df, pinverse, device) | |
large = debias(large, "Young", df, pinverse, device) | |
large = debias(large, "Pointy_Nose", df, pinverse, device) | |
large = debias(large, "Wavy_Hair", df, pinverse, device) | |
large = debias(large, "Mustache", df, pinverse, device) | |
large = debias(large, "No_Beard", df, pinverse, device) | |
large = debias(large, "Sideburns", df, pinverse, device) | |
large = debias(large, "Big_Nose", df, pinverse, device) | |
large = debias(large, "Big_Lips", df, pinverse, device) | |
large = debias(large, "Black_Hair", df, pinverse, device) | |
large = debias(large, "Brown_Hair", df, pinverse, device) | |
large = debias(large, "Pale_Skin", df, pinverse, device) | |
large = debias(large, "Heavy_Makeup", df, pinverse, device) | |
class CustomImageDataset(Dataset): | |
def __init__(self, images, transform=None): | |
self.images = images | |
self.transform = transform | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
image = self.images[idx] | |
if self.transform: | |
image = self.transform(image) | |
return image | |
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): | |
global unet | |
del unet | |
global network | |
unet, _, _, _, _ = load_models(device) | |
proj = torch.zeros(1,pcs).bfloat16().to(device) | |
network = LoRAw2w( proj, mean, std, v[:, :pcs], | |
unet, | |
rank=1, | |
multiplier=1.0, | |
alpha=27.0, | |
train_method="xattn-strict" | |
).to(device, torch.bfloat16) | |
### load mask | |
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) | |
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) | |
### check if an actual mask was draw, otherwise mask is just all ones | |
if torch.sum(mask) == 0: | |
mask = torch.ones((1,1,64,64)).to(device).bfloat16() | |
### single image dataset | |
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.RandomCrop(512), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5])]) | |
train_dataset = CustomImageDataset(image, transform=image_transforms) | |
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) | |
### optimizer | |
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) | |
### training loop | |
unet.train() | |
for epoch in tqdm.tqdm(range(epochs)): | |
for batch in train_dataloader: | |
### prepare inputs | |
batch = batch.to(device).bfloat16() | |
latents = vae.encode(batch).latent_dist.sample() | |
latents = latents*0.18215 | |
noise = torch.randn_like(latents) | |
bsz = latents.shape[0] | |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | |
timesteps = timesteps.long() | |
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
### loss + sgd step | |
with network: | |
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample | |
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
### return optimized network | |
return network | |
def run_inversion(dict, pcs, epochs, weight_decay,lr): | |
global network | |
init_image = dict["image"].convert("RGB").resize((512, 512)) | |
mask = dict["mask"].convert("RGB").resize((512, 512)) | |
network = invert([init_image], mask, pcs, epochs, weight_decay,lr) | |
#sample an image | |
prompt = "sks person" | |
negative_prompt = "low quality, blurry, unfinished, nudity" | |
seed = 5 | |
cfg = 3.0 | |
steps = 50 | |
image = inference( prompt, negative_prompt, cfg, steps, seed) | |
torch.save(network.proj, "model.pt" ) | |
return image, "model.pt" | |
def file_upload(file): | |
global unet | |
del unet | |
global network | |
global device | |
proj = torch.load(file.name).to(device) | |
#pad to 10000 Principal components to keep everything consistent | |
pcs = proj.shape[1] | |
padding = torch.zeros((1,10000-pcs)).to(device) | |
proj = torch.cat((proj, padding), 1) | |
unet, _, _, _, _ = load_models(device) | |
network = LoRAw2w( proj, mean, std, v[:, :pcs], | |
unet, | |
rank=1, | |
multiplier=1.0, | |
alpha=27.0, | |
train_method="xattn-strict" | |
).to(device, torch.bfloat16) | |
prompt = "sks person" | |
negative_prompt = "low quality, blurry, unfinished, nudity" | |
seed = 5 | |
cfg = 3.0 | |
steps = 50 | |
image = inference( prompt, negative_prompt, cfg, steps, seed) | |
return image | |
intro = """ | |
<div style="display: flex;align-items: center;justify-content: center"> | |
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1> | |
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3> | |
</div> | |
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block"> | |
<a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a> | |
| | |
<a href="https://huggingface.co./spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style=" | |
display: inline-block; | |
"> | |
<img style="margin-top: -1em;margin-bottom: 0em;position: absolute;" src="https://bit.ly/3CWLGkA" alt="Duplicate Space"></a> | |
</p> | |
""" | |
with gr.Blocks(css="style.css") as demo: | |
gr.HTML(intro) | |
gr.Markdown("""<div style="text-align: justify;"> In this demo, you can get an identity-encoding model by sampling or inverting. To use a model previously downloaded from this demo see \"Uploading a model\" in the Advanced options. Next, you can generate new samples from it, or edit the identity encoded in the model and generate samples from the edited model. We provide detailed instructions and tips at the bottom of the page.""") | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""1) Either sample a new model, or upload an image (optionally draw a mask over the face) and click `invert`. """) | |
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask", | |
width=512, height=512, brush_color='#00FFFF', mask_opacity=0.4) | |
with gr.Row(): | |
sample = gr.Button("🎲 Sample New Model") | |
invert_button = gr.Button("⬆️ Invert") | |
with gr.Column(): | |
gr.Markdown("""2) Generate images of the sampled/inverted identity or edit the identity and generate new images. """) | |
gallery = gr.Image(label="Image",height=512, width=512, interactive=False) | |
submit = gr.Button("Generate") | |
prompt = gr.Textbox(label="Prompt", | |
info="Make sure to include 'sks person'" , | |
placeholder="sks person", | |
value="sks person") | |
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True) | |
# Editing | |
with gr.Column(): | |
with gr.Row(): | |
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
with gr.Row(): | |
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) | |
with gr.Accordion("Advanced Options", open=False): | |
with gr.Tab("Inversion"): | |
with gr.Row(): | |
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) | |
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) | |
with gr.Row(): | |
epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True) | |
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) | |
with gr.Tab("Sampling"): | |
with gr.Row(): | |
cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) | |
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True) | |
with gr.Row(): | |
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") | |
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) | |
with gr.Tab("Uploading a model"): | |
gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""") | |
file_input = gr.File(label="Upload Model", container=True, interactive=False) | |
gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""") | |
with gr.Row(): | |
file_output = gr.File(label="Download Sampled/Inverted Model", container=True, interactive=False) | |
invert_button.click(fn=run_inversion, | |
inputs=[input_image, pcs, epochs, weight_decay,lr], | |
outputs = [gallery, file_output]) | |
sample.click(fn=sample_then_run, outputs=[gallery, file_output]) | |
submit.click( | |
fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery] | |
) | |
file_input.change(fn=file_upload, inputs=file_input, outputs = input_image) | |
demo.queue().launch() | |