Suprath commited on
Commit
9f4b9c7
·
verified ·
1 Parent(s): 3e11bd3

Upload 54 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CONTRIBUTING.md +23 -0
  2. Dockerfile +9 -0
  3. README.md +78 -7
  4. app.py +105 -0
  5. app.sh +4 -0
  6. checkpoints/.gitkeep +0 -0
  7. config/__init__.py +5 -0
  8. config/gradio.yaml +14 -0
  9. config/nota_wav2lip.yaml +44 -0
  10. data/.gitkeep +0 -0
  11. docker-compose.yml +11 -0
  12. docs/assets/fig5.png +0 -0
  13. docs/description.md +22 -0
  14. docs/footer.md +5 -0
  15. docs/header.md +10 -0
  16. docs/main.css +4 -0
  17. download.py +44 -0
  18. download.sh +7 -0
  19. face_detection/README.md +1 -0
  20. face_detection/__init__.py +7 -0
  21. face_detection/api.py +79 -0
  22. face_detection/detection/__init__.py +1 -0
  23. face_detection/detection/core.py +130 -0
  24. face_detection/detection/sfd/__init__.py +1 -0
  25. face_detection/detection/sfd/bbox.py +129 -0
  26. face_detection/detection/sfd/detect.py +112 -0
  27. face_detection/detection/sfd/net_s3fd.py +129 -0
  28. face_detection/detection/sfd/sfd_detector.py +59 -0
  29. face_detection/models.py +261 -0
  30. face_detection/utils.py +313 -0
  31. inference.py +82 -0
  32. inference.sh +15 -0
  33. nota_wav2lip/__init__.py +2 -0
  34. nota_wav2lip/audio.py +135 -0
  35. nota_wav2lip/demo.py +91 -0
  36. nota_wav2lip/gradio.py +91 -0
  37. nota_wav2lip/inference.py +111 -0
  38. nota_wav2lip/models/__init__.py +3 -0
  39. nota_wav2lip/models/base.py +55 -0
  40. nota_wav2lip/models/conv.py +34 -0
  41. nota_wav2lip/models/util.py +32 -0
  42. nota_wav2lip/models/wav2lip.py +85 -0
  43. nota_wav2lip/models/wav2lip_compressed.py +72 -0
  44. nota_wav2lip/preprocess/__init__.py +2 -0
  45. nota_wav2lip/preprocess/core.py +98 -0
  46. nota_wav2lip/preprocess/ffmpeg.py +5 -0
  47. nota_wav2lip/preprocess/lrs3_download.py +259 -0
  48. nota_wav2lip/util.py +5 -0
  49. nota_wav2lip/video.py +68 -0
  50. preprocess.py +28 -0
CONTRIBUTING.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to this repository
2
+
3
+ ## Install linter
4
+
5
+ First of all, you need to install `ruff` package to verify that you passed all conditions for formatting.
6
+
7
+ ```
8
+ pip install ruff==0.0.287
9
+ ```
10
+
11
+ ### Apply linter before PR
12
+
13
+ Please run the ruff check with the following command:
14
+
15
+ ```
16
+ ruff check .
17
+ ```
18
+
19
+ ### Auto-fix with fixable errors
20
+
21
+ ```
22
+ ruff check . --fix
23
+ ```
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:22.03-py3
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+ RUN apt-get update
5
+ RUN apt-get install ffmpeg libsm6 libxext6 tmux git -y
6
+
7
+ WORKDIR /workspace
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache -r requirements.txt
README.md CHANGED
@@ -1,13 +1,84 @@
1
  ---
2
- title: LipSync
3
- emoji: 🌍
4
  colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.27.0
8
  app_file: app.py
9
- pinned: false
10
- license: unknown
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Compressed Wav2Lip
3
+ emoji: 🌟
4
  colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.13.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
  ---
12
 
