{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[2024-06-28 00:45:26,702] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" ] } ], "source": [ "import gradio as gr\n", "import sys\n", "import os \n", "import tqdm\n", "sys.path.append(os.path.abspath(os.path.join(\"\", \"..\")))\n", "import torch\n", "import gc\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", "from PIL import Image\n", "from utils import load_models, save_model_w2w, save_model_for_diffusers\n", "from sampling import sample_weights" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "global device\n", "global generator \n", "global unet\n", "global vae \n", "global text_encoder\n", "global tokenizer\n", "global noise_scheduler\n", "device = \"cuda:0\"\n", "generator = torch.Generator(device=device)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "mean = torch.load(\"files/mean.pt\").bfloat16().to(device)\n", "std = torch.load(\"files/std.pt\").bfloat16().to(device)\n", "v = torch.load(\"files/V.pt\").bfloat16().to(device)\n", "proj = torch.load(\"files/proj_1000pc.pt\").bfloat16().to(device)\n", "df = torch.load(\"files/identity_df.pt\")\n", "weight_dimensions = torch.load(\"files/weight_dimensions.pt\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.79it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "global network" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def sample_model():\n", " global unet\n", " del unet\n", " global network\n", " unet, _, _, _, _ = load_models(device)\n", " network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)\n", " \n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad()\n", "def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):\n", " global device\n", " global generator \n", " global unet\n", " global vae \n", " global text_encoder\n", " global tokenizer\n", " global noise_scheduler\n", " generator = generator.manual_seed(seed)\n", " latents = torch.randn(\n", " (1, unet.in_channels, 512 // 8, 512 // 8),\n", " generator = generator,\n", " device = device\n", " ).bfloat16()\n", " \n", "\n", " text_input = tokenizer(prompt, padding=\"max_length\", max_length=tokenizer.model_max_length, truncation=True, return_tensors=\"pt\")\n", "\n", " text_embeddings = text_encoder(text_input.input_ids.to(device))[0]\n", "\n", " max_length = text_input.input_ids.shape[-1]\n", " uncond_input = tokenizer(\n", " [negative_prompt], padding=\"max_length\", max_length=max_length, return_tensors=\"pt\"\n", " )\n", " uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]\n", " text_embeddings = torch.cat([uncond_embeddings, text_embeddings])\n", " noise_scheduler.set_timesteps(ddim_steps) \n", " latents = latents * noise_scheduler.init_noise_sigma\n", " \n", " for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):\n", " latent_model_input = torch.cat([latents] * 2)\n", " latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)\n", " with network:\n", " noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample\n", " #guidance\n", " noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n", " noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n", " latents = noise_scheduler.step(noise_pred, t, latents).prev_sample\n", " \n", " latents = 1 / 0.18215 * latents\n", " image = vae.decode(latents).sample\n", " image = (image / 2 + 0.5).clamp(0, 1)\n", " image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]\n", "\n", " image = Image.fromarray((image * 255).round().astype(\"uint8\"))\n", "\n", " return [image] " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "Running on public URL: https://bc89b27b9704787832.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co./spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 8.95it/s]\n", "Traceback (most recent call last):\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/routes.py\", line 437, in run_predict\n", " output = await app.get_blocks().process_api(\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py\", line 1352, in process_api\n", " result = await self.call_function(\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py\", line 1077, in call_function\n", " prediction = await anyio.to_thread.run_sync(\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/to_thread.py\", line 56, in run_sync\n", " return await get_async_backend().run_sync_in_worker_thread(\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 2134, in run_sync_in_worker_thread\n", " return await future\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py\", line 851, in run\n", " result = context.run(func, *args)\n", " File \"/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/torch/utils/_contextlib.py\", line 115, in decorate_context\n", " return func(*args, **kwargs)\n", " File \"/tmp/ipykernel_2844069/1186401021.py\", line 12, in inference\n", " (1, unet.in_channels, 512 // 8, 512 // 8),\n", "NameError: name 'unet' is not defined\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "css = ''\n", "with gr.Blocks(css=css) as demo:\n", " gr.Markdown(\"# weights2weights Demo\")\n", " gr.Markdown(\"Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co./h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license\")\n", " with gr.Row():\n", " with gr.Column():\n", " files = gr.Files(\n", " label=\"Upload a photo of your face to invert, or sample a new model\",\n", " file_types=[\"image\"]\n", " )\n", " uploaded_files = gr.Gallery(label=\"Your images\", visible=False, columns=5, rows=1, height=125)\n", "\n", " sample = gr.Button(\"Sample New Model\")\n", "\n", " with gr.Column(visible=False) as clear_button:\n", " remove_and_reupload = gr.ClearButton(value=\"Remove and upload new ones\", components=files, size=\"sm\")\n", " prompt = gr.Textbox(label=\"Prompt\",\n", " info=\"Make sure to include 'sks person'\" ,\n", " placeholder=\"sks person\", \n", " value=\"sks person\")\n", " negative_prompt = gr.Textbox(label=\"Negative Prompt\", placeholder=\"low quality, blurry, unfinished, cartoon\", value=\"low quality, blurry, unfinished, cartoon\")\n", " seed = gr.Number(value=5, precision=0, label=\"Seed\", interactive=True)\n", " cfg = gr.Slider(label=\"CFG\", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)\n", " steps = gr.Slider(label=\"Inference Steps\", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)\n", "\n", "\n", " submit = gr.Button(\"Submit\")\n", "\n", " with gr.Column():\n", " gallery = gr.Gallery(label=\"Generated Images\")\n", "\n", " sample.click(fn=sample_model)\n", " \n", " submit.click(fn=inference,\n", " inputs=[prompt, negative_prompt, cfg, steps, seed],\n", " outputs=gallery)\n", " \n", "\n", "\n", "\n", " \n", " \n", "demo.launch(share=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "dblora2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 2 }