weights2weights / app.py
amildravid4292's picture
Update app.py
bca64d3 verified
raw
history blame
9.41 kB
import gradio as gr
import sys
import os
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights
from editing import get_direction, debias
from huggingface_hub import snapshot_download
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global young_val
global pointy_val
global bags_val
device = "cuda:0"
generator = torch.Generator(device=device)
models_path = snapshot_download(repo_id="Snapchat/w2w")
mean = torch.load(f"{models_path}/mean.pt").bfloat16().to(device)
std = torch.load(f"{models_path}/std.pt").bfloat16().to(device)
v = torch.load(f"{models_path}/V.pt").bfloat16().to(device)
proj = torch.load(f"{models_path}/proj_1000pc.pt").bfloat16().to(device)
df = torch.load(f"{models_path}/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/pinverse_1000pc.pt").bfloat16().to(device)
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
global network
def sample_model():
global unet
del unet
global network
unet, _, _, _, _ = load_models(device)
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
@torch.no_grad()
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
return [image]
@torch.no_grad()
def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3):
global device
global generator
global unet
global vae
global text_encoder
global tokenizer
global noise_scheduler
global young
global pointy
global bags
original_weights = network.proj.clone()
edited_weights = original_weights+a1*young+a2*pointy+a3*bags
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return [image]
def sample_then_run():
global young_val
global pointy_val
global bags_val
global young
global pointy
global bags
sample_model()
young_val = network.proj@young[0]/(torch.norm(young)**2).item()
pointy_val = network.proj@pointy[0]/(torch.norm(pointy)**2).item()
bags_val = network.proj@bags[0]/(torch.norm(bags)**2).item()
prompt = "sks person"
negative_prompt = "low quality, blurry, unfinished, cartoon"
seed = 5
cfg = 3.0
steps = 50
image = inference( prompt, negative_prompt, cfg, steps, seed)
return image
#directions
global young
global pointy
global bags
young = get_direction(df, "Young", pinverse, 1000, device)
young = debias(young, "Male", df, pinverse, device)
young_max = torch.max(proj@young[0]/(torch.norm(young))**2).item()
young_min = torch.min(proj@young[0]/(torch.norm(young))**2).item()
pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
pointy_max = torch.max(proj@pointy[0]/(torch.norm(pointy))**2).item()
pointy_min = torch.min(proj@pointy[0]/(torch.norm(pointy))**2).item()
bags = get_direction(df, "Bags_Under_Eyes", pinverse, 1000, device)
bags_max = torch.max(proj@bags[0]/(torch.norm(bags))**2).item()
bags_min = torch.min(proj@bags[0]/(torch.norm(bags))**2).item()
css = ''
with gr.Blocks(css=css) as demo:
gr.Markdown("# <em>weights2weights</em> Demo")
gr.Markdown("Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co./h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license")
with gr.Row():
with gr.Column():
sample = gr.Button("Sample New Model")
gallery1 = gr.Gallery(label="Identity from Sampled Model")
with gr.Column():
prompt = gr.Textbox(label="Prompt",
info="Make sure to include 'sks person'" ,
placeholder="sks person",
value="sks person")
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
seed = gr.Number(value=5, label="Seed", interactive=True)
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
with gr.Row():
a1 = gr.Slider(label="Young", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True)
a2 = gr.Slider(label="Pointy Nose", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True)
a3 = gr.Slider(label="Undereye Bags", value=0, step=1, minimum=-1000000, maximum=1000000, interactive=True)
submit = gr.Button("Submit")
with gr.Column():
gallery2 = gr.Gallery(label="Identity from Edited Model")
sample.click(fn=sample_then_run, outputs=gallery1)
submit.click(fn=edit_inference,
inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3],
outputs=gallery2)
demo.launch(share=True)