13
+ # 28× Compressed Wav2Lip by Nota AI
14
+
15
+ Official codebase for [**Accelerating Speech-Driven Talking Face Generation with 28× Compressed Wav2Lip**](https://arxiv.org/abs/2304.00471).
16
+
17
+ - Presented at [ICCV'23 Demo](https://iccv2023.thecvf.com/demos-111.php) Track; [On-Device Intelligence Workshop](https://sites.google.com/g.harvard.edu/on-device-workshop-23/home) @ MLSys'23; [NVIDIA GTC 2023](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=52409&sort=relevance&sortDir=desc) Poster.
18
+
19
+
20
+ ## Installation
21
+ #### Docker (recommended)
22
+ ```bash
23
+ git clone https://github.com/Nota-NetsPresso/nota-wav2lip.git
24
+ cd nota-wav2lip
25
+ docker compose run --service-ports --name nota-compressed-wav2lip compressed-wav2lip bash
26
+ ```
27
+
28
+ #### Conda
29
+ <details>
30
+ <summary>Click</summary>
31
+
32
+ ```bash
33
+ git clone https://github.com/Nota-NetsPresso/nota-wav2lip.git
34
+ cd nota-wav2lip
35
+ apt-get update
36
+ apt-get install ffmpeg libsm6 libxext6 tmux git -y
37
+ conda create -n nota-wav2lip python=3.9
38
+ conda activate nota-wav2lip
39
+ pip install -r requirements.txt
40
+ ```
41
+ </details>
42
+
43
+ ## Gradio Demo
44
+ Use the below script to run the [nota-ai/compressed-wav2lip demo](https://huggingface.co/spaces/nota-ai/compressed-wav2lip). The models and sample data will be downloaded automatically.
45
+
46
+ ```bash
47
+ bash app.sh
48
+ ```
49
+
50
+ ## Inference
51
+ (1) Download YouTube videos in the LRS3-TED label text file and preprocess them properly.
52
+ - Download `lrs3_v0.4_txt.zip` from [this link](https://mmai.io/datasets/lip_reading/).
53
+ - Unzip the file and make a folder structure: `./data/lrs3_v0.4_txt/lrs3_v0.4/test`
54
+ - Run `bash download.sh`
55
+ - Run `bash preprocess.sh`
56
+
57
+ (2) Run the script to compare the original Wav2Lip with Nota's compressed version.
58
+
59
+ ```bash
60
+ bash inference.sh
61
+ ```
62
+
63
+ ## License
64
+ - All rights related to this repository and the compressed models are reserved by Nota Inc.
65
+ - The intended use is strictly limited to research and non-commercial projects.
66
+
67
+ ## Contact
68
+ - To obtain compression code and assistance, kindly contact Nota AI ([email protected]). These are provided as part of our business solutions.
69
+ - For Q&A about this repo, use this board: [Nota-NetsPresso/discussions](https://github.com/orgs/Nota-NetsPresso/discussions)
70
+
71
+ ## Acknowledgment
72
+ - [NVIDIA Applied Research Accelerator Program](https://www.nvidia.com/en-us/industries/higher-education-research/applied-research-program/) for supporting this research.
73
+ - [Wav2Lip](https://github.com/Rudrabha/Wav2Lip) and [LRS3-TED](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/) for facilitating the development of the original Wav2Lip.
74
+
75
+ ## Citation
76
+ ```bibtex
77
+ @article{kim2023unified,
78
+ title={A Unified Compression Framework for Efficient Speech-Driven Talking-Face Generation},
79
+ author={Kim, Bo-Kyeong and Kang, Jaemin and Seo, Daeun and Park, Hancheol and Choi, Shinkook and Song, Hyoung-Kyu and Kim, Hyungshin and Lim, Sungsu},
80
+ journal={MLSys Workshop on On-Device Intelligence (ODIW)},
81
+ year={2023},
82
+ url={https://arxiv.org/abs/2304.00471}
83
+ }
84
+ ```
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+
7
+ from config import hparams as hp
8
+ from config import hparams_gradio as hp_gradio
9
+ from nota_wav2lip import Wav2LipModelComparisonGradio
10
+
11
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ device = hp_gradio.device
13
+ print(f'Using {device} for inference.')
14
+ video_label_dict = hp_gradio.sample.video
15
+ audio_label_dict = hp_gradio.sample.audio
16
+
17
+ LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
18
+ LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
19
+ LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None)
20
+
21
+ if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
22
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
23
+ if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
24
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)
25
+
26
+ path_inference_sample = "sample.tar.gz"
27
+ if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
28
+ subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
29
+ subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)
30
+
31
+
32
+ if __name__ == "__main__":
33
+
34
+ servicer = Wav2LipModelComparisonGradio(
35
+ device=device,
36
+ video_label_dict=video_label_dict,
37
+ audio_label_list=audio_label_dict,
38
+ default_video='v1',
39
+ default_audio='a1'
40
+ )
41
+
42
+ for video_name in sorted(video_label_dict):
43
+ video_stem = Path(video_label_dict[video_name])
44
+ servicer.update_video(video_stem, video_stem.with_suffix('.json'),
45
+ name=video_name)
46
+
47
+ for audio_name in sorted(audio_label_dict):
48
+ audio_path = Path(audio_label_dict[audio_name])
49
+ servicer.update_audio(audio_path, name=audio_name)
50
+
51
+ with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
52
+ gr.Markdown(Path('docs/header.md').read_text())
53
+ gr.Markdown(Path('docs/description.md').read_text())
54
+ with gr.Row():
55
+ with gr.Column(variant='panel'):
56
+
57
+ gr.Markdown('## Select input video and audio', sanitize_html=False)
58
+ # Define samples
59
+ sample_video = gr.Video(interactive=False, label="Input Video")
60
+ sample_audio = gr.Audio(interactive=False, label="Input Audio")
61
+
62
+ # Define radio inputs
63
+ video_selection = gr.components.Radio(video_label_dict,
64
+ type='value', label="Select an input video:")
65
+ audio_selection = gr.components.Radio(audio_label_dict,
66
+ type='value', label="Select an input audio:")
67
+ # Define button inputs
68
+ with gr.Row(equal_height=True):
69
+ generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
70
+ generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
71
+ with gr.Column(variant='panel'):
72
+ # Define original model output components
73
+ gr.Markdown('## Original Wav2Lip')
74
+ original_model_output = gr.Video(label="Original Model", interactive=False)
75
+ with gr.Column():
76
+ with gr.Row(equal_height=True):
77
+ original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
78
+ original_model_fps = gr.Textbox(value="", label="FPS")
79
+ original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters")
80
+ with gr.Column(variant='panel'):
81
+ # Define compressed model output components
82
+ gr.Markdown('## Compressed Wav2Lip (Ours)')
83
+ compressed_model_output = gr.Video(label="Compressed Model", interactive=False)
84
+ with gr.Column():
85
+ with gr.Row(equal_height=True):
86
+ compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
87
+ compressed_model_fps = gr.Textbox(value="", label="FPS")
88
+ compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters")
89
+
90
+ # Switch video and audio samples when selecting the raido button
91
+ video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video)
92
+ audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio)
93
+
94
+ # Click the generate button for original model
95
+ generate_original_button.click(servicer.generate_original_model,
96
+ inputs=[video_selection, audio_selection],
97
+ outputs=[original_model_output, original_model_inference_time, original_model_fps])
98
+ # Click the generate button for compressed model
99
+ generate_compressed_button.click(servicer.generate_compressed_model,
100
+ inputs=[video_selection, audio_selection],
101
+ outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps])
102
+
103
+ gr.Markdown(Path('docs/footer.md').read_text())
104
+
105
+ demo.queue().launch()
app.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ export LRS_ORIGINAL_URL=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/compressed-wav2lip/lrs3-wav2lip.pth && \
2
+ export LRS_COMPRESSED_URL=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/compressed-wav2lip/lrs3-nota-wav2lip.pth && \
3
+ export LRS_INFERENCE_SAMPLE=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/data/compressed-wav2lip-inference/sample.tar.gz && \
4
+ python app.py
checkpoints/.gitkeep ADDED
File without changes
config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from omegaconf import DictConfig, OmegaConf
2
+
3
+ hparams: DictConfig = OmegaConf.load("config/nota_wav2lip.yaml")
4
+
5
+ hparams_gradio: DictConfig = OmegaConf.load("config/gradio.yaml")
config/gradio.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: cpu
2
+ sample:
3
+ video:
4
+ v1: "sample/2145_orig"
5
+ v2: "sample/2942_orig"
6
+ v3: "sample/4598_orig"
7
+ v4: "sample/4653_orig"
8
+ v5: "sample/13692_orig"
9
+ audio:
10
+ a1: "sample/1673_orig.wav"
11
+ a2: "sample/9948_orig.wav"
12
+ a3: "sample/11028_orig.wav"
13
+ a4: "sample/12640_orig.wav"
14
+ a5: "sample/5592_orig.wav"
config/nota_wav2lip.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ inference:
3
+ batch_size: 1
4
+ frame:
5
+ h: 224
6
+ w: 224
7
+ model:
8
+ wav2lip:
9
+ checkpoint: "checkpoints/lrs3-wav2lip.pth"
10
+ nota_wav2lip:
11
+ checkpoint: "checkpoints/lrs3-nota-wav2lip.pth"
12
+
13
+ audio:
14
+ num_mels: 80
15
+ rescale: True
16
+ rescaling_max: 0.9
17
+
18
+ use_lws: False
19
+
20
+ n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
21
+ hop_size: 200 # For 16000Hz, 200 : 12.5 ms (0.0125 * sample_rate)
22
+ win_size: 800 # For 16000Hz, 800 : 50 ms (If None, win_size : n_fft) (0.05 * sample_rate)
23
+ sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
24
+
25
+ frame_shift_ms: ~
26
+
27
+ signal_normalization: True
28
+ allow_clipping_in_normalization: True
29
+ symmetric_mels: True
30
+ max_abs_value: 4.
31
+ preemphasize: True
32
+ preemphasis: 0.97
33
+
34
+ # Limits
35
+ min_level_db: -100
36
+ ref_level_db: 20
37
+ fmin: 55
38
+ fmax: 7600
39
+
40
+ face:
41
+ video_fps: 25
42
+ img_size: 96
43
+ mel_step_size: 16
44
+
data/.gitkeep ADDED
File without changes
docker-compose.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+ services:
3
+ compressed-wav2lip:
4
+ image: nota-compressed-wav2lip:dev
5
+ build: ./
6
+ container_name: nota-compressed-wav2lip
7
+ ipc: host
8
+ ports:
9
+ - "7860:7860"
10
+ volumes:
11
+ - ./:/workspace
docs/assets/fig5.png ADDED
docs/description.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This demo showcases a lightweight model for speech-driven talking-face synthesis, a **28× Compressed Wav2Lip**. The key features of our approach are:
2
+ - compact generator built by removing the residual blocks and reducing the channel width from Wav2Lip.
3
+ - knowledge distillation to effectively train the small-capacity generator without adversarial learning.
4
+ - selective quantization to accelerate inference on edge GPUs without noticeable performance degradation.
5
+
6
+ <!-- To demonstrate the efficacy of our approach, we provide a latency comparison of different precisions on NVIDIA Jetson edge GPUs in Figure 5. Our approach achieves a remarkable 8× to 17× speedup with FP16 precision, and a 19× speedup on Xavier NX with mixed precision. -->
7
+ The below figure shows a latency comparison at different precisions on NVIDIA Jetson edge GPUs, highlighting a 8× to 17× speedup at FP16 and a 19× speedup on Xavier NX at mixed precision.
8
+
9
+ <center>
10
+ <img alt="compressed-wav2lip-performance" src="https://huggingface.co/spaces/nota-ai/compressed-wav2lip/resolve/2b86e2aa4921d3422f0769ed02dce9898d1e0470/docs/assets/fig5.png" width="70%" />
11
+ </center>
12
+
13
+ <br/>
14
+
15
+ The generation speed may vary depending on network traffic. Nevertheless, our compresed Wav2Lip _consistently_ delivers a faster inference than the original model, while maintaining similar visual quality. Different from the paper, in this demo, we measure **total processing time** and **FPS** throughout loading the preprocessed video and audio, generating with the model, and merging lip-synced facial images with the original video.
16
+
17
+ <br/>
18
+
19
+
20
+ ### Notice
21
+ - This work was accepted to [Demo] [**ICCV 2023 Demo Track**](https://iccv2023.thecvf.com/demos-111.php); [[Paper](https://arxiv.org/abs/2304.00471)] [**On-Device Intelligence Workshop (ODIW) @ MLSys 2023**](https://sites.google.com/g.harvard.edu/on-device-workshop-23/home); [Poster] [**NVIDIA GPU Technology Conference (GTC) as Poster Spotlight**](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=52409&sort=relevance&sortDir=desc).
22
+ - We thank [NVIDIA Applied Research Accelerator Program](https://www.nvidia.com/en-us/industries/higher-education-research/applied-research-program/) for supporting this research and [Wav2Lip's Authors](https://github.com/Rudrabha/Wav2Lip) for their pioneering research.
docs/footer.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <p align="center">
2
+ <a href="https://netspresso.ai/"><img src="https://huggingface.co/spaces/nota-ai/theme/resolve/main/docs/logo/nota_favicon_800x800.png" width="96px" height="96px"></a>
3
+ </p>
4
+
5
+ <br/>
docs/header.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # <center>Lightweight Speech-Driven Talking-Face Synthesis Demo</center>
2
+
3
+ <br/>
4
+
5
+ <p align="center">
6
+ <a href="https://arxiv.org/abs/2304.00471"><img src="https://img.shields.io/badge/arXiv-2304.00471-b31b1b.svg?style=flat-square" style="display:inline;"></a>
7
+ <a href="https://huggingface.co/spaces/nota-ai/efficient_wav2lip"><img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fnota-ai%2Fefficient_wav2lip&count_bg=%23325AC8&title_bg=%23112344&icon=&icon_color=%23E7E7E7&title=HITS&edge_flat=true" style="display:inline;"></a>
8
+ </p>
9
+
10
+ <br/>
docs/main.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1, h2, h3 {
2
+ text-align: center;
3
+ display:block;
4
+ }
download.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from nota_wav2lip.preprocess import get_cropped_face_from_lrs3_label
4
+
5
+
6
+ def parse_args():
7
+
8
+ parser = argparse.ArgumentParser(description="NotaWav2Lip: Get LRS3 video sample with the label text file")
9
+
10
+ parser.add_argument(
11
+ '-i',
12
+ '--input-file',
13
+ type=str,
14
+ required=True,
15
+ help="Path of the label text file downloaded from https://mmai.io/datasets/lip_reading"
16
+ )
17
+
18
+ parser.add_argument(
19
+ '-o',
20
+ '--output-dir',
21
+ type=str,
22
+ default="sample_video_lrs3",
23
+ help="Output directory to save the result. Defaults: sample_video_lrs3"
24
+ )
25
+
26
+ parser.add_argument(
27
+ '--ignore-cache',
28
+ action='store_true',
29
+ help="Whether to force downloading and resampling video and overwrite pre-existing files"
30
+ )
31
+
32
+ args = parser.parse_args()
33
+
34
+ return args
35
+
36
+
37
+ if __name__ == '__main__':
38
+ args = parse_args()
39
+
40
+ get_cropped_face_from_lrs3_label(
41
+ args.input_file,
42
+ video_root_dir=args.output_dir,
43
+ ignore_cache = args.ignore_cache
44
+ )
download.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # example for audio source
2
+ python download.py\
3
+ -i data/lrs3_v0.4_txt/lrs3_v0.4/test/sxnlvwprfSc/00007.txt
4
+
5
+ # example for video source
6
+ python download.py\
7
+ -i data/lrs3_v0.4_txt/lrs3_v0.4/test/Li4S1yyrsTI/00010.txt
face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = '[email protected]'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize
face_detection/api.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+ ROOT = os.path.dirname(os.path.abspath(__file__))
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+
59
+ # Get the face detector
60
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
61
+ globals(), locals(), [face_detector], 0)
62
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
63
+
64
+ def get_detections_for_batch(self, images):
65
+ images = images[..., ::-1]
66
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
67
+ results = []
68
+
69
+ for i, d in enumerate(detected_faces):
70
+ if len(d) == 0:
71
+ results.append(None)
72
+ continue
73
+ d = d[0]
74
+ d = np.clip(d, 0, None)
75
+
76
+ x1, y1, x2, y2 = map(int, d[:-1])
77
+ results.append((x1, y1, x2, y2))
78
+
79
+ return results
face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+
70
+ bboxlist = []
71
+ for i in range(len(olist) // 2):
72
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
73
+ olist = [oelem.data.cpu() for oelem in olist]
74
+ for i in range(len(olist) // 2):
75
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
76
+ FB, FC, FH, FW = ocls.size() # feature map size
77
+ stride = 2**(i + 2) # 4,8,16,32,64,128
78
+ anchor = stride * 4
79
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
80
+ for Iindex, hindex, windex in poss:
81
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
82
+ score = ocls[:, 1, hindex, windex]
83
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
84
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
85
+ variances = [0.1, 0.2]
86
+ box = batch_decode(loc, priors, variances)
87
+ box = box[:, 0] * 1.0
88
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
89
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
90
+ bboxlist = np.array(bboxlist)
91
+ if 0 == len(bboxlist):
92
+ bboxlist = np.zeros((1, BB, 5))
93
+
94
+ return bboxlist
95
+
96
+ def flip_detect(net, img, device):
97
+ img = cv2.flip(img, 1)
98
+ b = detect(net, img, device)
99
+
100
+ bboxlist = np.zeros(b.shape)
101
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
102
+ bboxlist[:, 1] = b[:, 1]
103
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
104
+ bboxlist[:, 3] = b[:, 3]
105
+ bboxlist[:, 4] = b[:, 4]
106
+ return bboxlist
107
+
108
+
109
+ def pts_to_bb(pts):
110
+ min_x, min_y = np.min(pts, axis=0)
111
+ max_x, max_y = np.max(pts, axis=0)
112
+ return np.array([min_x, min_y, max_x, max_y])
face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ from config import hparams as hp
7
+ from nota_wav2lip import Wav2LipModelComparisonDemo
8
+
9
+ LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
10
+ LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
11
+
12
+ if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
13
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
14
+ if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
15
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)
16
+
17
+ def parse_args():
18
+
19
+ parser = argparse.ArgumentParser(description="NotaWav2Lip: Inference snippet for your own video and audio pair")
20
+
21
+ parser.add_argument(
22
+ '-a',
23
+ '--audio-input',
24
+ type=str,
25
+ required=True,
26
+ help="Path of the audio file"
27
+ )
28
+
29
+ parser.add_argument(
30
+ '-v',
31
+ '--video-frame-input',
32
+ type=str,
33
+ required=True,
34
+ help="Input directory with face image sequence. We recommend to extract the face image sequence with `preprocess.py`."
35
+ )
36
+
37
+ parser.add_argument(
38
+ '-b',
39
+ '--bbox-input',
40
+ type=str,
41
+ help="Path of the file with bbox coordinates. We recommend to extract the json file with `preprocess.py`."
42
+ "If None, it pretends that the json file is located at the same directory with face images: {VIDEO_FRAME_INPUT}.with_suffix('.json')."
43
+ )
44
+
45
+ parser.add_argument(
46
+ '-m',
47
+ '--model',
48
+ choices=['wav2lip', 'nota_wav2lip'],
49
+ default='nota_wav2ilp',
50
+ help="Model for generating talking video. Defaults: nota_wav2lip"
51
+ )
52
+
53
+ parser.add_argument(
54
+ '-o',
55
+ '--output-dir',
56
+ type=str,
57
+ default="result",
58
+ help="Output directory to save the result. Defaults: result"
59
+ )
60
+
61
+ parser.add_argument(
62
+ '-d',
63
+ '--device',
64
+ choices=['cpu', 'cuda'],
65
+ default='cpu',
66
+ help="Device setting for model inference. Defaults: cpu"
67
+ )
68
+
69
+ args = parser.parse_args()
70
+
71
+ return args
72
+
73
+ if __name__ == "__main__":
74
+ args = parse_args()
75
+ bbox_input = args.bbox_input if args.bbox_input is not None \
76
+ else Path(args.video_frame_input).with_suffix('.json')
77
+
78
+ servicer = Wav2LipModelComparisonDemo(device=args.device, result_dir=args.output_dir, model_list=args.model)
79
+ servicer.update_audio(args.audio_input, name='a0')
80
+ servicer.update_video(args.video_frame_input, bbox_input, name='v0')
81
+
82
+ servicer.save_as_video('a0', 'v0', args.model)
inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original Wav2Lip
2
+ python inference.py\
3
+ -a "sample_video_lrs3/sxnlvwprf_c-00007.wav"\
4
+ -v "sample_video_lrs3/Li4-1yyrsTI-00010"\
5
+ -m "wav2lip"\
6
+ -o "result_original"\
7
+ --device cpu
8
+
9
+ # Nota's Wav2Lip (28× Compressed)
10
+ python inference.py\
11
+ -a "sample_video_lrs3/sxnlvwprf_c-00007.wav"\
12
+ -v "sample_video_lrs3/Li4-1yyrsTI-00010"\
13
+ -m "nota_wav2lip"\
14
+ -o "result_nota"\
15
+ --device cpu
nota_wav2lip/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nota_wav2lip.demo import Wav2LipModelComparisonDemo
2
+ from nota_wav2lip.gradio import Wav2LipModelComparisonGradio
nota_wav2lip/audio.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+
7
+ from config import hparams
8
+
9
+ hp = hparams.audio
10
+
11
+ def load_wav(path, sr):
12
+ return librosa.core.load(path, sr=sr)[0]
13
+
14
+ def save_wav(wav, path, sr):
15
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
16
+ #proposed by @dsmiller
17
+ wavfile.write(path, sr, wav.astype(np.int16))
18
+
19
+ def save_wavenet_wav(wav, path, sr):
20
+ librosa.output.write_wav(path, wav, sr=sr)
21
+
22
+ def preemphasis(wav, k, preemphasize=True):
23
+ if preemphasize:
24
+ return signal.lfilter([1, -k], [1], wav)
25
+ return wav
26
+
27
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
28
+ if inv_preemphasize:
29
+ return signal.lfilter([1], [1, -k], wav)
30
+ return wav
31
+
32
+ def get_hop_size():
33
+ hop_size = hp.hop_size
34
+ if hop_size is None:
35
+ assert hp.frame_shift_ms is not None
36
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
37
+ return hop_size
38
+
39
+ def linearspectrogram(wav):
40
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
41
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
42
+
43
+ if hp.signal_normalization:
44
+ return _normalize(S)
45
+ return S
46
+
47
+ def melspectrogram(wav):
48
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
49
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
50
+
51
+ if hp.signal_normalization:
52
+ return _normalize(S)
53
+ return S
54
+
55
+ def _lws_processor():
56
+ import lws
57
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
58
+
59
+ def _stft(y):
60
+ if hp.use_lws:
61
+ return _lws_processor(hp).stft(y).T
62
+ else:
63
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
64
+
65
+ ##########################################################
66
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
67
+ def num_frames(length, fsize, fshift):
68
+ """Compute number of time frames of spectrogram
69
+ """
70
+ pad = (fsize - fshift)
71
+ M = (length + pad * 2 - fsize) // fshift + 1 if length % fshift == 0 else (length + pad * 2 - fsize) // fshift + 2
72
+ return M
73
+
74
+
75
+ def pad_lr(x, fsize, fshift):
76
+ """Compute left and right padding
77
+ """
78
+ M = num_frames(len(x), fsize, fshift)
79
+ pad = (fsize - fshift)
80
+ T = len(x) + 2 * pad
81
+ r = (M - 1) * fshift + fsize - T
82
+ return pad, pad + r
83
+ ##########################################################
84
+ #Librosa correct padding
85
+ def librosa_pad_lr(x, fsize, fshift):
86
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
87
+
88
+ # Conversions
89
+ _mel_basis = None
90
+
91
+ def _linear_to_mel(spectogram):
92
+ global _mel_basis
93
+ if _mel_basis is None:
94
+ _mel_basis = _build_mel_basis()
95
+ return np.dot(_mel_basis, spectogram)
96
+
97
+ def _build_mel_basis():
98
+ assert hp.fmax <= hp.sample_rate // 2
99
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
100
+ fmin=hp.fmin, fmax=hp.fmax)
101
+
102
+ def _amp_to_db(x):
103
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
104
+ return 20 * np.log10(np.maximum(min_level, x))
105
+
106
+ def _db_to_amp(x):
107
+ return np.power(10.0, (x) * 0.05)
108
+
109
+ def _normalize(S):
110
+ if hp.allow_clipping_in_normalization:
111
+ if hp.symmetric_mels:
112
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
113
+ -hp.max_abs_value, hp.max_abs_value)
114
+ else:
115
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
116
+
117
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
118
+ if hp.symmetric_mels:
119
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
120
+ else:
121
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
122
+
123
+ def _denormalize(D):
124
+ if hp.allow_clipping_in_normalization:
125
+ if hp.symmetric_mels:
126
+ return (((np.clip(D, -hp.max_abs_value,
127
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
128
+ + hp.min_level_db)
129
+ else:
130
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
131
+
132
+ if hp.symmetric_mels:
133
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
134
+ else:
135
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
nota_wav2lip/demo.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import platform
3
+ import subprocess
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Dict, Iterator, List, Literal, Optional, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from config import hparams as hp
12
+ from nota_wav2lip.inference import Wav2LipInferenceImpl
13
+ from nota_wav2lip.util import FFMPEG_LOGGING_MODE
14
+ from nota_wav2lip.video import AudioSlicer, VideoSlicer
15
+
16
+
17
+ class Wav2LipModelComparisonDemo:
18
+ def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None):
19
+ if model_list is None:
20
+ model_list: List[str] = ['wav2lip', 'nota_wav2lip']
21
+ if isinstance(model_list, str) and len(model_list) != 0:
22
+ model_list: List[str] = [model_list]
23
+ super().__init__()
24
+ self.video_dict: Dict[str, VideoSlicer] = {}
25
+ self.audio_dict: Dict[str, AudioSlicer] = {}
26
+
27
+ self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {}
28
+ for model_name in model_list:
29
+ assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}"
30
+ self.model_zoo[model_name] = Wav2LipInferenceImpl(
31
+ model_name, hp_inference_model=hp.inference.model[model_name], device=device
32
+ )
33
+
34
+ self._params_zoo: Dict[str, str] = {
35
+ model_name: self.model_zoo[model_name].params for model_name in self.model_zoo
36
+ }
37
+
38
+ self.result_dir: Path = Path(result_dir)
39
+ self.result_dir.mkdir(exist_ok=True)
40
+
41
+ @property
42
+ def params(self):
43
+ return self._params_zoo
44
+
45
+ def _infer(
46
+ self,
47
+ audio_name: str,
48
+ video_name: str,
49
+ model_type: Literal['wav2lip', 'nota_wav2lip']
50
+ ) -> Iterator[np.ndarray]:
51
+ audio_iterable: AudioSlicer = self.audio_dict[audio_name]
52
+ video_iterable: VideoSlicer = self.video_dict[video_name]
53
+ target_model = self.model_zoo[model_type]
54
+ return target_model.inference_with_iterator(audio_iterable, video_iterable)
55
+
56
+ def update_audio(self, audio_path, name=None):
57
+ _name = name if name is not None else Path(audio_path).stem
58
+ self.audio_dict.update(
59
+ {_name: AudioSlicer(audio_path)}
60
+ )
61
+
62
+ def update_video(self, frame_dir_path, bbox_path, name=None):
63
+ _name = name if name is not None else Path(frame_dir_path).stem
64
+ self.video_dict.update(
65
+ {_name: VideoSlicer(frame_dir_path, bbox_path)}
66
+ )
67
+
68
+ def save_as_video(self, audio_name, video_name, model_type):
69
+
70
+ output_video_path = self.result_dir / 'generated_with_audio.mp4'
71
+ frame_only_video_path = self.result_dir / 'generated.mp4'
72
+ audio_path = self.audio_dict[audio_name].audio_path
73
+
74
+ out = cv2.VideoWriter(str(frame_only_video_path),
75
+ cv2.VideoWriter_fourcc(*'mp4v'),
76
+ hp.face.video_fps,
77
+ (hp.inference.frame.w, hp.inference.frame.h))
78
+ start = time.time()
79
+ for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type):
80
+ out.write(frame)
81
+ inference_time = time.time() - start
82
+ out.release()
83
+
84
+ command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}"
85
+ subprocess.call(command, shell=platform.system() != 'Windows')
86
+
87
+ # The number of frames of generated video
88
+ video_frames_num = len(self.audio_dict[audio_name])
89
+ inference_fps = video_frames_num / inference_time
90
+
91
+ return output_video_path, inference_time, inference_fps
nota_wav2lip/gradio.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from pathlib import Path
3
+
4
+ from nota_wav2lip.demo import Wav2LipModelComparisonDemo
5
+
6
+
7
+ class Wav2LipModelComparisonGradio(Wav2LipModelComparisonDemo):
8
+ def __init__(
9
+ self,
10
+ device='cpu',
11
+ result_dir='./temp',
12
+ video_label_dict=None,
13
+ audio_label_list=None,
14
+ default_video='v1',
15
+ default_audio='a1'
16
+ ) -> None:
17
+ if audio_label_list is None:
18
+ audio_label_list = {}
19
+ if video_label_dict is None:
20
+ video_label_dict = {}
21
+ super().__init__(device, result_dir)
22
+ self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()}
23
+ self._audio_label_dict = audio_label_list
24
+ self._default_video = default_video
25
+ self._default_audio = default_audio
26
+
27
+ self._lock = threading.Lock() # lock for asserting that concurrency_count == 1
28
+
29
+ def _is_valid_input(self, video_selection, audio_selection):
30
+ assert video_selection in self._video_label_dict, \
31
+ f"Your input ({video_selection}) is not in {self._video_label_dict}!!!"
32
+ assert audio_selection in self._audio_label_dict, \
33
+ f"Your input ({audio_selection}) is not in {self._audio_label_dict}!!!"
34
+
35
+ def generate_original_model(self, video_selection, audio_selection):
36
+ try:
37
+ self._is_valid_input(video_selection, audio_selection)
38
+
39
+ with self._lock:
40
+ output_video_path, inference_time, inference_fps = \
41
+ self.save_as_video(audio_name=audio_selection,
42
+ video_name=video_selection,
43
+ model_type='wav2lip')
44
+
45
+ return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f")
46
+ except KeyboardInterrupt:
47
+ exit()
48
+ except Exception as e:
49
+ print(e)
50
+ pass
51
+
52
+ def generate_compressed_model(self, video_selection, audio_selection):
53
+ try:
54
+ self._is_valid_input(video_selection, audio_selection)
55
+
56
+ with self._lock:
57
+ output_video_path, inference_time, inference_fps = \
58
+ self.save_as_video(audio_name=audio_selection,
59
+ video_name=video_selection,
60
+ model_type='nota_wav2lip')
61
+
62
+ return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f")
63
+ except KeyboardInterrupt:
64
+ exit()
65
+ except Exception as e:
66
+ print(e)
67
+ pass
68
+
69
+ def switch_video_samples(self, video_selection):
70
+ try:
71
+ if video_selection not in self._video_label_dict:
72
+ return self._video_label_dict[self._default_video]
73
+ return self._video_label_dict[video_selection]
74
+
75
+ except KeyboardInterrupt:
76
+ exit()
77
+ except Exception as e:
78
+ print(e)
79
+ pass
80
+
81
+ def switch_audio_samples(self, audio_selection):
82
+ try:
83
+ if audio_selection not in self._audio_label_dict:
84
+ return self._audio_label_dict[self._default_audio]
85
+ return self._audio_label_dict[audio_selection]
86
+
87
+ except KeyboardInterrupt:
88
+ exit()
89
+ except Exception as e:
90
+ print(e)
91
+ pass
nota_wav2lip/inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, Iterator, List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from omegaconf import DictConfig
8
+ from tqdm import tqdm
9
+
10
+ from config import hparams as hp
11
+ from nota_wav2lip.models.util import count_params, load_model
12
+
13
+
14
+ class Wav2LipInferenceImpl:
15
+ def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'):
16
+ self.model: nn.Module = load_model(
17
+ model_name,
18
+ device=device,
19
+ **hp_inference_model
20
+ )
21
+ self.device = device
22
+ self._params: str = self._format_param(count_params(self.model))
23
+
24
+ @property
25
+ def params(self):
26
+ return self._params
27
+
28
+ @staticmethod
29
+ def _format_param(num_params: int) -> str:
30
+ params_in_million = num_params / 1e6
31
+ return f"{params_in_million:.1f}M"
32
+
33
+ @staticmethod
34
+ def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]:
35
+ return [], [], [], []
36
+
37
+ def get_data_iterator(
38
+ self,
39
+ audio_iterable: Iterable[np.ndarray],
40
+ video_iterable: List[Tuple[np.ndarray, List[int]]]
41
+ ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]:
42
+ img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
43
+
44
+ for i, m in enumerate(audio_iterable):
45
+ idx = i % len(video_iterable)
46
+ _frame_to_save, coords = video_iterable[idx]
47
+ frame_to_save = _frame_to_save.copy()
48
+ face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy()
49
+
50
+ face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size))
51
+
52
+ img_batch.append(face)
53
+ mel_batch.append(m)
54
+ frame_batch.append(frame_to_save)
55
+ coords_batch.append(coords)
56
+
57
+ if len(img_batch) >= hp.inference.batch_size:
58
+ img_batch = np.asarray(img_batch)
59
+ mel_batch = np.asarray(mel_batch)
60
+
61
+ img_masked = img_batch.copy()
62
+ img_masked[:, hp.face.img_size // 2:] = 0
63
+
64
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
65
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
66
+
67
+ yield img_batch, mel_batch, frame_batch, coords_batch
68
+ img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
69
+
70
+ if len(img_batch) > 0:
71
+ img_batch = np.asarray(img_batch)
72
+ mel_batch = np.asarray(mel_batch)
73
+
74
+ img_masked = img_batch.copy()
75
+ img_masked[:, hp.face.img_size // 2:] = 0
76
+
77
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
78
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
79
+
80
+ yield img_batch, mel_batch, frame_batch, coords_batch
81
+
82
+ @torch.no_grad()
83
+ def inference_with_iterator(
84
+ self,
85
+ audio_iterable: Iterable[np.ndarray],
86
+ video_iterable: List[Tuple[np.ndarray, List[int]]]
87
+ ) -> Iterator[np.ndarray]:
88
+ data_iterator = self.get_data_iterator(audio_iterable, video_iterable)
89
+
90
+ for (img_batch, mel_batch, frames, coords) in \
91
+ tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))):
92
+
93
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
94
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
95
+
96
+ preds: torch.Tensor = self.forward(mel_batch, img_batch)
97
+
98
+ preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255.
99
+ for pred, frame, coord in zip(preds, frames, coords):
100
+ y1, y2, x1, x2 = coord
101
+ pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1))
102
+
103
+ frame[y1:y2, x1:x2] = pred
104
+ yield frame
105
+
106
+ @torch.no_grad()
107
+ def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor:
108
+ return self.model(audio_sequences, face_sequences)
109
+
110
+ def __call__(self, *args, **kwargs):
111
+ return self.forward(*args, **kwargs)
nota_wav2lip/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import Wav2LipBase
2
+ from .wav2lip import Wav2Lip
3
+ from .wav2lip_compressed import NotaWav2Lip
nota_wav2lip/models/base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import final
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class Wav2LipBase(nn.Module):
8
+ def __init__(self) -> None:
9
+ super().__init__()
10
+
11
+ self.audio_encoder = nn.Sequential()
12
+ self.face_encoder_blocks = nn.ModuleList([])
13
+ self.face_decoder_blocks = nn.ModuleList([])
14
+ self.output_block = nn.Sequential()
15
+
16
+ @final
17
+ def forward(self, audio_sequences, face_sequences):
18
+ # audio_sequences = (B, T, 1, 80, 16)
19
+ B = audio_sequences.size(0)
20
+
21
+ input_dim_size = len(face_sequences.size())
22
+ if input_dim_size > 4:
23
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
24
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
25
+
26
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
27
+
28
+ feats = []
29
+ x = face_sequences
30
+ for f in self.face_encoder_blocks:
31
+ x = f(x)
32
+ feats.append(x)
33
+
34
+ x = audio_embedding
35
+ for f in self.face_decoder_blocks:
36
+ x = f(x)
37
+ try:
38
+ x = torch.cat((x, feats[-1]), dim=1)
39
+ except Exception as e:
40
+ print(x.size())
41
+ print(feats[-1].size())
42
+ raise e
43
+
44
+ feats.pop()
45
+
46
+ x = self.output_block(x)
47
+
48
+ if input_dim_size > 4:
49
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
50
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
51
+
52
+ else:
53
+ outputs = x
54
+
55
+ return outputs
nota_wav2lip/models/conv.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class Conv2d(nn.Module):
7
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.conv_block = nn.Sequential(
10
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
11
+ nn.BatchNorm2d(cout)
12
+ )
13
+ self.act = nn.ReLU()
14
+ self.residual = residual
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+ return self.act(out)
21
+
22
+
23
+ class Conv2dTranspose(nn.Module):
24
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ self.conv_block = nn.Sequential(
27
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
28
+ nn.BatchNorm2d(cout)
29
+ )
30
+ self.act = nn.ReLU()
31
+
32
+ def forward(self, x):
33
+ out = self.conv_block(x)
34
+ return self.act(out)
nota_wav2lip/models/util.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Type
2
+
3
+ import torch
4
+
5
+ from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase
6
+
7
+ MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = {
8
+ 'wav2lip': Wav2Lip,
9
+ 'nota_wav2lip': NotaWav2Lip
10
+ }
11
+
12
+ def _load(checkpoint_path, device):
13
+ assert device in ['cpu', 'cuda']
14
+
15
+ print(f"Load checkpoint from: {checkpoint_path}")
16
+ if device == 'cuda':
17
+ return torch.load(checkpoint_path)
18
+ return torch.load(checkpoint_path, map_location=lambda storage, _: storage)
19
+
20
+ def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase:
21
+
22
+ cls = MODEL_REGISTRY[model_name.lower()]
23
+ assert issubclass(cls, Wav2LipBase)
24
+
25
+ model = cls(**kwargs)
26
+ checkpoint = _load(checkpoint, device)
27
+ model.load_state_dict(checkpoint)
28
+ model = model.to(device)
29
+ return model.eval()
30
+
31
+ def count_params(model):
32
+ return sum(p.numel() for p in model.parameters())
nota_wav2lip/models/wav2lip.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from nota_wav2lip.models.base import Wav2LipBase
5
+ from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose
6
+
7
+
8
+ class Wav2Lip(Wav2LipBase):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ self.face_encoder_blocks = nn.ModuleList([
13
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
14
+
15
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
16
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
17
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
18
+
19
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
20
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
22
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
23
+
24
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
25
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
26
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
27
+
28
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
29
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
30
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
31
+
32
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
33
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
34
+
35
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
36
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
37
+
38
+ self.audio_encoder = nn.Sequential(
39
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
40
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
41
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
42
+
43
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+
51
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
52
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
53
+
54
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
55
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
56
+
57
+ self.face_decoder_blocks = nn.ModuleList([
58
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
59
+
60
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
62
+
63
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
64
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
65
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
66
+
67
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
68
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
69
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
70
+
71
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
73
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
74
+
75
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
76
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
77
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
78
+
79
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
80
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
81
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
82
+
83
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
84
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
85
+ nn.Sigmoid())
nota_wav2lip/models/wav2lip_compressed.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from nota_wav2lip.models.base import Wav2LipBase
5
+ from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose
6
+
7
+
8
+ class NotaWav2Lip(Wav2LipBase):
9
+ def __init__(self, nef=4, naf=8, ndf=8, x_size=96):
10
+ super().__init__()
11
+
12
+ assert x_size in [96, 128]
13
+ self.ker_sz_last = x_size // 32
14
+
15
+ self.face_encoder_blocks = nn.ModuleList([
16
+ nn.Sequential(Conv2d(6, nef, kernel_size=7, stride=1, padding=3)), # 96,96
17
+
18
+ nn.Sequential(Conv2d(nef, nef * 2, kernel_size=3, stride=2, padding=1),), # 48,48
19
+
20
+ nn.Sequential(Conv2d(nef * 2, nef * 4, kernel_size=3, stride=2, padding=1),), # 24,24
21
+
22
+ nn.Sequential(Conv2d(nef * 4, nef * 8, kernel_size=3, stride=2, padding=1),), # 12,12
23
+
24
+ nn.Sequential(Conv2d(nef * 8, nef * 16, kernel_size=3, stride=2, padding=1),), # 6,6
25
+
26
+ nn.Sequential(Conv2d(nef * 16, nef * 32, kernel_size=3, stride=2, padding=1),), # 3,3
27
+
28
+ nn.Sequential(Conv2d(nef * 32, nef * 32, kernel_size=self.ker_sz_last, stride=1, padding=0), # 1, 1
29
+ Conv2d(nef * 32, nef * 32, kernel_size=1, stride=1, padding=0)), ])
30
+
31
+ self.audio_encoder = nn.Sequential(
32
+ Conv2d(1, naf, kernel_size=3, stride=1, padding=1),
33
+
34
+ Conv2d(naf, naf * 2, kernel_size=3, stride=(3, 1), padding=1),
35
+
36
+ Conv2d(naf * 2, naf * 4, kernel_size=3, stride=3, padding=1),
37
+
38
+ Conv2d(naf * 4, naf * 8, kernel_size=3, stride=(3, 2), padding=1),
39
+
40
+ Conv2d(naf * 8, naf * 16, kernel_size=3, stride=1, padding=0),
41
+ Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), )
42
+
43
+ self.face_decoder_blocks = nn.ModuleList([
44
+ nn.Sequential(Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), ),
45
+
46
+ nn.Sequential(Conv2dTranspose(nef * 32 + naf * 16, ndf * 16, kernel_size=self.ker_sz_last, stride=1, padding=0),),
47
+ # 3,3 # 512+512 = 1024
48
+
49
+ nn.Sequential(
50
+ Conv2dTranspose(nef * 32 + ndf * 16, ndf * 16, kernel_size=3, stride=2, padding=1, output_padding=1),), # 6, 6
51
+ # 512+512 = 1024
52
+
53
+ nn.Sequential(
54
+ Conv2dTranspose(nef * 16 + ndf * 16, ndf * 12, kernel_size=3, stride=2, padding=1, output_padding=1),), # 12, 12
55
+ # 256+512 = 768
56
+
57
+ nn.Sequential(
58
+ Conv2dTranspose(nef * 8 + ndf * 12, ndf * 8, kernel_size=3, stride=2, padding=1, output_padding=1),), # 24, 24
59
+ # 128+384 = 512
60
+
61
+ nn.Sequential(
62
+ Conv2dTranspose(nef * 4 + ndf * 8, ndf * 4, kernel_size=3, stride=2, padding=1, output_padding=1),), # 48, 48
63
+ # 64+256 = 320
64
+
65
+ nn.Sequential(
66
+ Conv2dTranspose(nef * 2 + ndf * 4, ndf * 2, kernel_size=3, stride=2, padding=1, output_padding=1),), # 96,96
67
+ # 32+128 = 160
68
+ ])
69
+
70
+ self.output_block = nn.Sequential(Conv2d(nef + ndf * 2, ndf, kernel_size=3, stride=1, padding=1), # 16+64 = 80
71
+ nn.Conv2d(ndf, 3, kernel_size=1, stride=1, padding=0),
72
+ nn.Sigmoid())
nota_wav2lip/preprocess/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nota_wav2lip.preprocess.core import get_preprocessed_data
2
+ from nota_wav2lip.preprocess.lrs3_download import get_cropped_face_from_lrs3_label
nota_wav2lip/preprocess/core.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import platform
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from loguru import logger
9
+ from tqdm import tqdm
10
+
11
+ import face_detection
12
+ from nota_wav2lip.util import FFMPEG_LOGGING_MODE
13
+
14
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cpu')
15
+ PADDING = [0, 10, 0, 0]
16
+
17
+
18
+ def get_smoothened_boxes(boxes, T):
19
+ for i in range(len(boxes)):
20
+ window = boxes[len(boxes) - T:] if i + T > len(boxes) else boxes[i:i + T]
21
+ boxes[i] = np.mean(window, axis=0)
22
+ return boxes
23
+
24
+
25
+ def face_detect(images, pads, no_smooth=False, batch_size=1):
26
+
27
+ predictions = []
28
+ images_array = [cv2.imread(str(image)) for image in images]
29
+ for i in tqdm(range(0, len(images_array), batch_size)):
30
+ predictions.extend(detector.get_detections_for_batch(np.array(images_array[i:i + batch_size])))
31
+
32
+ results = []
33
+ pady1, pady2, padx1, padx2 = pads
34
+ for rect, image_array in zip(predictions, images_array):
35
+ if rect is None:
36
+ cv2.imwrite('temp/faulty_frame.jpg', image_array) # check this frame where the face was not detected.
37
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
38
+
39
+ y1 = max(0, rect[1] - pady1)
40
+ y2 = min(image_array.shape[0], rect[3] + pady2)
41
+ x1 = max(0, rect[0] - padx1)
42
+ x2 = min(image_array.shape[1], rect[2] + padx2)
43
+ results.append([x1, y1, x2, y2])
44
+
45
+ boxes = np.array(results)
46
+ bbox_format = "(y1, y2, x1, x2)"
47
+ if not no_smooth:
48
+ boxes = get_smoothened_boxes(boxes, T=5)
49
+ outputs = {
50
+ 'bbox': {str(image_path): tuple(map(int, (y1, y2, x1, x2))) for image_path, (x1, y1, x2, y2) in zip(images, boxes)},
51
+ 'format': bbox_format
52
+ }
53
+ return outputs
54
+
55
+
56
+ def save_video_frame(video_path, output_dir=None):
57
+ video_path = Path(video_path)
58
+ output_dir = output_dir if output_dir is not None else video_path.with_suffix('')
59
+ output_dir.mkdir(exist_ok=True)
60
+ return subprocess.call(
61
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -r 25 -f image2 {output_dir}/%05d.jpg",
62
+ shell=platform.system() != 'Windows'
63
+ )
64
+
65
+
66
+ def save_audio_file(video_path, output_path=None):
67
+ video_path = Path(video_path)
68
+ output_path = output_path if output_path is not None else video_path.with_suffix('.wav')
69
+ subprocess.call(
70
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -vn -acodec pcm_s16le -ar 16000 -ac 1 {output_path}",
71
+ shell=platform.system() != 'Windows'
72
+ )
73
+
74
+
75
+ def save_bbox_file(video_path, bbox_dict, output_path=None):
76
+ video_path = Path(video_path)
77
+ output_path = output_path if output_path is not None else video_path.with_suffix('.json')
78
+
79
+ with open(output_path, 'w') as f:
80
+ json.dump(bbox_dict, f, indent=4)
81
+
82
+ def get_preprocessed_data(video_path: Path):
83
+ video_path = Path(video_path)
84
+
85
+ image_sequence_dir = video_path.with_suffix('')
86
+ audio_path = video_path.with_suffix('.wav')
87
+ face_bbox_json_path = video_path.with_suffix('.json')
88
+
89
+ logger.info(f"Save 25 FPS video frames as image files ... will be saved at {video_path}")
90
+ save_video_frame(video_path=video_path, output_dir=image_sequence_dir)
91
+
92
+ logger.info(f"Save the audio as wav file ... will be saved at {audio_path}")
93
+ save_audio_file(video_path=video_path, output_path=audio_path) # bonus
94
+
95
+ # Load images, extract bboxes and save the coords(to directly use as array indicies)
96
+ logger.info(f"Extract face boxes and save the coords with json format ... will be saved at {face_bbox_json_path}")
97
+ results = face_detect(sorted(image_sequence_dir.glob("*.jpg")), pads=PADDING)
98
+ save_bbox_file(video_path, results, output_path=face_bbox_json_path)
nota_wav2lip/preprocess/ffmpeg.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FFMPEG_LOGGING_MODE = {
2
+ 'DEBUG': "",
3
+ 'INFO': "-v quiet -stats",
4
+ 'ERROR': "-hide_banner -loglevel error",
5
+ }
nota_wav2lip/preprocess/lrs3_download.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import subprocess
3
+ from pathlib import Path
4
+ from typing import Dict, List, Tuple, TypedDict, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import yt_dlp
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+
12
+ from nota_wav2lip.util import FFMPEG_LOGGING_MODE
13
+
14
+
15
+ class LabelInfo(TypedDict):
16
+ text: str
17
+ conf: int
18
+ url: str
19
+ bbox_xywhn: Dict[int, Tuple[float, float, float, float]]
20
+
21
+ def frame_to_time(frame_id: int, fps=25) -> str:
22
+ seconds = frame_id / fps
23
+
24
+ hours = int(seconds // 3600)
25
+ seconds -= 3600 * hours
26
+
27
+ minutes = int(seconds // 60)
28
+ seconds -= 60 * minutes
29
+
30
+ seconds_int = int(seconds)
31
+ seconds_milli = int((seconds - int(seconds)) * 1e3)
32
+
33
+ return f"{hours:02d}:{minutes:02d}:{seconds_int:02d}.{seconds_milli:03d}" # HH:MM:SS.mmm
34
+
35
+ def save_audio_file(input_path, start_frame_id, to_frame_id, output_path=None):
36
+ input_path = Path(input_path)
37
+ output_path = output_path if output_path is not None else input_path.with_suffix('.wav')
38
+
39
+ ss = frame_to_time(start_frame_id)
40
+ to = frame_to_time(to_frame_id)
41
+ subprocess.call(
42
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {input_path} -vn -acodec pcm_s16le -ss {ss} -to {to} -ar 16000 -ac 1 {output_path}",
43
+ shell=platform.system() != 'Windows'
44
+ )
45
+
46
+ def merge_video_audio(video_path, audio_path, output_path):
47
+ subprocess.call(
48
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -i {audio_path} -strict experimental {output_path}",
49
+ shell=platform.system() != 'Windows'
50
+ )
51
+
52
+ def parse_lrs3_label(label_path) -> LabelInfo:
53
+ label_text = Path(label_path).read_text()
54
+ label_splitted = label_text.split('\n')
55
+
56
+ # Label validation
57
+ assert label_splitted[0].startswith("Text:")
58
+ assert label_splitted[1].startswith("Conf:")
59
+ assert label_splitted[2].startswith("Ref:")
60
+ assert label_splitted[4].startswith("FRAME")
61
+
62
+ label_info = LabelInfo(bbox_xywhn={})
63
+ label_info['text'] = label_splitted[0][len("Text: "):].strip()
64
+ label_info['conf'] = int(label_splitted[1][len("Conf: "):])
65
+ label_info['url'] = label_splitted[2][len("Ref: "):].strip()
66
+
67
+ for label_line in label_splitted[5:]:
68
+ bbox_splitted = [x.strip() for x in label_line.split('\t')]
69
+ if len(bbox_splitted) != 5:
70
+ continue
71
+ frame_index = int(bbox_splitted[0])
72
+ bbox_xywhn = tuple(map(float, bbox_splitted[1:]))
73
+ label_info['bbox_xywhn'][frame_index] = bbox_xywhn
74
+
75
+ return label_info
76
+
77
+ def _get_cropped_bbox(bbox_info_xywhn, original_width, original_height):
78
+
79
+ bbox_info = bbox_info_xywhn
80
+ x = bbox_info[0] * original_width
81
+ y = bbox_info[1] * original_height
82
+ w = bbox_info[2] * original_width
83
+ h = bbox_info[3] * original_height
84
+
85
+ x_min = max(0, int(x - 0.5 * w))
86
+ y_min = max(0, int(y))
87
+ x_max = min(original_width, int(x + 1.5 * w))
88
+ y_max = min(original_height, int(y + 1.5 * h))
89
+
90
+ cropped_width = x_max - x_min
91
+ cropped_height = y_max - y_min
92
+
93
+ if cropped_height > cropped_width:
94
+ offset = cropped_height - cropped_width
95
+ offset_low = min(x_min, offset // 2)
96
+ offset_high = min(offset - offset_low, original_width - x_max)
97
+ x_min -= offset_low
98
+ x_max += offset_high
99
+ else:
100
+ offset = cropped_width - cropped_height
101
+ offset_low = min(y_min, offset // 2)
102
+ offset_high = min(offset - offset_low, original_width - y_max)
103
+ y_min -= offset_low
104
+ y_max += offset_high
105
+
106
+ return x_min, y_min, x_max, y_max
107
+
108
+ def _get_smoothened_boxes(bbox_dict, bbox_smoothen_window):
109
+ boxes = [np.array(bbox_dict[frame_id]) for frame_id in sorted(bbox_dict)]
110
+ for i in range(len(boxes)):
111
+ window = boxes[len(boxes) - bbox_smoothen_window:] if i + bbox_smoothen_window > len(boxes) else boxes[i:i + bbox_smoothen_window]
112
+ boxes[i] = np.mean(window, axis=0)
113
+
114
+ for idx, frame_id in enumerate(sorted(bbox_dict)):
115
+ bbox_dict[frame_id] = (np.rint(boxes[idx])).astype(int).tolist()
116
+ return bbox_dict
117
+
118
+ def download_video_from_youtube(youtube_ref, output_path):
119
+ ydl_url = f"https://www.youtube.com/watch?v={youtube_ref}"
120
+ ydl_opts = {
121
+ 'format': 'bestvideo[ext=mp4][height<=720]+bestaudio[ext=m4a]/best[ext=mp4][height<=720]',
122
+ 'outtmpl': str(output_path),
123
+ }
124
+
125
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
126
+ ydl.download([ydl_url])
127
+
128
+ def resample_video(input_path, output_path):
129
+ subprocess.call(
130
+ f"ffmpeg {FFMPEG_LOGGING_MODE['INFO']} -y -i {input_path} -r 25 -preset veryfast {output_path}",
131
+ shell=platform.system() != 'Windows'
132
+ )
133
+
134
+ def _get_smoothen_xyxy_bbox(
135
+ label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
136
+ original_width: int,
137
+ original_height: int,
138
+ bbox_smoothen_window: int = 5
139
+ ) -> Dict[int, Tuple[float, float, float, float]]:
140
+
141
+ label_bbox_xyxy: Dict[int, Tuple[float, float, float, float]] = {}
142
+ for frame_id in sorted(label_bbox_xywhn):
143
+ frame_bbox_xywhn = label_bbox_xywhn[frame_id]
144
+ bbox_xyxy = _get_cropped_bbox(frame_bbox_xywhn, original_width, original_height)
145
+ label_bbox_xyxy[frame_id] = bbox_xyxy
146
+
147
+ label_bbox_xyxy = _get_smoothened_boxes(label_bbox_xyxy, bbox_smoothen_window=bbox_smoothen_window)
148
+ return label_bbox_xyxy
149
+
150
+ def get_start_end_frame_id(
151
+ label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
152
+ ) -> Tuple[int, int]:
153
+ frame_ids = list(label_bbox_xywhn.keys())
154
+ start_frame_id = min(frame_ids)
155
+ to_frame_id = max(frame_ids)
156
+ return start_frame_id, to_frame_id
157
+
158
+ def crop_video_with_bbox(
159
+ input_path,
160
+ label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
161
+ start_frame_id,
162
+ to_frame_id,
163
+ output_path,
164
+ bbox_smoothen_window = 5,
165
+ frame_width = 224,
166
+ frame_height = 224,
167
+ fps = 25,
168
+ interpolation = cv2.INTER_CUBIC,
169
+ ):
170
+ def frame_generator(cap):
171
+ if not cap.isOpened():
172
+ raise IOError("Error: Could not open video.")
173
+
174
+ while True:
175
+ ret, frame = cap.read()
176
+ if not ret:
177
+ break
178
+ yield frame
179
+
180
+ cap.release()
181
+
182
+ cap = cv2.VideoCapture(str(input_path))
183
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
184
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
185
+ label_bbox_xyxy = _get_smoothen_xyxy_bbox(label_bbox_xywhn, original_width, original_height, bbox_smoothen_window=bbox_smoothen_window)
186
+
187
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
188
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
189
+
190
+ for frame_id, frame in tqdm(enumerate(frame_generator(cap))):
191
+ if start_frame_id <= frame_id <= to_frame_id:
192
+ x_min, y_min, x_max, y_max = label_bbox_xyxy[frame_id]
193
+
194
+ frame_cropped = frame[y_min:y_max, x_min:x_max]
195
+ frame_cropped = cv2.resize(frame_cropped, (frame_width, frame_height), interpolation=interpolation)
196
+ out.write(frame_cropped)
197
+
198
+ out.release()
199
+
200
+
201
+ def get_cropped_face_from_lrs3_label(
202
+ label_text_path: Union[Path, str],
203
+ video_root_dir: Union[Path, str],
204
+ bbox_smoothen_window: int = 5,
205
+ frame_width: int = 224,
206
+ frame_height: int = 224,
207
+ fps: int = 25,
208
+ interpolation = cv2.INTER_CUBIC,
209
+ ignore_cache: bool = False,
210
+ ):
211
+ label_text_path = Path(label_text_path)
212
+ label_info = parse_lrs3_label(label_text_path)
213
+ start_frame_id, to_frame_id = get_start_end_frame_id(label_info['bbox_xywhn'])
214
+
215
+ video_root_dir = Path(video_root_dir)
216
+ video_cache_dir = video_root_dir / ".cache"
217
+ video_cache_dir.mkdir(parents=True, exist_ok=True)
218
+
219
+ output_video: Path = video_cache_dir / f"{label_info['url']}.mp4"
220
+ output_resampled_video: Path = output_video.with_name(f"{output_video.stem}-25fps.mp4")
221
+ output_cropped_audio: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.wav")
222
+ output_cropped_video: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.mp4")
223
+ output_cropped_with_audio: Path = video_root_dir / output_video.with_name(f"{output_video.stem}-{label_text_path.stem}.mp4").name
224
+
225
+ if not output_video.exists() or ignore_cache:
226
+ youtube_ref = label_info['url']
227
+ logger.info(f"Download Youtube video(https://www.youtube.com/watch?v={youtube_ref}) ... will be saved at {output_video}")
228
+ download_video_from_youtube(youtube_ref, output_path=output_video)
229
+
230
+ if not output_resampled_video.exists() or ignore_cache:
231
+ logger.info(f"Resampling video to 25 FPS ... will be saved at {output_resampled_video}")
232
+ resample_video(input_path=output_video, output_path=output_resampled_video)
233
+
234
+ if not output_cropped_audio.exists() or ignore_cache:
235
+ logger.info(f"Cut audio file with the given timestamps ... will be saved at {output_cropped_audio}")
236
+ save_audio_file(
237
+ output_resampled_video,
238
+ start_frame_id=start_frame_id,
239
+ to_frame_id=to_frame_id,
240
+ output_path=output_cropped_audio
241
+ )
242
+
243
+ logger.info(f"Naive crop the face region with the given frame labels ... will be saved at {output_cropped_video}")
244
+ crop_video_with_bbox(
245
+ output_resampled_video,
246
+ label_info['bbox_xywhn'],
247
+ start_frame_id,
248
+ to_frame_id,
249
+ output_path=output_cropped_video,
250
+ bbox_smoothen_window=bbox_smoothen_window,
251
+ frame_width=frame_width,
252
+ frame_height=frame_height,
253
+ fps=fps,
254
+ interpolation=interpolation
255
+ )
256
+
257
+ if not output_cropped_with_audio.exists() or ignore_cache:
258
+ logger.info(f"Merge an audio track with the cropped face sequence ... will be saved at {output_cropped_with_audio}")
259
+ merge_video_audio(output_cropped_video, output_cropped_audio, output_cropped_with_audio)
nota_wav2lip/util.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FFMPEG_LOGGING_MODE = {
2
+ 'DEBUG': "",
3
+ 'INFO': "-v quiet -stats",
4
+ 'ERROR': "-hide_banner -loglevel error",
5
+ }
nota_wav2lip/video.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ import nota_wav2lip.audio as audio
9
+ from config import hparams as hp
10
+
11
+
12
+ class VideoSlicer:
13
+ def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str]):
14
+ self.fps = hp.face.video_fps
15
+ self.frame_dir = frame_dir
16
+ self.frame_path_list = sorted(Path(self.frame_dir).glob("*.jpg"))
17
+ self.frame_array_list: List[np.ndarray] = [cv2.imread(str(image)) for image in self.frame_path_list]
18
+
19
+ with open(bbox_path, 'r') as f:
20
+ metadata = json.load(f)
21
+ self.bbox: List[List[int]] = [metadata['bbox'][key] for key in sorted(metadata['bbox'].keys())]
22
+ self.bbox_format = metadata['format']
23
+ assert len(self.bbox) == len(self.frame_array_list)
24
+
25
+ def __len__(self):
26
+ return len(self.frame_array_list)
27
+
28
+ def __getitem__(self, idx) -> Tuple[np.ndarray, List[int]]:
29
+ bbox = self.bbox[idx]
30
+ frame_original: np.ndarray = self.frame_array_list[idx]
31
+ # return frame_original[bbox[0]:bbox[1], bbox[2]:bbox[3], :]
32
+ return frame_original, bbox
33
+
34
+
35
+ class AudioSlicer:
36
+ def __init__(self, audio_path: Union[Path, str]):
37
+ self.fps = hp.face.video_fps
38
+ self.mel_chunks = self._audio_chunk_generator(audio_path)
39
+ self._audio_path = audio_path
40
+
41
+ @property
42
+ def audio_path(self):
43
+ return self._audio_path
44
+
45
+ def __len__(self):
46
+ return len(self.mel_chunks)
47
+
48
+ def _audio_chunk_generator(self, audio_path):
49
+ wav: np.ndarray = audio.load_wav(audio_path, hp.audio.sample_rate)
50
+ mel: np.ndarray = audio.melspectrogram(wav)
51
+
52
+ if np.isnan(mel.reshape(-1)).sum() > 0:
53
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
54
+
55
+ mel_chunks: List[np.ndarray] = []
56
+ mel_idx_multiplier = 80. / self.fps
57
+
58
+ i = 0
59
+ while True:
60
+ start_idx = int(i * mel_idx_multiplier)
61
+ if start_idx + hp.face.mel_step_size > len(mel[0]):
62
+ mel_chunks.append(mel[:, len(mel[0]) - hp.face.mel_step_size:])
63
+ return mel_chunks
64
+ mel_chunks.append(mel[:, start_idx: start_idx + hp.face.mel_step_size])
65
+ i += 1
66
+
67
+ def __getitem__(self, idx: int) -> np.ndarray:
68
+ return self.mel_chunks[idx]
preprocess.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from nota_wav2lip.preprocess import get_preprocessed_data
4
+
5
+
6
+ def parse_args():
7
+
8
+ parser = argparse.ArgumentParser(description="NotaWav2Lip: Preprocess the facial video with face detection")
9
+
10
+ parser.add_argument(
11
+ '-i',
12
+ '--input-file',
13
+ type=str,
14
+ required=True,
15
+ help="Path of the facial video. We recommend that the video is one of LRS3 data samples, which is the result of `download.py`."
16
+ "The extracted features and facial image sequences are saved at the same location with the input file."
17
+ )
18
+
19
+ args = parser.parse_args()
20
+
21
+ return args
22
+
23
+ if __name__ == '__main__':
24
+ args = parse_args()
25
+
26
+ get_preprocessed_data(
27
+ args.input_file,
28
+ )