weights2weights / app.py
amildravid4292's picture
Update app.py
12b3d57 verified
raw
history blame
4.01 kB
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import gradio as gr
import sys
import tqdm
import uuid
sys.path.append(os.path.abspath(os.path.join("", "..")))
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
import numpy as np
from editing import get_direction, debias
from sampling import sample_weights
from lora_w2w import LoRAw2w
from transformers import CLIPTextModel
from lora_w2w import LoRAw2w
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
PNDMScheduler,
StableDiffusionPipeline
)
from huggingface_hub import snapshot_download
import spaces
models_path = snapshot_download(repo_id="Snapchat/w2w")
@spaces.GPU
def load_models(device):
pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
revision = None
weight_dtype = torch.bfloat16
# Load scheduler, tokenizer and models.
pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
torch_dtype=torch.float16,safety_checker = None,
requires_safety_checker = False).to(device)
noise_scheduler = pipe.scheduler
del pipe
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
)
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", revision=revision
)
unet.requires_grad_(False)
unet.to(device, dtype=weight_dtype)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
vae.to(device, dtype=weight_dtype)
text_encoder.to(device, dtype=weight_dtype)
print("")
return unet, vae, text_encoder, tokenizer, noise_scheduler
device="cuda"
mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
df = torch.load(f"{models_path}/files/identity_df.pt")
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
@spaces.GPU
def sample_then_run():
# get mean and standard deviation for each principal component
m = torch.mean(proj, 0)
standev = torch.std(proj, 0)
# sample
sample = torch.zeros([1, 1000]).to(device)
for i in range(1000):
sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1))
net = "model_"+str(uuid.uuid4())[:4]+".pt"
return net
with gr.Blocks(css="style.css") as demo:
net = gr.State()
with gr.Column():
with gr.Row():
with gr.Column():
sample = gr.Button("🎲 Sample New Model")
sample.click(fn=sample_then_run, inputs = [net], outputs=[net])
demo.queue().launch()