{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4a8767df-feea-4bda-8bfb-ab07f667cd11", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'cv2'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 12\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdatetime\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m datetime\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpathlib\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Path\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mcv2\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cv2'" ] } ], "source": [ "#!/usr/bin/env python\n", "# -*- coding: UTF-8 -*-\n", "'''\n", "webui\n", "'''\n", "\n", "import os\n", "import random\n", "from datetime import datetime\n", "from pathlib import Path\n", "\n", "import cv2\n", "import numpy as np\n", "import torch\n", "from diffusers import AutoencoderKL, DDIMScheduler\n", "from omegaconf import OmegaConf\n", "from PIL import Image\n", "from src.models.unet_2d_condition import UNet2DConditionModel\n", "from src.models.unet_3d_echo import EchoUNet3DConditionModel\n", "from src.models.whisper.audio2feature import load_audio_model\n", "from src.pipelines.pipeline_echo_mimic import Audio2VideoPipeline\n", "from src.utils.util import save_videos_grid, crop_and_pad\n", "from src.models.face_locator import FaceLocator\n", "from moviepy.editor import VideoFileClip, AudioFileClip\n", "from facenet_pytorch import MTCNN\n", "import argparse\n", "\n", "import gradio as gr\n", "\n", "default_values = {\n", " \"width\": 512,\n", " \"height\": 512,\n", " \"length\": 1200,\n", " \"seed\": 420,\n", " \"facemask_dilation_ratio\": 0.1,\n", " \"facecrop_dilation_ratio\": 0.5,\n", " \"context_frames\": 12,\n", " \"context_overlap\": 3,\n", " \"cfg\": 2.5,\n", " \"steps\": 30,\n", " \"sample_rate\": 16000,\n", " \"fps\": 24,\n", " \"device\": \"cuda\"\n", "}\n", "\n", "ffmpeg_path = os.getenv('FFMPEG_PATH')\n", "if ffmpeg_path is None:\n", " print(\"please download ffmpeg-static and export to FFMPEG_PATH. \\nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static\")\n", "elif ffmpeg_path not in os.getenv('PATH'):\n", " print(\"add ffmpeg to path\")\n", " os.environ[\"PATH\"] = f\"{ffmpeg_path}:{os.environ['PATH']}\"\n", "\n", "\n", "config_path = \"./configs/prompts/animation.yaml\"\n", "config = OmegaConf.load(config_path)\n", "if config.weight_dtype == \"fp16\":\n", " weight_dtype = torch.float16\n", "else:\n", " weight_dtype = torch.float32\n", "\n", "device = \"cuda\"\n", "if not torch.cuda.is_available():\n", " device = \"cpu\"\n", "\n", "inference_config_path = config.inference_config\n", "infer_config = OmegaConf.load(inference_config_path)\n", "\n", "############# model_init started #############\n", "## vae init\n", "vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path).to(\"cuda\", dtype=weight_dtype)\n", "\n", "## reference net init\n", "reference_unet = UNet2DConditionModel.from_pretrained(\n", " config.pretrained_base_model_path,\n", " subfolder=\"unet\",\n", ").to(dtype=weight_dtype, device=device)\n", "reference_unet.load_state_dict(torch.load(config.reference_unet_path, map_location=\"cpu\"))\n", "\n", "## denoising net init\n", "if os.path.exists(config.motion_module_path):\n", " ### stage1 + stage2\n", " denoising_unet = EchoUNet3DConditionModel.from_pretrained_2d(\n", " config.pretrained_base_model_path,\n", " config.motion_module_path,\n", " subfolder=\"unet\",\n", " unet_additional_kwargs=infer_config.unet_additional_kwargs,\n", " ).to(dtype=weight_dtype, device=device)\n", "else:\n", " ### only stage1\n", " denoising_unet = EchoUNet3DConditionModel.from_pretrained_2d(\n", " config.pretrained_base_model_path,\n", " \"\",\n", " subfolder=\"unet\",\n", " unet_additional_kwargs={\n", " \"use_motion_module\": False,\n", " \"unet_use_temporal_attention\": False,\n", " \"cross_attention_dim\": infer_config.unet_additional_kwargs.cross_attention_dim\n", " }\n", " ).to(dtype=weight_dtype, device=device)\n", "\n", "denoising_unet.load_state_dict(torch.load(config.denoising_unet_path, map_location=\"cpu\"), strict=False)\n", "\n", "## face locator init\n", "face_locator = FaceLocator(320, conditioning_channels=1, block_out_channels=(16, 32, 96, 256)).to(dtype=weight_dtype, device=\"cuda\")\n", "face_locator.load_state_dict(torch.load(config.face_locator_path))\n", "\n", "## load audio processor params\n", "audio_processor = load_audio_model(model_path=config.audio_model_path, device=device)\n", "\n", "## load face detector params\n", "face_detector = MTCNN(image_size=320, margin=0, min_face_size=20, thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, device=device)\n", "\n", "############# model_init finished #############\n", "\n", "sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)\n", "scheduler = DDIMScheduler(**sched_kwargs)\n", "\n", "pipe = Audio2VideoPipeline(\n", " vae=vae,\n", " reference_unet=reference_unet,\n", " denoising_unet=denoising_unet,\n", " audio_guider=audio_processor,\n", " face_locator=face_locator,\n", " scheduler=scheduler,\n", ").to(\"cuda\", dtype=weight_dtype)\n", "\n", "def select_face(det_bboxes, probs):\n", " ## max face from faces that the prob is above 0.8\n", " ## box: xyxy\n", " if det_bboxes is None or probs is None:\n", " return None\n", " filtered_bboxes = []\n", " for bbox_i in range(len(det_bboxes)):\n", " if probs[bbox_i] > 0.8:\n", " filtered_bboxes.append(det_bboxes[bbox_i])\n", " if len(filtered_bboxes) == 0:\n", " return None\n", " sorted_bboxes = sorted(filtered_bboxes, key=lambda x:(x[3]-x[1]) * (x[2] - x[0]), reverse=True)\n", " return sorted_bboxes[0]\n", "\n", "def process_video(uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):\n", "\n", " if seed is not None and seed > -1:\n", " generator = torch.manual_seed(seed)\n", " else:\n", " generator = torch.manual_seed(random.randint(100, 1000000))\n", "\n", " #### face musk prepare\n", " face_img = cv2.imread(uploaded_img)\n", " face_mask = np.zeros((face_img.shape[0], face_img.shape[1])).astype('uint8')\n", " det_bboxes, probs = face_detector.detect(face_img)\n", " select_bbox = select_face(det_bboxes, probs)\n", " if select_bbox is None:\n", " face_mask[:, :] = 255\n", " else:\n", " xyxy = select_bbox[:4]\n", " xyxy = np.round(xyxy).astype('int')\n", " rb, re, cb, ce = xyxy[1], xyxy[3], xyxy[0], xyxy[2]\n", " r_pad = int((re - rb) * facemask_dilation_ratio)\n", " c_pad = int((ce - cb) * facemask_dilation_ratio)\n", " face_mask[rb - r_pad : re + r_pad, cb - c_pad : ce + c_pad] = 255\n", " \n", " #### face crop\n", " r_pad_crop = int((re - rb) * facecrop_dilation_ratio)\n", " c_pad_crop = int((ce - cb) * facecrop_dilation_ratio)\n", " crop_rect = [max(0, cb - c_pad_crop), max(0, rb - r_pad_crop), min(ce + c_pad_crop, face_img.shape[1]), min(re + r_pad_crop, face_img.shape[0])]\n", " face_img = crop_and_pad(face_img, crop_rect)\n", " face_mask = crop_and_pad(face_mask, crop_rect)\n", " face_img = cv2.resize(face_img, (width, height))\n", " face_mask = cv2.resize(face_mask, (width, height))\n", "\n", " ref_image_pil = Image.fromarray(face_img[:, :, [2, 1, 0]])\n", " face_mask_tensor = torch.Tensor(face_mask).to(dtype=weight_dtype, device=\"cuda\").unsqueeze(0).unsqueeze(0).unsqueeze(0) / 255.0\n", " \n", " video = pipe(\n", " ref_image_pil,\n", " uploaded_audio,\n", " face_mask_tensor,\n", " width,\n", " height,\n", " length,\n", " steps,\n", " cfg,\n", " generator=generator,\n", " audio_sample_rate=sample_rate,\n", " context_frames=context_frames,\n", " fps=fps,\n", " context_overlap=context_overlap\n", " ).videos\n", "\n", " save_dir = Path(\"output/tmp\")\n", " save_dir.mkdir(exist_ok=True, parents=True)\n", " output_video_path = save_dir / \"output_video.mp4\"\n", " save_videos_grid(video, str(output_video_path), n_rows=1, fps=fps)\n", "\n", " video_clip = VideoFileClip(str(output_video_path))\n", " audio_clip = AudioFileClip(uploaded_audio)\n", " final_output_path = save_dir / \"output_video_with_audio.mp4\"\n", " video_clip = video_clip.set_audio(audio_clip)\n", " video_clip.write_videofile(str(final_output_path), codec=\"libx264\", audio_codec=\"aac\")\n", "\n", " return final_output_path\n", " \n", "with gr.Blocks() as demo:\n", " gr.Markdown('# EchoMimic')\n", " gr.Markdown('![]()')\n", " with gr.Row():\n", " with gr.Column():\n", " uploaded_img = gr.Image(type=\"filepath\", label=\"Reference Image\")\n", " uploaded_audio = gr.Audio(type=\"filepath\", label=\"Input Audio\")\n", " with gr.Column():\n", " output_video = gr.Video()\n", "\n", " with gr.Accordion(\"Configuration\", open=False):\n", " width = gr.Slider(label=\"Width\", minimum=128, maximum=1024, value=default_values[\"width\"])\n", " height = gr.Slider(label=\"Height\", minimum=128, maximum=1024, value=default_values[\"height\"])\n", " length = gr.Slider(label=\"Length\", minimum=100, maximum=5000, value=default_values[\"length\"])\n", " seed = gr.Slider(label=\"Seed\", minimum=0, maximum=10000, value=default_values[\"seed\"])\n", " facemask_dilation_ratio = gr.Slider(label=\"Facemask Dilation Ratio\", minimum=0.0, maximum=1.0, step=0.01, value=default_values[\"facemask_dilation_ratio\"])\n", " facecrop_dilation_ratio = gr.Slider(label=\"Facecrop Dilation Ratio\", minimum=0.0, maximum=1.0, step=0.01, value=default_values[\"facecrop_dilation_ratio\"])\n", " context_frames = gr.Slider(label=\"Context Frames\", minimum=0, maximum=50, step=1, value=default_values[\"context_frames\"])\n", " context_overlap = gr.Slider(label=\"Context Overlap\", minimum=0, maximum=10, step=1, value=default_values[\"context_overlap\"])\n", " cfg = gr.Slider(label=\"CFG\", minimum=0.0, maximum=10.0, step=0.1, value=default_values[\"cfg\"])\n", " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=100, step=1, value=default_values[\"steps\"])\n", " sample_rate = gr.Slider(label=\"Sample Rate\", minimum=8000, maximum=48000, step=1000, value=default_values[\"sample_rate\"])\n", " fps = gr.Slider(label=\"FPS\", minimum=1, maximum=60, step=1, value=default_values[\"fps\"])\n", " device = gr.Radio(label=\"Device\", choices=[\"cuda\", \"cpu\"], value=default_values[\"device\"])\n", "\n", " generate_button = gr.Button(\"Generate Video\")\n", "\n", " def generate_video(uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device):\n", "\n", " final_output_path = process_video(\n", " uploaded_img, uploaded_audio, width, height, length, seed, facemask_dilation_ratio, facecrop_dilation_ratio, context_frames, context_overlap, cfg, steps, sample_rate, fps, device\n", " ) \n", " output_video= final_output_path\n", " return final_output_path\n", "\n", " generate_button.click(\n", " generate_video,\n", " inputs=[\n", " uploaded_img,\n", " uploaded_audio,\n", " width,\n", " height,\n", " length,\n", " seed,\n", " facemask_dilation_ratio,\n", " facecrop_dilation_ratio,\n", " context_frames,\n", " context_overlap,\n", " cfg,\n", " steps,\n", " sample_rate,\n", " fps,\n", " device\n", " ],\n", " outputs=output_video\n", " )\n", "parser = argparse.ArgumentParser(description='EchoMimic')\n", "parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')\n", "parser.add_argument('--server_port', type=int, default=7680, help='Server port')\n", "args = parser.parse_args()\n", "\n", "# demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)\n", "\n", "if __name__ == '__main__':\n", " #demo.launch(server_name='0.0.0.0')\n", " demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True, share=True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }