{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2024-06-28 00:45:26,702] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "import gradio as gr\n", "import sys\n", "import os \n", "import tqdm\n", "sys.path.append(os.path.abspath(os.path.join(\"\", \"..\")))\n", "import torch\n", "import gc\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "from PIL import Image\n", "from utils import load_models, save_model_w2w, save_model_for_diffusers\n", "from sampling import sample_weights" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "global device\n", "global generator \n", "global unet\n", "global vae \n", "global text_encoder\n", "global tokenizer\n", "global noise_scheduler\n", "device = \"cuda:0\"\n", "generator = torch.Generator(device=device)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "mean = torch.load(\"files/mean.pt\").bfloat16().to(device)\n", "std = torch.load(\"files/std.pt\").bfloat16().to(device)\n", "v = torch.load(\"files/V.pt\").bfloat16().to(device)\n", "proj = torch.load(\"files/proj_1000pc.pt\").bfloat16().to(device)\n", "df = torch.load(\"files/identity_df.pt\")\n", "weight_dimensions = torch.load(\"files/weight_dimensions.pt\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.79it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "global network" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def sample_model():\n", " global unet\n", " del unet\n", " global network\n", " unet, _, _, _, _ = load_models(device)\n", " network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)\n", " \n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):\n", " global device\n", " global generator \n", " global unet\n", " global vae \n", " global text_encoder\n", " global tokenizer\n", " global noise_scheduler\n", " generator = generator.manual_seed(seed)\n", " latents = torch.randn(\n", " (1, unet.in_channels, 512 // 8, 512 // 8),\n", " generator = generator,\n", " device = device\n", " ).bfloat16()\n", " \n", "\n", " text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n", "\n", " text_embeddings = text_encoder(text_input.input_ids.to(device))[0]\n", "\n", " max_length = text_input.input_ids.shape[-1]\n", " uncond_input = tokenizer(\n", " [negative_prompt], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n", " )\n", " uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]\n", " text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n", " noise_scheduler.set_timesteps(ddim_steps) \n", " latents = latents * noise_scheduler.init_noise_sigma\n", " \n", " for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):\n", " latent_model_input = torch.cat([latents] * 2)\n", " latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)\n", " with network:\n", " noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample\n", " #guidance\n", " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n", " latents = noise_scheduler.step(noise_pred, t, latents).prev_sample\n", " \n", " latents = 1 / 0.18215 * latents\n", " image = vae.decode(latents).sample\n", " image = (image / 2 + 0.5).clamp(0, 1)\n", " image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]\n", "\n", " image = Image.fromarray((image * 255).round().astype(\"uint8\"))\n", "\n", " return [image] " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "Running on public URL: https://bc89b27b9704787832.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co./spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "