Spaces:
Running
on
Zero
Running
on
Zero
[email protected]
commited on
Commit
·
2b34e02
1
Parent(s):
1c08139
update
Browse files- .gitignore +133 -0
- README.md +4 -4
- app.py +83 -0
- assets/audios/audio files.txt +1 -0
- assets/coords/coord files.txt +1 -0
- assets/coords/sample1.npz +3 -0
- assets/coords/sample2.npz +3 -0
- assets/coords/sample3.npz +3 -0
- assets/coords/sample4.npz +3 -0
- assets/coords/sample5.npz +3 -0
- assets/videos/sample1.mp4 +0 -0
- assets/videos/sample2.mp4 +0 -0
- assets/videos/sample3.mp4 +0 -0
- assets/videos/sample4.mp4 +0 -0
- assets/videos/sample5.mp4 +0 -0
- attributtes_utils.py +47 -0
- audio.py +160 -0
- checkpoints/obama-fp16.safetensors +3 -0
- face_detection/README.md +1 -0
- face_detection/__init__.py +7 -0
- face_detection/api.py +79 -0
- face_detection/detection/__init__.py +1 -0
- face_detection/detection/core.py +130 -0
- face_detection/detection/sfd/__init__.py +1 -0
- face_detection/detection/sfd/bbox.py +129 -0
- face_detection/detection/sfd/detect.py +112 -0
- face_detection/detection/sfd/net_s3fd.py +129 -0
- face_detection/detection/sfd/s3fd.pth +3 -0
- face_detection/detection/sfd/sfd_detector.py +59 -0
- face_detection/models.py +261 -0
- face_detection/utils.py +313 -0
- fete_model.py +296 -0
- inference_util.py +338 -0
- preprocess_videos.py +120 -0
- requirements.txt +10 -0
.gitignore
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
face_detection/face_detection.pb
|
31 |
+
assets/audios/*.wav
|
32 |
+
|
33 |
+
# PyInstaller
|
34 |
+
# Usually these files are written by a python script from a template
|
35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
36 |
+
*.manifest
|
37 |
+
*.spec
|
38 |
+
|
39 |
+
# Installer logs
|
40 |
+
pip-log.txt
|
41 |
+
pip-delete-this-directory.txt
|
42 |
+
|
43 |
+
# Unit test / coverage reports
|
44 |
+
htmlcov/
|
45 |
+
.tox/
|
46 |
+
.nox/
|
47 |
+
.coverage
|
48 |
+
.coverage.*
|
49 |
+
.cache
|
50 |
+
nosetests.xml
|
51 |
+
coverage.xml
|
52 |
+
*.cover
|
53 |
+
*.py,cover
|
54 |
+
.hypothesis/
|
55 |
+
.pytest_cache/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
133 |
+
.vercel
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.41.0
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: FETE
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.41.0
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
from natsort import natsorted
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from inference_util import init_model, infenrece
|
7 |
+
from attributtes_utils import input_pose, input_emotion, input_blink
|
8 |
+
|
9 |
+
model = init_model()
|
10 |
+
|
11 |
+
|
12 |
+
def process(input_vid, audio_path, pose_select, emotion_select, blink_select):
|
13 |
+
pose = input_pose(pose_select)
|
14 |
+
emotion = input_emotion(emotion_select)
|
15 |
+
blink = input_blink(blink_select)
|
16 |
+
|
17 |
+
print("input_vid: ", input_vid)
|
18 |
+
result = infenrece(model, os.path.join("./assets/videos/", input_vid), os.path.join("./assets/audios/", audio_path), pose, emotion, blink)
|
19 |
+
print("result: ", result)
|
20 |
+
|
21 |
+
print("finished !")
|
22 |
+
|
23 |
+
return result # , gr.Group.update(visible=True)
|
24 |
+
|
25 |
+
|
26 |
+
available_videos = natsorted(glob.glob("./assets/videos/*.mp4"))
|
27 |
+
available_videos = [os.path.basename(x) for x in available_videos]
|
28 |
+
|
29 |
+
# prepare audio
|
30 |
+
for video in available_videos:
|
31 |
+
audio = video.replace(".mp4", ".wav")
|
32 |
+
if not os.path.exists(os.path.join("./assets/audios/", audio)):
|
33 |
+
os.system(f"ffmpeg -i ./assets/videos/{video} -vn -acodec pcm_s16le -ar 16000 -ac 1 ./assets/audios/{audio}")
|
34 |
+
available_audios = natsorted(glob.glob("./assets/audios/*.wav"))
|
35 |
+
available_audios = [os.path.basename(x) for x in available_audios]
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
with gr.Blocks() as demo:
|
40 |
+
gr.HTML(
|
41 |
+
"""
|
42 |
+
<h1 style="text-align: center; font-size: 40px; font-family: 'Times New Roman', Times, serif;">
|
43 |
+
Free-View Expressive Talking Head Video Editing
|
44 |
+
</h1>
|
45 |
+
<p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
|
46 |
+
<a href="https://sky24h.github.io/websites/icassp2023_free-view_video-editing/" target="_blank">
|
47 |
+
<b>Project Page</b>
|
48 |
+
<br>
|
49 |
+
</a>
|
50 |
+
<a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/FETE?duplicate=true">
|
51 |
+
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
|
52 |
+
</a>
|
53 |
+
</p>
|
54 |
+
""")
|
55 |
+
with gr.Column(elem_id="col-container"):
|
56 |
+
with gr.Row():
|
57 |
+
with gr.Column():
|
58 |
+
# select and preview video from a list of examples
|
59 |
+
video_preview = gr.Video(label="Video Preview", elem_id="video-preview", height=360, value="./assets/videos/sample1.mp4")
|
60 |
+
video_input = gr.Dropdown(available_videos, label="Input Video", value="sample1.mp4")
|
61 |
+
audio_preview = gr.Audio(label="Audio Preview", elem_id="audio-preview", height=360, value="./assets/audios/sample2.wav")
|
62 |
+
audio_input = gr.Dropdown(available_audios, label="Input Audio", value="sample2.wav")
|
63 |
+
pose_select = gr.Radio(["front", "left_right_shaking"], label="Pose", value="front")
|
64 |
+
emotion_select = gr.Radio(["neutral", "happy", "angry", "surprised"], label="Emotion", value="neutral")
|
65 |
+
blink_select = gr.Radio(["yes", "no"], label="Blink", value="yes")
|
66 |
+
# with gr.Row():
|
67 |
+
with gr.Column():
|
68 |
+
video_out = gr.Video(label="Video Output", elem_id="video-output", height=360)
|
69 |
+
# titile: Free-View Expressive Talking Head Video Editing
|
70 |
+
|
71 |
+
submit_btn = gr.Button("Generate video")
|
72 |
+
|
73 |
+
inputs = [video_input, audio_input, pose_select, emotion_select, blink_select]
|
74 |
+
outputs = [video_out]
|
75 |
+
|
76 |
+
video_preview_output = [video_preview]
|
77 |
+
audio_preview_output = [audio_preview]
|
78 |
+
|
79 |
+
video_input.select(lambda x: "./assets/videos/" + x, video_input, video_preview_output)
|
80 |
+
audio_input.select(lambda x: "./assets/audios/" + x, audio_input, audio_preview_output)
|
81 |
+
submit_btn.click(process, inputs, outputs)
|
82 |
+
|
83 |
+
demo.queue(max_size=12).launch()
|
assets/audios/audio files.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
assets/coords/coord files.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
assets/coords/sample1.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5323c1d6ed9dbda978859aed00bd6b4ec2ca122dbc556821edd4add44181c360
|
3 |
+
size 647
|
assets/coords/sample2.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aaf2424b58f2c32b3b55d172a4437bd73687ca2346e4928575c481fa024cf41d
|
3 |
+
size 1150
|
assets/coords/sample3.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb9c284ec586ffe5cb9af9e715907810ce6f07694ef4245c29bb06646a04f3ed
|
3 |
+
size 839
|
assets/coords/sample4.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa3cd8b488cbd368c22b8d10f072d8a55cc73c3c0e8cea6e11420f2743cc00c4
|
3 |
+
size 901
|
assets/coords/sample5.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ee15a8dd3b47bc036a4502ccd60a0a5c29262a5581593e8f97c11e18b389e67
|
3 |
+
size 974
|
assets/videos/sample1.mp4
ADDED
Binary file (184 kB). View file
|
|
assets/videos/sample2.mp4
ADDED
Binary file (625 kB). View file
|
|
assets/videos/sample3.mp4
ADDED
Binary file (374 kB). View file
|
|
assets/videos/sample4.mp4
ADDED
Binary file (562 kB). View file
|
|
assets/videos/sample5.mp4
ADDED
Binary file (698 kB). View file
|
|
attributtes_utils.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
|
5 |
+
def input_pose(pose_select="front"):
|
6 |
+
step = 1
|
7 |
+
if pose_select == "front":
|
8 |
+
pose = [[0.0, 0.0, 0.0] for i in range(0, 10, step)]#-20 to 20
|
9 |
+
elif pose_select == "left_right_shaking":
|
10 |
+
pose = [[-i, 0.0, 0.0] for i in range(0, 20, step)]#0 to -20
|
11 |
+
pose += [[i - 20.0, 0.0, 0.0] for i in range(0, 40, step)] # -20 to 20
|
12 |
+
pose += [[20.0 - i, 0.0, 0.0] for i in range(0, 20, step)] # 20 to 0
|
13 |
+
pose = pose + pose
|
14 |
+
pose = pose + pose
|
15 |
+
pose = pose + pose
|
16 |
+
# pose = pose + pose[::-1]
|
17 |
+
else:
|
18 |
+
raise ValueError("pose_select Error")
|
19 |
+
|
20 |
+
return pose
|
21 |
+
|
22 |
+
|
23 |
+
EMOTIONS = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
|
24 |
+
def input_emotion(emotion_select="neutral"):
|
25 |
+
sacle_factor = 2
|
26 |
+
if emotion_select == "neutral":
|
27 |
+
emotion = [[0.0,0.0,0.0,0.0,0.0,0.0,1.0] for _ in range(2)]#((i%50))*0.04
|
28 |
+
elif emotion_select == "happy":
|
29 |
+
emotion = [[0.0,0.0,0.0,1.0,0.0,0.0,0.0] for _ in range(2)]#((i%50))*0.04
|
30 |
+
elif emotion_select == "angry":
|
31 |
+
emotion = [[1.0,0.0,0.0,0.0,0.0,0.0,0.0] for _ in range(2)]
|
32 |
+
elif emotion_select == "surprised":
|
33 |
+
emotion = [[0.0,0.0,0.0,0.0,0.0,1.0,0.0] for _ in range(2)]
|
34 |
+
else:
|
35 |
+
raise ValueError("emotion_select Error")
|
36 |
+
|
37 |
+
return emotion * sacle_factor
|
38 |
+
|
39 |
+
|
40 |
+
def input_blink(blink_select="yes"):
|
41 |
+
if blink_select == "yes":
|
42 |
+
blink = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.8], [0.6], [0.0], [0.0], [1.0]]
|
43 |
+
blink = blink + blink
|
44 |
+
else:
|
45 |
+
blink = [[1.0] for _ in range(2)]
|
46 |
+
return blink
|
47 |
+
|
audio.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import librosa.filters
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# import tensorflow as tf
|
6 |
+
from scipy import signal
|
7 |
+
from scipy.io import wavfile
|
8 |
+
|
9 |
+
hp_num_mels = 80
|
10 |
+
hp_rescale = True
|
11 |
+
hp_rescaling_max = 0.9
|
12 |
+
hp_use_lws = False
|
13 |
+
hp_n_fft = 800
|
14 |
+
hp_hop_size = 200
|
15 |
+
hp_win_size = 800
|
16 |
+
hp_sample_rate = 16000
|
17 |
+
hp_frame_shift_ms = None
|
18 |
+
hp_signal_normalization = True
|
19 |
+
hp_allow_clipping_in_normalization = True
|
20 |
+
hp_symmetric_mels = True
|
21 |
+
hp_max_abs_value = 4.0
|
22 |
+
hp_preemphasize = True
|
23 |
+
hp_preemphasis = 0.97
|
24 |
+
hp_min_level_db = -100
|
25 |
+
hp_ref_level_db = 20
|
26 |
+
hp_fmin = 55
|
27 |
+
hp_fmax = 7600
|
28 |
+
|
29 |
+
|
30 |
+
def load_wav(path, sr):
|
31 |
+
return librosa.core.load(path, sr=sr)[0]
|
32 |
+
|
33 |
+
|
34 |
+
def save_wav(wav, path, sr):
|
35 |
+
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
36 |
+
# proposed by @dsmiller
|
37 |
+
wavfile.write(path, sr, wav.astype(np.int16))
|
38 |
+
|
39 |
+
|
40 |
+
def save_wavenet_wav(wav, path, sr):
|
41 |
+
librosa.output.write_wav(path, wav, sr=sr)
|
42 |
+
|
43 |
+
|
44 |
+
def preemphasis(wav, k, preemphasize=True):
|
45 |
+
if preemphasize:
|
46 |
+
return signal.lfilter([1, -k], [1], wav)
|
47 |
+
return wav
|
48 |
+
|
49 |
+
|
50 |
+
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
51 |
+
if inv_preemphasize:
|
52 |
+
return signal.lfilter([1], [1, -k], wav)
|
53 |
+
return wav
|
54 |
+
|
55 |
+
|
56 |
+
def get_hop_size():
|
57 |
+
hop_size = hp_hop_size
|
58 |
+
if hop_size is None:
|
59 |
+
assert hp_frame_shift_ms is not None
|
60 |
+
hop_size = int(hp_frame_shift_ms / 1000 * hp_sample_rate)
|
61 |
+
return hop_size
|
62 |
+
|
63 |
+
|
64 |
+
def linearspectrogram(wav):
|
65 |
+
D = _stft(preemphasis(wav, hp_preemphasis, hp_preemphasize))
|
66 |
+
S = _amp_to_db(np.abs(D)) - hp_ref_level_db
|
67 |
+
if hp_signal_normalization:
|
68 |
+
return _normalize(S)
|
69 |
+
return S
|
70 |
+
|
71 |
+
|
72 |
+
def melspectrogram(wav):
|
73 |
+
D = _stft(preemphasis(wav, hp_preemphasis, hp_preemphasize))
|
74 |
+
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp_ref_level_db
|
75 |
+
if hp_signal_normalization:
|
76 |
+
return _normalize(S)
|
77 |
+
return S
|
78 |
+
|
79 |
+
|
80 |
+
def _lws_processor():
|
81 |
+
import lws
|
82 |
+
|
83 |
+
return lws.lws(hp_n_fft, get_hop_size(), fftsize=hp_win_size, mode="speech")
|
84 |
+
|
85 |
+
|
86 |
+
def _stft(y):
|
87 |
+
if hp_use_lws:
|
88 |
+
return _lws_processor(hp).stft(y).T
|
89 |
+
else:
|
90 |
+
return librosa.stft(y=y, n_fft=hp_n_fft, hop_length=get_hop_size(), win_length=hp_win_size)
|
91 |
+
|
92 |
+
|
93 |
+
##########################################################
|
94 |
+
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
95 |
+
def num_frames(length, fsize, fshift):
|
96 |
+
"""Compute number of time frames of spectrogram"""
|
97 |
+
pad = fsize - fshift
|
98 |
+
if length % fshift == 0:
|
99 |
+
M = (length + pad * 2 - fsize) // fshift + 1
|
100 |
+
else:
|
101 |
+
M = (length + pad * 2 - fsize) // fshift + 2
|
102 |
+
return M
|
103 |
+
|
104 |
+
|
105 |
+
def pad_lr(x, fsize, fshift):
|
106 |
+
"""Compute left and right padding"""
|
107 |
+
M = num_frames(len(x), fsize, fshift)
|
108 |
+
pad = fsize - fshift
|
109 |
+
T = len(x) + 2 * pad
|
110 |
+
r = (M - 1) * fshift + fsize - T
|
111 |
+
return pad, pad + r
|
112 |
+
|
113 |
+
|
114 |
+
##########################################################
|
115 |
+
# Librosa correct padding
|
116 |
+
def librosa_pad_lr(x, fsize, fshift):
|
117 |
+
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
118 |
+
|
119 |
+
|
120 |
+
# Conversions
|
121 |
+
_mel_basis = None
|
122 |
+
|
123 |
+
|
124 |
+
def _linear_to_mel(spectogram):
|
125 |
+
global _mel_basis
|
126 |
+
if _mel_basis is None:
|
127 |
+
_mel_basis = _build_mel_basis()
|
128 |
+
return np.dot(_mel_basis, spectogram)
|
129 |
+
|
130 |
+
|
131 |
+
def _build_mel_basis():
|
132 |
+
assert hp_fmax <= hp_sample_rate // 2
|
133 |
+
return librosa.filters.mel(hp_sample_rate, hp_n_fft, n_mels=hp_num_mels, fmin=hp_fmin, fmax=hp_fmax)
|
134 |
+
|
135 |
+
|
136 |
+
def _amp_to_db(x):
|
137 |
+
min_level = np.exp(hp_min_level_db / 20 * np.log(10))
|
138 |
+
return 20 * np.log10(np.maximum(min_level, x))
|
139 |
+
|
140 |
+
|
141 |
+
def _normalize(S):
|
142 |
+
if hp_allow_clipping_in_normalization:
|
143 |
+
if hp_symmetric_mels:
|
144 |
+
return np.clip(
|
145 |
+
(2 * hp_max_abs_value) * ((S - hp_min_level_db) / (-hp_min_level_db)) - hp_max_abs_value,
|
146 |
+
-hp_max_abs_value,
|
147 |
+
hp_max_abs_value,
|
148 |
+
)
|
149 |
+
else:
|
150 |
+
return np.clip(
|
151 |
+
hp_max_abs_value * ((S - hp_min_level_db) / (-hp_min_level_db)),
|
152 |
+
0,
|
153 |
+
hp_max_abs_value,
|
154 |
+
)
|
155 |
+
|
156 |
+
assert S.max() <= 0 and S.min() - hp_min_level_db >= 0
|
157 |
+
if hp_symmetric_mels:
|
158 |
+
return (2 * hp_max_abs_value) * ((S - hp_min_level_db) / (-hp_min_level_db)) - hp_max_abs_value
|
159 |
+
else:
|
160 |
+
return hp_max_abs_value * ((S - hp_min_level_db) / (-hp_min_level_db))
|
checkpoints/obama-fp16.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a11a77110f573c5d379f7ce0a14b2c85be12c41723bef601c612f06d0d1f7d55
|
3 |
+
size 96452948
|
face_detection/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
face_detection/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
__author__ = """Adrian Bulat"""
|
4 |
+
__email__ = '[email protected]'
|
5 |
+
__version__ = '1.0.1'
|
6 |
+
|
7 |
+
from .api import FaceAlignment, LandmarksType, NetworkSize
|
face_detection/api.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch.utils.model_zoo import load_url
|
5 |
+
from enum import Enum
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
try:
|
9 |
+
import urllib.request as request_file
|
10 |
+
except BaseException:
|
11 |
+
import urllib as request_file
|
12 |
+
|
13 |
+
from .models import FAN, ResNetDepth
|
14 |
+
from .utils import *
|
15 |
+
|
16 |
+
|
17 |
+
class LandmarksType(Enum):
|
18 |
+
"""Enum class defining the type of landmarks to detect.
|
19 |
+
|
20 |
+
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
21 |
+
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
22 |
+
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
23 |
+
|
24 |
+
"""
|
25 |
+
_2D = 1
|
26 |
+
_2halfD = 2
|
27 |
+
_3D = 3
|
28 |
+
|
29 |
+
|
30 |
+
class NetworkSize(Enum):
|
31 |
+
# TINY = 1
|
32 |
+
# SMALL = 2
|
33 |
+
# MEDIUM = 3
|
34 |
+
LARGE = 4
|
35 |
+
|
36 |
+
def __new__(cls, value):
|
37 |
+
member = object.__new__(cls)
|
38 |
+
member._value_ = value
|
39 |
+
return member
|
40 |
+
|
41 |
+
def __int__(self):
|
42 |
+
return self.value
|
43 |
+
|
44 |
+
ROOT = os.path.dirname(os.path.abspath(__file__))
|
45 |
+
|
46 |
+
class FaceAlignment:
|
47 |
+
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
48 |
+
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
49 |
+
self.device = device
|
50 |
+
self.flip_input = flip_input
|
51 |
+
self.landmarks_type = landmarks_type
|
52 |
+
self.verbose = verbose
|
53 |
+
|
54 |
+
network_size = int(network_size)
|
55 |
+
|
56 |
+
if 'cuda' in device:
|
57 |
+
torch.backends.cudnn.benchmark = True
|
58 |
+
|
59 |
+
# Get the face detector
|
60 |
+
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
61 |
+
globals(), locals(), [face_detector], 0)
|
62 |
+
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
63 |
+
|
64 |
+
def get_detections_for_batch(self, images):
|
65 |
+
images = images[..., ::-1]
|
66 |
+
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
67 |
+
results = []
|
68 |
+
|
69 |
+
for i, d in enumerate(detected_faces):
|
70 |
+
if len(d) == 0:
|
71 |
+
results.append(None)
|
72 |
+
continue
|
73 |
+
d = d[0]
|
74 |
+
d = np.clip(d, 0, None)
|
75 |
+
|
76 |
+
x1, y1, x2, y2 = map(int, d[:-1])
|
77 |
+
results.append((x1, y1, x2, y2))
|
78 |
+
|
79 |
+
return results
|
face_detection/detection/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core import FaceDetector
|
face_detection/detection/core.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import glob
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
class FaceDetector(object):
|
10 |
+
"""An abstract class representing a face detector.
|
11 |
+
|
12 |
+
Any other face detection implementation must subclass it. All subclasses
|
13 |
+
must implement ``detect_from_image``, that return a list of detected
|
14 |
+
bounding boxes. Optionally, for speed considerations detect from path is
|
15 |
+
recommended.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, device, verbose):
|
19 |
+
self.device = device
|
20 |
+
self.verbose = verbose
|
21 |
+
|
22 |
+
if verbose:
|
23 |
+
if 'cpu' in device:
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
logger.warning("Detection running on CPU, this may be potentially slow.")
|
26 |
+
|
27 |
+
if 'cpu' not in device and 'cuda' not in device:
|
28 |
+
if verbose:
|
29 |
+
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
30 |
+
raise ValueError
|
31 |
+
|
32 |
+
def detect_from_image(self, tensor_or_path):
|
33 |
+
"""Detects faces in a given image.
|
34 |
+
|
35 |
+
This function detects the faces present in a provided BGR(usually)
|
36 |
+
image. The input can be either the image itself or the path to it.
|
37 |
+
|
38 |
+
Arguments:
|
39 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
40 |
+
to an image or the image itself.
|
41 |
+
|
42 |
+
Example::
|
43 |
+
|
44 |
+
>>> path_to_image = 'data/image_01.jpg'
|
45 |
+
... detected_faces = detect_from_image(path_to_image)
|
46 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
47 |
+
>>> image = cv2.imread(path_to_image)
|
48 |
+
... detected_faces = detect_from_image(image)
|
49 |
+
[A list of bounding boxes (x1, y1, x2, y2)]
|
50 |
+
|
51 |
+
"""
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
55 |
+
"""Detects faces from all the images present in a given directory.
|
56 |
+
|
57 |
+
Arguments:
|
58 |
+
path {string} -- a string containing a path that points to the folder containing the images
|
59 |
+
|
60 |
+
Keyword Arguments:
|
61 |
+
extensions {list} -- list of string containing the extensions to be
|
62 |
+
consider in the following format: ``.extension_name`` (default:
|
63 |
+
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
64 |
+
folder recursively (default: {False}) show_progress_bar {bool} --
|
65 |
+
display a progressbar (default: {True})
|
66 |
+
|
67 |
+
Example:
|
68 |
+
>>> directory = 'data'
|
69 |
+
... detected_faces = detect_from_directory(directory)
|
70 |
+
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
71 |
+
|
72 |
+
"""
|
73 |
+
if self.verbose:
|
74 |
+
logger = logging.getLogger(__name__)
|
75 |
+
|
76 |
+
if len(extensions) == 0:
|
77 |
+
if self.verbose:
|
78 |
+
logger.error("Expected at list one extension, but none was received.")
|
79 |
+
raise ValueError
|
80 |
+
|
81 |
+
if self.verbose:
|
82 |
+
logger.info("Constructing the list of images.")
|
83 |
+
additional_pattern = '/**/*' if recursive else '/*'
|
84 |
+
files = []
|
85 |
+
for extension in extensions:
|
86 |
+
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
87 |
+
|
88 |
+
if self.verbose:
|
89 |
+
logger.info("Finished searching for images. %s images found", len(files))
|
90 |
+
logger.info("Preparing to run the detection.")
|
91 |
+
|
92 |
+
predictions = {}
|
93 |
+
for image_path in tqdm(files, disable=not show_progress_bar):
|
94 |
+
if self.verbose:
|
95 |
+
logger.info("Running the face detector on image: %s", image_path)
|
96 |
+
predictions[image_path] = self.detect_from_image(image_path)
|
97 |
+
|
98 |
+
if self.verbose:
|
99 |
+
logger.info("The detector was successfully run on all %s images", len(files))
|
100 |
+
|
101 |
+
return predictions
|
102 |
+
|
103 |
+
@property
|
104 |
+
def reference_scale(self):
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
@property
|
108 |
+
def reference_x_shift(self):
|
109 |
+
raise NotImplementedError
|
110 |
+
|
111 |
+
@property
|
112 |
+
def reference_y_shift(self):
|
113 |
+
raise NotImplementedError
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
117 |
+
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
118 |
+
|
119 |
+
Arguments:
|
120 |
+
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
121 |
+
"""
|
122 |
+
if isinstance(tensor_or_path, str):
|
123 |
+
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
124 |
+
elif torch.is_tensor(tensor_or_path):
|
125 |
+
# Call cpu in case its coming from cuda
|
126 |
+
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
127 |
+
elif isinstance(tensor_or_path, np.ndarray):
|
128 |
+
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
129 |
+
else:
|
130 |
+
raise TypeError
|
face_detection/detection/sfd/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sfd_detector import SFDDetector as FaceDetector
|
face_detection/detection/sfd/bbox.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import cv2
|
5 |
+
import random
|
6 |
+
import datetime
|
7 |
+
import time
|
8 |
+
import math
|
9 |
+
import argparse
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
try:
|
14 |
+
from iou import IOU
|
15 |
+
except BaseException:
|
16 |
+
# IOU cython speedup 10x
|
17 |
+
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
18 |
+
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
19 |
+
sb = abs((bx2 - bx1) * (by2 - by1))
|
20 |
+
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
21 |
+
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
22 |
+
w = x2 - x1
|
23 |
+
h = y2 - y1
|
24 |
+
if w < 0 or h < 0:
|
25 |
+
return 0.0
|
26 |
+
else:
|
27 |
+
return 1.0 * w * h / (sa + sb - w * h)
|
28 |
+
|
29 |
+
|
30 |
+
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
31 |
+
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
32 |
+
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
33 |
+
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
34 |
+
return dx, dy, dw, dh
|
35 |
+
|
36 |
+
|
37 |
+
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
38 |
+
xc, yc = dx * aww + axc, dy * ahh + ayc
|
39 |
+
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
40 |
+
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
41 |
+
return x1, y1, x2, y2
|
42 |
+
|
43 |
+
|
44 |
+
def nms(dets, thresh):
|
45 |
+
if 0 == len(dets):
|
46 |
+
return []
|
47 |
+
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
48 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
49 |
+
order = scores.argsort()[::-1]
|
50 |
+
|
51 |
+
keep = []
|
52 |
+
while order.size > 0:
|
53 |
+
i = order[0]
|
54 |
+
keep.append(i)
|
55 |
+
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
56 |
+
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
57 |
+
|
58 |
+
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
59 |
+
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
60 |
+
|
61 |
+
inds = np.where(ovr <= thresh)[0]
|
62 |
+
order = order[inds + 1]
|
63 |
+
|
64 |
+
return keep
|
65 |
+
|
66 |
+
|
67 |
+
def encode(matched, priors, variances):
|
68 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
69 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
70 |
+
Args:
|
71 |
+
matched: (tensor) Coords of ground truth for each prior in point-form
|
72 |
+
Shape: [num_priors, 4].
|
73 |
+
priors: (tensor) Prior boxes in center-offset form
|
74 |
+
Shape: [num_priors,4].
|
75 |
+
variances: (list[float]) Variances of priorboxes
|
76 |
+
Return:
|
77 |
+
encoded boxes (tensor), Shape: [num_priors, 4]
|
78 |
+
"""
|
79 |
+
|
80 |
+
# dist b/t match center and prior's center
|
81 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
82 |
+
# encode variance
|
83 |
+
g_cxcy /= (variances[0] * priors[:, 2:])
|
84 |
+
# match wh / prior wh
|
85 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
86 |
+
g_wh = torch.log(g_wh) / variances[1]
|
87 |
+
# return target for smooth_l1_loss
|
88 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
89 |
+
|
90 |
+
|
91 |
+
def decode(loc, priors, variances):
|
92 |
+
"""Decode locations from predictions using priors to undo
|
93 |
+
the encoding we did for offset regression at train time.
|
94 |
+
Args:
|
95 |
+
loc (tensor): location predictions for loc layers,
|
96 |
+
Shape: [num_priors,4]
|
97 |
+
priors (tensor): Prior boxes in center-offset form.
|
98 |
+
Shape: [num_priors,4].
|
99 |
+
variances: (list[float]) Variances of priorboxes
|
100 |
+
Return:
|
101 |
+
decoded bounding box predictions
|
102 |
+
"""
|
103 |
+
|
104 |
+
boxes = torch.cat((
|
105 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
106 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
107 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
108 |
+
boxes[:, 2:] += boxes[:, :2]
|
109 |
+
return boxes
|
110 |
+
|
111 |
+
def batch_decode(loc, priors, variances):
|
112 |
+
"""Decode locations from predictions using priors to undo
|
113 |
+
the encoding we did for offset regression at train time.
|
114 |
+
Args:
|
115 |
+
loc (tensor): location predictions for loc layers,
|
116 |
+
Shape: [num_priors,4]
|
117 |
+
priors (tensor): Prior boxes in center-offset form.
|
118 |
+
Shape: [num_priors,4].
|
119 |
+
variances: (list[float]) Variances of priorboxes
|
120 |
+
Return:
|
121 |
+
decoded bounding box predictions
|
122 |
+
"""
|
123 |
+
|
124 |
+
boxes = torch.cat((
|
125 |
+
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
126 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
127 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
128 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
129 |
+
return boxes
|
face_detection/detection/sfd/detect.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import cv2
|
7 |
+
import random
|
8 |
+
import datetime
|
9 |
+
import math
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import scipy.io as sio
|
14 |
+
import zipfile
|
15 |
+
from .net_s3fd import s3fd
|
16 |
+
from .bbox import *
|
17 |
+
|
18 |
+
|
19 |
+
def detect(net, img, device):
|
20 |
+
img = img - np.array([104, 117, 123])
|
21 |
+
img = img.transpose(2, 0, 1)
|
22 |
+
img = img.reshape((1,) + img.shape)
|
23 |
+
|
24 |
+
if 'cuda' in device:
|
25 |
+
torch.backends.cudnn.benchmark = True
|
26 |
+
|
27 |
+
img = torch.from_numpy(img).float().to(device)
|
28 |
+
BB, CC, HH, WW = img.size()
|
29 |
+
with torch.no_grad():
|
30 |
+
olist = net(img)
|
31 |
+
|
32 |
+
bboxlist = []
|
33 |
+
for i in range(len(olist) // 2):
|
34 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
35 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
36 |
+
for i in range(len(olist) // 2):
|
37 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
38 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
39 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
40 |
+
anchor = stride * 4
|
41 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
42 |
+
for Iindex, hindex, windex in poss:
|
43 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
44 |
+
score = ocls[0, 1, hindex, windex]
|
45 |
+
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
46 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
47 |
+
variances = [0.1, 0.2]
|
48 |
+
box = decode(loc, priors, variances)
|
49 |
+
x1, y1, x2, y2 = box[0] * 1.0
|
50 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
51 |
+
bboxlist.append([x1, y1, x2, y2, score])
|
52 |
+
bboxlist = np.array(bboxlist)
|
53 |
+
if 0 == len(bboxlist):
|
54 |
+
bboxlist = np.zeros((1, 5))
|
55 |
+
|
56 |
+
return bboxlist
|
57 |
+
|
58 |
+
def batch_detect(net, imgs, device):
|
59 |
+
imgs = imgs - np.array([104, 117, 123])
|
60 |
+
imgs = imgs.transpose(0, 3, 1, 2)
|
61 |
+
|
62 |
+
if 'cuda' in device:
|
63 |
+
torch.backends.cudnn.benchmark = True
|
64 |
+
|
65 |
+
imgs = torch.from_numpy(imgs).float().to(device)
|
66 |
+
BB, CC, HH, WW = imgs.size()
|
67 |
+
with torch.no_grad():
|
68 |
+
olist = net(imgs)
|
69 |
+
|
70 |
+
bboxlist = []
|
71 |
+
for i in range(len(olist) // 2):
|
72 |
+
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
73 |
+
olist = [oelem.data.cpu() for oelem in olist]
|
74 |
+
for i in range(len(olist) // 2):
|
75 |
+
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
76 |
+
FB, FC, FH, FW = ocls.size() # feature map size
|
77 |
+
stride = 2**(i + 2) # 4,8,16,32,64,128
|
78 |
+
anchor = stride * 4
|
79 |
+
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
80 |
+
for Iindex, hindex, windex in poss:
|
81 |
+
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
82 |
+
score = ocls[:, 1, hindex, windex]
|
83 |
+
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
84 |
+
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
85 |
+
variances = [0.1, 0.2]
|
86 |
+
box = batch_decode(loc, priors, variances)
|
87 |
+
box = box[:, 0] * 1.0
|
88 |
+
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
89 |
+
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
90 |
+
bboxlist = np.array(bboxlist)
|
91 |
+
if 0 == len(bboxlist):
|
92 |
+
bboxlist = np.zeros((1, BB, 5))
|
93 |
+
|
94 |
+
return bboxlist
|
95 |
+
|
96 |
+
def flip_detect(net, img, device):
|
97 |
+
img = cv2.flip(img, 1)
|
98 |
+
b = detect(net, img, device)
|
99 |
+
|
100 |
+
bboxlist = np.zeros(b.shape)
|
101 |
+
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
102 |
+
bboxlist[:, 1] = b[:, 1]
|
103 |
+
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
104 |
+
bboxlist[:, 3] = b[:, 3]
|
105 |
+
bboxlist[:, 4] = b[:, 4]
|
106 |
+
return bboxlist
|
107 |
+
|
108 |
+
|
109 |
+
def pts_to_bb(pts):
|
110 |
+
min_x, min_y = np.min(pts, axis=0)
|
111 |
+
max_x, max_y = np.max(pts, axis=0)
|
112 |
+
return np.array([min_x, min_y, max_x, max_y])
|
face_detection/detection/sfd/net_s3fd.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class L2Norm(nn.Module):
|
7 |
+
def __init__(self, n_channels, scale=1.0):
|
8 |
+
super(L2Norm, self).__init__()
|
9 |
+
self.n_channels = n_channels
|
10 |
+
self.scale = scale
|
11 |
+
self.eps = 1e-10
|
12 |
+
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
13 |
+
self.weight.data *= 0.0
|
14 |
+
self.weight.data += self.scale
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
18 |
+
x = x / norm * self.weight.view(1, -1, 1, 1)
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class s3fd(nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(s3fd, self).__init__()
|
25 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
26 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
27 |
+
|
28 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
29 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
30 |
+
|
31 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
32 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
33 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
36 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
37 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
38 |
+
|
39 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
40 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
41 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
42 |
+
|
43 |
+
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
44 |
+
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
45 |
+
|
46 |
+
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
47 |
+
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
48 |
+
|
49 |
+
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
50 |
+
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
51 |
+
|
52 |
+
self.conv3_3_norm = L2Norm(256, scale=10)
|
53 |
+
self.conv4_3_norm = L2Norm(512, scale=8)
|
54 |
+
self.conv5_3_norm = L2Norm(512, scale=5)
|
55 |
+
|
56 |
+
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
57 |
+
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
58 |
+
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
59 |
+
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
60 |
+
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
61 |
+
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
62 |
+
|
63 |
+
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
64 |
+
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
65 |
+
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
66 |
+
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
67 |
+
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
68 |
+
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
h = F.relu(self.conv1_1(x))
|
72 |
+
h = F.relu(self.conv1_2(h))
|
73 |
+
h = F.max_pool2d(h, 2, 2)
|
74 |
+
|
75 |
+
h = F.relu(self.conv2_1(h))
|
76 |
+
h = F.relu(self.conv2_2(h))
|
77 |
+
h = F.max_pool2d(h, 2, 2)
|
78 |
+
|
79 |
+
h = F.relu(self.conv3_1(h))
|
80 |
+
h = F.relu(self.conv3_2(h))
|
81 |
+
h = F.relu(self.conv3_3(h))
|
82 |
+
f3_3 = h
|
83 |
+
h = F.max_pool2d(h, 2, 2)
|
84 |
+
|
85 |
+
h = F.relu(self.conv4_1(h))
|
86 |
+
h = F.relu(self.conv4_2(h))
|
87 |
+
h = F.relu(self.conv4_3(h))
|
88 |
+
f4_3 = h
|
89 |
+
h = F.max_pool2d(h, 2, 2)
|
90 |
+
|
91 |
+
h = F.relu(self.conv5_1(h))
|
92 |
+
h = F.relu(self.conv5_2(h))
|
93 |
+
h = F.relu(self.conv5_3(h))
|
94 |
+
f5_3 = h
|
95 |
+
h = F.max_pool2d(h, 2, 2)
|
96 |
+
|
97 |
+
h = F.relu(self.fc6(h))
|
98 |
+
h = F.relu(self.fc7(h))
|
99 |
+
ffc7 = h
|
100 |
+
h = F.relu(self.conv6_1(h))
|
101 |
+
h = F.relu(self.conv6_2(h))
|
102 |
+
f6_2 = h
|
103 |
+
h = F.relu(self.conv7_1(h))
|
104 |
+
h = F.relu(self.conv7_2(h))
|
105 |
+
f7_2 = h
|
106 |
+
|
107 |
+
f3_3 = self.conv3_3_norm(f3_3)
|
108 |
+
f4_3 = self.conv4_3_norm(f4_3)
|
109 |
+
f5_3 = self.conv5_3_norm(f5_3)
|
110 |
+
|
111 |
+
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
112 |
+
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
113 |
+
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
114 |
+
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
115 |
+
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
116 |
+
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
117 |
+
cls4 = self.fc7_mbox_conf(ffc7)
|
118 |
+
reg4 = self.fc7_mbox_loc(ffc7)
|
119 |
+
cls5 = self.conv6_2_mbox_conf(f6_2)
|
120 |
+
reg5 = self.conv6_2_mbox_loc(f6_2)
|
121 |
+
cls6 = self.conv7_2_mbox_conf(f7_2)
|
122 |
+
reg6 = self.conv7_2_mbox_loc(f7_2)
|
123 |
+
|
124 |
+
# max-out background label
|
125 |
+
chunk = torch.chunk(cls1, 4, 1)
|
126 |
+
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
127 |
+
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
128 |
+
|
129 |
+
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
face_detection/detection/sfd/s3fd.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543
|
3 |
+
size 89843225
|
face_detection/detection/sfd/sfd_detector.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
from torch.utils.model_zoo import load_url
|
4 |
+
|
5 |
+
from ..core import FaceDetector
|
6 |
+
|
7 |
+
from .net_s3fd import s3fd
|
8 |
+
from .bbox import *
|
9 |
+
from .detect import *
|
10 |
+
|
11 |
+
models_urls = {
|
12 |
+
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class SFDDetector(FaceDetector):
|
17 |
+
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
18 |
+
super(SFDDetector, self).__init__(device, verbose)
|
19 |
+
|
20 |
+
# Initialise the face detector
|
21 |
+
if not os.path.isfile(path_to_detector):
|
22 |
+
model_weights = load_url(models_urls['s3fd'])
|
23 |
+
else:
|
24 |
+
model_weights = torch.load(path_to_detector)
|
25 |
+
|
26 |
+
self.face_detector = s3fd()
|
27 |
+
self.face_detector.load_state_dict(model_weights)
|
28 |
+
self.face_detector.to(device)
|
29 |
+
self.face_detector.eval()
|
30 |
+
|
31 |
+
def detect_from_image(self, tensor_or_path):
|
32 |
+
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
33 |
+
|
34 |
+
bboxlist = detect(self.face_detector, image, device=self.device)
|
35 |
+
keep = nms(bboxlist, 0.3)
|
36 |
+
bboxlist = bboxlist[keep, :]
|
37 |
+
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
38 |
+
|
39 |
+
return bboxlist
|
40 |
+
|
41 |
+
def detect_from_batch(self, images):
|
42 |
+
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
43 |
+
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
44 |
+
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
45 |
+
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
46 |
+
|
47 |
+
return bboxlists
|
48 |
+
|
49 |
+
@property
|
50 |
+
def reference_scale(self):
|
51 |
+
return 195
|
52 |
+
|
53 |
+
@property
|
54 |
+
def reference_x_shift(self):
|
55 |
+
return 0
|
56 |
+
|
57 |
+
@property
|
58 |
+
def reference_y_shift(self):
|
59 |
+
return 0
|
face_detection/models.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
8 |
+
"3x3 convolution with padding"
|
9 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
10 |
+
stride=strd, padding=padding, bias=bias)
|
11 |
+
|
12 |
+
|
13 |
+
class ConvBlock(nn.Module):
|
14 |
+
def __init__(self, in_planes, out_planes):
|
15 |
+
super(ConvBlock, self).__init__()
|
16 |
+
self.bn1 = nn.BatchNorm2d(in_planes)
|
17 |
+
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
18 |
+
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
19 |
+
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
20 |
+
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
21 |
+
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
22 |
+
|
23 |
+
if in_planes != out_planes:
|
24 |
+
self.downsample = nn.Sequential(
|
25 |
+
nn.BatchNorm2d(in_planes),
|
26 |
+
nn.ReLU(True),
|
27 |
+
nn.Conv2d(in_planes, out_planes,
|
28 |
+
kernel_size=1, stride=1, bias=False),
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
self.downsample = None
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
residual = x
|
35 |
+
|
36 |
+
out1 = self.bn1(x)
|
37 |
+
out1 = F.relu(out1, True)
|
38 |
+
out1 = self.conv1(out1)
|
39 |
+
|
40 |
+
out2 = self.bn2(out1)
|
41 |
+
out2 = F.relu(out2, True)
|
42 |
+
out2 = self.conv2(out2)
|
43 |
+
|
44 |
+
out3 = self.bn3(out2)
|
45 |
+
out3 = F.relu(out3, True)
|
46 |
+
out3 = self.conv3(out3)
|
47 |
+
|
48 |
+
out3 = torch.cat((out1, out2, out3), 1)
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
residual = self.downsample(residual)
|
52 |
+
|
53 |
+
out3 += residual
|
54 |
+
|
55 |
+
return out3
|
56 |
+
|
57 |
+
|
58 |
+
class Bottleneck(nn.Module):
|
59 |
+
|
60 |
+
expansion = 4
|
61 |
+
|
62 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
63 |
+
super(Bottleneck, self).__init__()
|
64 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
65 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
66 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
67 |
+
padding=1, bias=False)
|
68 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
69 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
70 |
+
self.bn3 = nn.BatchNorm2d(planes * 4)
|
71 |
+
self.relu = nn.ReLU(inplace=True)
|
72 |
+
self.downsample = downsample
|
73 |
+
self.stride = stride
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
residual = x
|
77 |
+
|
78 |
+
out = self.conv1(x)
|
79 |
+
out = self.bn1(out)
|
80 |
+
out = self.relu(out)
|
81 |
+
|
82 |
+
out = self.conv2(out)
|
83 |
+
out = self.bn2(out)
|
84 |
+
out = self.relu(out)
|
85 |
+
|
86 |
+
out = self.conv3(out)
|
87 |
+
out = self.bn3(out)
|
88 |
+
|
89 |
+
if self.downsample is not None:
|
90 |
+
residual = self.downsample(x)
|
91 |
+
|
92 |
+
out += residual
|
93 |
+
out = self.relu(out)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
|
98 |
+
class HourGlass(nn.Module):
|
99 |
+
def __init__(self, num_modules, depth, num_features):
|
100 |
+
super(HourGlass, self).__init__()
|
101 |
+
self.num_modules = num_modules
|
102 |
+
self.depth = depth
|
103 |
+
self.features = num_features
|
104 |
+
|
105 |
+
self._generate_network(self.depth)
|
106 |
+
|
107 |
+
def _generate_network(self, level):
|
108 |
+
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
109 |
+
|
110 |
+
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
111 |
+
|
112 |
+
if level > 1:
|
113 |
+
self._generate_network(level - 1)
|
114 |
+
else:
|
115 |
+
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
116 |
+
|
117 |
+
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
118 |
+
|
119 |
+
def _forward(self, level, inp):
|
120 |
+
# Upper branch
|
121 |
+
up1 = inp
|
122 |
+
up1 = self._modules['b1_' + str(level)](up1)
|
123 |
+
|
124 |
+
# Lower branch
|
125 |
+
low1 = F.avg_pool2d(inp, 2, stride=2)
|
126 |
+
low1 = self._modules['b2_' + str(level)](low1)
|
127 |
+
|
128 |
+
if level > 1:
|
129 |
+
low2 = self._forward(level - 1, low1)
|
130 |
+
else:
|
131 |
+
low2 = low1
|
132 |
+
low2 = self._modules['b2_plus_' + str(level)](low2)
|
133 |
+
|
134 |
+
low3 = low2
|
135 |
+
low3 = self._modules['b3_' + str(level)](low3)
|
136 |
+
|
137 |
+
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
138 |
+
|
139 |
+
return up1 + up2
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
return self._forward(self.depth, x)
|
143 |
+
|
144 |
+
|
145 |
+
class FAN(nn.Module):
|
146 |
+
|
147 |
+
def __init__(self, num_modules=1):
|
148 |
+
super(FAN, self).__init__()
|
149 |
+
self.num_modules = num_modules
|
150 |
+
|
151 |
+
# Base part
|
152 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
153 |
+
self.bn1 = nn.BatchNorm2d(64)
|
154 |
+
self.conv2 = ConvBlock(64, 128)
|
155 |
+
self.conv3 = ConvBlock(128, 128)
|
156 |
+
self.conv4 = ConvBlock(128, 256)
|
157 |
+
|
158 |
+
# Stacking part
|
159 |
+
for hg_module in range(self.num_modules):
|
160 |
+
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
161 |
+
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
162 |
+
self.add_module('conv_last' + str(hg_module),
|
163 |
+
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
164 |
+
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
165 |
+
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
166 |
+
68, kernel_size=1, stride=1, padding=0))
|
167 |
+
|
168 |
+
if hg_module < self.num_modules - 1:
|
169 |
+
self.add_module(
|
170 |
+
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
171 |
+
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
172 |
+
256, kernel_size=1, stride=1, padding=0))
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
x = F.relu(self.bn1(self.conv1(x)), True)
|
176 |
+
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
177 |
+
x = self.conv3(x)
|
178 |
+
x = self.conv4(x)
|
179 |
+
|
180 |
+
previous = x
|
181 |
+
|
182 |
+
outputs = []
|
183 |
+
for i in range(self.num_modules):
|
184 |
+
hg = self._modules['m' + str(i)](previous)
|
185 |
+
|
186 |
+
ll = hg
|
187 |
+
ll = self._modules['top_m_' + str(i)](ll)
|
188 |
+
|
189 |
+
ll = F.relu(self._modules['bn_end' + str(i)]
|
190 |
+
(self._modules['conv_last' + str(i)](ll)), True)
|
191 |
+
|
192 |
+
# Predict heatmaps
|
193 |
+
tmp_out = self._modules['l' + str(i)](ll)
|
194 |
+
outputs.append(tmp_out)
|
195 |
+
|
196 |
+
if i < self.num_modules - 1:
|
197 |
+
ll = self._modules['bl' + str(i)](ll)
|
198 |
+
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
199 |
+
previous = previous + ll + tmp_out_
|
200 |
+
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class ResNetDepth(nn.Module):
|
205 |
+
|
206 |
+
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
207 |
+
self.inplanes = 64
|
208 |
+
super(ResNetDepth, self).__init__()
|
209 |
+
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
210 |
+
bias=False)
|
211 |
+
self.bn1 = nn.BatchNorm2d(64)
|
212 |
+
self.relu = nn.ReLU(inplace=True)
|
213 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
214 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
215 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
216 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
217 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
218 |
+
self.avgpool = nn.AvgPool2d(7)
|
219 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
220 |
+
|
221 |
+
for m in self.modules():
|
222 |
+
if isinstance(m, nn.Conv2d):
|
223 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
224 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
225 |
+
elif isinstance(m, nn.BatchNorm2d):
|
226 |
+
m.weight.data.fill_(1)
|
227 |
+
m.bias.data.zero_()
|
228 |
+
|
229 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
230 |
+
downsample = None
|
231 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
232 |
+
downsample = nn.Sequential(
|
233 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
234 |
+
kernel_size=1, stride=stride, bias=False),
|
235 |
+
nn.BatchNorm2d(planes * block.expansion),
|
236 |
+
)
|
237 |
+
|
238 |
+
layers = []
|
239 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
240 |
+
self.inplanes = planes * block.expansion
|
241 |
+
for i in range(1, blocks):
|
242 |
+
layers.append(block(self.inplanes, planes))
|
243 |
+
|
244 |
+
return nn.Sequential(*layers)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
x = self.conv1(x)
|
248 |
+
x = self.bn1(x)
|
249 |
+
x = self.relu(x)
|
250 |
+
x = self.maxpool(x)
|
251 |
+
|
252 |
+
x = self.layer1(x)
|
253 |
+
x = self.layer2(x)
|
254 |
+
x = self.layer3(x)
|
255 |
+
x = self.layer4(x)
|
256 |
+
|
257 |
+
x = self.avgpool(x)
|
258 |
+
x = x.view(x.size(0), -1)
|
259 |
+
x = self.fc(x)
|
260 |
+
|
261 |
+
return x
|
face_detection/utils.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
def _gaussian(
|
12 |
+
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
13 |
+
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
14 |
+
mean_vert=0.5):
|
15 |
+
# handle some defaults
|
16 |
+
if width is None:
|
17 |
+
width = size
|
18 |
+
if height is None:
|
19 |
+
height = size
|
20 |
+
if sigma_horz is None:
|
21 |
+
sigma_horz = sigma
|
22 |
+
if sigma_vert is None:
|
23 |
+
sigma_vert = sigma
|
24 |
+
center_x = mean_horz * width + 0.5
|
25 |
+
center_y = mean_vert * height + 0.5
|
26 |
+
gauss = np.empty((height, width), dtype=np.float32)
|
27 |
+
# generate kernel
|
28 |
+
for i in range(height):
|
29 |
+
for j in range(width):
|
30 |
+
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
31 |
+
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
32 |
+
if normalize:
|
33 |
+
gauss = gauss / np.sum(gauss)
|
34 |
+
return gauss
|
35 |
+
|
36 |
+
|
37 |
+
def draw_gaussian(image, point, sigma):
|
38 |
+
# Check if the gaussian is inside
|
39 |
+
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
40 |
+
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
41 |
+
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
42 |
+
return image
|
43 |
+
size = 6 * sigma + 1
|
44 |
+
g = _gaussian(size)
|
45 |
+
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
46 |
+
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
47 |
+
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
48 |
+
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
49 |
+
assert (g_x[0] > 0 and g_y[1] > 0)
|
50 |
+
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
51 |
+
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
52 |
+
image[image > 1] = 1
|
53 |
+
return image
|
54 |
+
|
55 |
+
|
56 |
+
def transform(point, center, scale, resolution, invert=False):
|
57 |
+
"""Generate and affine transformation matrix.
|
58 |
+
|
59 |
+
Given a set of points, a center, a scale and a targer resolution, the
|
60 |
+
function generates and affine transformation matrix. If invert is ``True``
|
61 |
+
it will produce the inverse transformation.
|
62 |
+
|
63 |
+
Arguments:
|
64 |
+
point {torch.tensor} -- the input 2D point
|
65 |
+
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
66 |
+
scale {float} -- the scale of the face/object
|
67 |
+
resolution {float} -- the output resolution
|
68 |
+
|
69 |
+
Keyword Arguments:
|
70 |
+
invert {bool} -- define wherever the function should produce the direct or the
|
71 |
+
inverse transformation matrix (default: {False})
|
72 |
+
"""
|
73 |
+
_pt = torch.ones(3)
|
74 |
+
_pt[0] = point[0]
|
75 |
+
_pt[1] = point[1]
|
76 |
+
|
77 |
+
h = 200.0 * scale
|
78 |
+
t = torch.eye(3)
|
79 |
+
t[0, 0] = resolution / h
|
80 |
+
t[1, 1] = resolution / h
|
81 |
+
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
82 |
+
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
83 |
+
|
84 |
+
if invert:
|
85 |
+
t = torch.inverse(t)
|
86 |
+
|
87 |
+
new_point = (torch.matmul(t, _pt))[0:2]
|
88 |
+
|
89 |
+
return new_point.int()
|
90 |
+
|
91 |
+
|
92 |
+
def crop(image, center, scale, resolution=256.0):
|
93 |
+
"""Center crops an image or set of heatmaps
|
94 |
+
|
95 |
+
Arguments:
|
96 |
+
image {numpy.array} -- an rgb image
|
97 |
+
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
98 |
+
scale {float} -- scale of the face
|
99 |
+
|
100 |
+
Keyword Arguments:
|
101 |
+
resolution {float} -- the size of the output cropped image (default: {256.0})
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
[type] -- [description]
|
105 |
+
""" # Crop around the center point
|
106 |
+
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
107 |
+
ul = transform([1, 1], center, scale, resolution, True)
|
108 |
+
br = transform([resolution, resolution], center, scale, resolution, True)
|
109 |
+
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
110 |
+
if image.ndim > 2:
|
111 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
112 |
+
image.shape[2]], dtype=np.int32)
|
113 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
114 |
+
else:
|
115 |
+
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
116 |
+
newImg = np.zeros(newDim, dtype=np.uint8)
|
117 |
+
ht = image.shape[0]
|
118 |
+
wd = image.shape[1]
|
119 |
+
newX = np.array(
|
120 |
+
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
121 |
+
newY = np.array(
|
122 |
+
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
123 |
+
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
124 |
+
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
125 |
+
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
126 |
+
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
127 |
+
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
128 |
+
interpolation=cv2.INTER_LINEAR)
|
129 |
+
return newImg
|
130 |
+
|
131 |
+
|
132 |
+
def get_preds_fromhm(hm, center=None, scale=None):
|
133 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
134 |
+
and the scale is provided the function will return the points also in
|
135 |
+
the original coordinate frame.
|
136 |
+
|
137 |
+
Arguments:
|
138 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
139 |
+
|
140 |
+
Keyword Arguments:
|
141 |
+
center {torch.tensor} -- the center of the bounding box (default: {None})
|
142 |
+
scale {float} -- face scale (default: {None})
|
143 |
+
"""
|
144 |
+
max, idx = torch.max(
|
145 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
146 |
+
idx += 1
|
147 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
148 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
149 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
150 |
+
|
151 |
+
for i in range(preds.size(0)):
|
152 |
+
for j in range(preds.size(1)):
|
153 |
+
hm_ = hm[i, j, :]
|
154 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
155 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
156 |
+
diff = torch.FloatTensor(
|
157 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
158 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
159 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
160 |
+
|
161 |
+
preds.add_(-.5)
|
162 |
+
|
163 |
+
preds_orig = torch.zeros(preds.size())
|
164 |
+
if center is not None and scale is not None:
|
165 |
+
for i in range(hm.size(0)):
|
166 |
+
for j in range(hm.size(1)):
|
167 |
+
preds_orig[i, j] = transform(
|
168 |
+
preds[i, j], center, scale, hm.size(2), True)
|
169 |
+
|
170 |
+
return preds, preds_orig
|
171 |
+
|
172 |
+
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
173 |
+
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
174 |
+
and the scales is provided the function will return the points also in
|
175 |
+
the original coordinate frame.
|
176 |
+
|
177 |
+
Arguments:
|
178 |
+
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
179 |
+
|
180 |
+
Keyword Arguments:
|
181 |
+
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
182 |
+
scales {float} -- face scales (default: {None})
|
183 |
+
"""
|
184 |
+
max, idx = torch.max(
|
185 |
+
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
186 |
+
idx += 1
|
187 |
+
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
188 |
+
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
189 |
+
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
190 |
+
|
191 |
+
for i in range(preds.size(0)):
|
192 |
+
for j in range(preds.size(1)):
|
193 |
+
hm_ = hm[i, j, :]
|
194 |
+
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
195 |
+
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
196 |
+
diff = torch.FloatTensor(
|
197 |
+
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
198 |
+
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
199 |
+
preds[i, j].add_(diff.sign_().mul_(.25))
|
200 |
+
|
201 |
+
preds.add_(-.5)
|
202 |
+
|
203 |
+
preds_orig = torch.zeros(preds.size())
|
204 |
+
if centers is not None and scales is not None:
|
205 |
+
for i in range(hm.size(0)):
|
206 |
+
for j in range(hm.size(1)):
|
207 |
+
preds_orig[i, j] = transform(
|
208 |
+
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
209 |
+
|
210 |
+
return preds, preds_orig
|
211 |
+
|
212 |
+
def shuffle_lr(parts, pairs=None):
|
213 |
+
"""Shuffle the points left-right according to the axis of symmetry
|
214 |
+
of the object.
|
215 |
+
|
216 |
+
Arguments:
|
217 |
+
parts {torch.tensor} -- a 3D or 4D object containing the
|
218 |
+
heatmaps.
|
219 |
+
|
220 |
+
Keyword Arguments:
|
221 |
+
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
222 |
+
"""
|
223 |
+
if pairs is None:
|
224 |
+
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
225 |
+
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
226 |
+
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
227 |
+
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
228 |
+
62, 61, 60, 67, 66, 65]
|
229 |
+
if parts.ndimension() == 3:
|
230 |
+
parts = parts[pairs, ...]
|
231 |
+
else:
|
232 |
+
parts = parts[:, pairs, ...]
|
233 |
+
|
234 |
+
return parts
|
235 |
+
|
236 |
+
|
237 |
+
def flip(tensor, is_label=False):
|
238 |
+
"""Flip an image or a set of heatmaps left-right
|
239 |
+
|
240 |
+
Arguments:
|
241 |
+
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
242 |
+
|
243 |
+
Keyword Arguments:
|
244 |
+
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
245 |
+
"""
|
246 |
+
if not torch.is_tensor(tensor):
|
247 |
+
tensor = torch.from_numpy(tensor)
|
248 |
+
|
249 |
+
if is_label:
|
250 |
+
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
251 |
+
else:
|
252 |
+
tensor = tensor.flip(tensor.ndimension() - 1)
|
253 |
+
|
254 |
+
return tensor
|
255 |
+
|
256 |
+
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
257 |
+
|
258 |
+
|
259 |
+
def appdata_dir(appname=None, roaming=False):
|
260 |
+
""" appdata_dir(appname=None, roaming=False)
|
261 |
+
|
262 |
+
Get the path to the application directory, where applications are allowed
|
263 |
+
to write user specific files (e.g. configurations). For non-user specific
|
264 |
+
data, consider using common_appdata_dir().
|
265 |
+
If appname is given, a subdir is appended (and created if necessary).
|
266 |
+
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
267 |
+
"""
|
268 |
+
|
269 |
+
# Define default user directory
|
270 |
+
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
271 |
+
if userDir is None:
|
272 |
+
userDir = os.path.expanduser('~')
|
273 |
+
if not os.path.isdir(userDir): # pragma: no cover
|
274 |
+
userDir = '/var/tmp' # issue #54
|
275 |
+
|
276 |
+
# Get system app data dir
|
277 |
+
path = None
|
278 |
+
if sys.platform.startswith('win'):
|
279 |
+
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
280 |
+
path = (path2 or path1) if roaming else (path1 or path2)
|
281 |
+
elif sys.platform.startswith('darwin'):
|
282 |
+
path = os.path.join(userDir, 'Library', 'Application Support')
|
283 |
+
# On Linux and as fallback
|
284 |
+
if not (path and os.path.isdir(path)):
|
285 |
+
path = userDir
|
286 |
+
|
287 |
+
# Maybe we should store things local to the executable (in case of a
|
288 |
+
# portable distro or a frozen application that wants to be portable)
|
289 |
+
prefix = sys.prefix
|
290 |
+
if getattr(sys, 'frozen', None):
|
291 |
+
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
292 |
+
for reldir in ('settings', '../settings'):
|
293 |
+
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
294 |
+
if os.path.isdir(localpath): # pragma: no cover
|
295 |
+
try:
|
296 |
+
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
297 |
+
os.remove(os.path.join(localpath, 'test.write'))
|
298 |
+
except IOError:
|
299 |
+
pass # We cannot write in this directory
|
300 |
+
else:
|
301 |
+
path = localpath
|
302 |
+
break
|
303 |
+
|
304 |
+
# Get path specific for this app
|
305 |
+
if appname:
|
306 |
+
if path == userDir:
|
307 |
+
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
308 |
+
path = os.path.join(path, appname)
|
309 |
+
if not os.path.isdir(path): # pragma: no cover
|
310 |
+
os.mkdir(path)
|
311 |
+
|
312 |
+
# Done
|
313 |
+
return path
|
fete_model.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class Conv2d(nn.Module):
|
7 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs):
|
8 |
+
super().__init__(*args, **kwargs)
|
9 |
+
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
10 |
+
self.act = nn.ReLU()
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
out = self.conv_block(x)
|
14 |
+
return self.act(out)
|
15 |
+
|
16 |
+
|
17 |
+
class Conv2d_res(nn.Module):
|
18 |
+
# TensorRT does not support 'if' statement, thus we create independent Conv2d_res for residual block
|
19 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs):
|
20 |
+
super().__init__(*args, **kwargs)
|
21 |
+
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
22 |
+
self.act = nn.ReLU()
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
out = self.conv_block(x)
|
26 |
+
out += x
|
27 |
+
return self.act(out)
|
28 |
+
|
29 |
+
|
30 |
+
class Conv2dTranspose(nn.Module):
|
31 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
self.conv_block = nn.Sequential(
|
34 |
+
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
35 |
+
nn.BatchNorm2d(cout),
|
36 |
+
)
|
37 |
+
self.act = nn.ReLU()
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
out = self.conv_block(x)
|
41 |
+
return self.act(out)
|
42 |
+
|
43 |
+
|
44 |
+
class FETE_model(nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super(FETE_model, self).__init__()
|
47 |
+
|
48 |
+
self.face_encoder_blocks = nn.ModuleList(
|
49 |
+
[
|
50 |
+
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=2, padding=3)), # 256,256 -> 128,128
|
51 |
+
nn.Sequential(
|
52 |
+
Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 64,64
|
53 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
54 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
55 |
+
),
|
56 |
+
nn.Sequential(
|
57 |
+
Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 32,32
|
58 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
59 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
60 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
61 |
+
),
|
62 |
+
nn.Sequential(
|
63 |
+
Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 16,16
|
64 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
65 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
66 |
+
),
|
67 |
+
nn.Sequential(
|
68 |
+
Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 8,8
|
69 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
70 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
71 |
+
),
|
72 |
+
nn.Sequential(
|
73 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 4,4
|
74 |
+
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
|
75 |
+
),
|
76 |
+
nn.Sequential(
|
77 |
+
Conv2d(512, 512, kernel_size=3, stride=2, padding=0), # 1, 1
|
78 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
79 |
+
),
|
80 |
+
]
|
81 |
+
)
|
82 |
+
|
83 |
+
self.audio_encoder = nn.Sequential(
|
84 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
85 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
86 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
87 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
88 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
89 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
90 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
91 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
92 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
93 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
94 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
95 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
96 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
97 |
+
)
|
98 |
+
|
99 |
+
self.pose_encoder = nn.Sequential(
|
100 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
101 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
102 |
+
Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
|
103 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
104 |
+
Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
105 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
106 |
+
Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
|
107 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
108 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
|
109 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
110 |
+
)
|
111 |
+
|
112 |
+
self.emotion_encoder = nn.Sequential(
|
113 |
+
Conv2d(1, 32, kernel_size=7, stride=1, padding=1),
|
114 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
115 |
+
Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
|
116 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
117 |
+
Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
118 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
119 |
+
Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
|
120 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
121 |
+
Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
|
122 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
123 |
+
)
|
124 |
+
|
125 |
+
self.blink_encoder = nn.Sequential(
|
126 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
127 |
+
Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
|
128 |
+
Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
|
129 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
130 |
+
Conv2d(64, 128, kernel_size=3, stride=(1, 2), padding=1),
|
131 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
132 |
+
Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
|
133 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
134 |
+
Conv2d(256, 512, kernel_size=1, stride=(1, 2), padding=0),
|
135 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
136 |
+
)
|
137 |
+
|
138 |
+
self.face_decoder_blocks = nn.ModuleList(
|
139 |
+
[
|
140 |
+
nn.Sequential(
|
141 |
+
Conv2d(2048, 512, kernel_size=1, stride=1, padding=0),
|
142 |
+
),
|
143 |
+
nn.Sequential(
|
144 |
+
Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0), # 4,4
|
145 |
+
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
|
146 |
+
),
|
147 |
+
nn.Sequential(
|
148 |
+
Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
|
149 |
+
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
|
150 |
+
Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), # 8,8
|
151 |
+
Self_Attention(512, 512),
|
152 |
+
),
|
153 |
+
nn.Sequential(
|
154 |
+
Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
|
155 |
+
Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1),
|
156 |
+
Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), # 16, 16
|
157 |
+
Self_Attention(384, 384),
|
158 |
+
),
|
159 |
+
nn.Sequential(
|
160 |
+
Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
|
161 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
|
162 |
+
Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), # 32, 32
|
163 |
+
Self_Attention(256, 256),
|
164 |
+
),
|
165 |
+
nn.Sequential(
|
166 |
+
Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
|
167 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
168 |
+
Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
|
169 |
+
), # 64, 64
|
170 |
+
nn.Sequential(
|
171 |
+
Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
172 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
173 |
+
Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
|
174 |
+
),
|
175 |
+
]
|
176 |
+
) # 128,128
|
177 |
+
|
178 |
+
# self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
|
179 |
+
# nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
180 |
+
# nn.Sigmoid())
|
181 |
+
|
182 |
+
self.output_block = nn.Sequential(
|
183 |
+
Conv2dTranspose(80, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
|
184 |
+
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
185 |
+
nn.Sigmoid(),
|
186 |
+
)
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self,
|
190 |
+
face_sequences,
|
191 |
+
audio_sequences,
|
192 |
+
pose_sequences,
|
193 |
+
emotion_sequences,
|
194 |
+
blink_sequences,
|
195 |
+
):
|
196 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
197 |
+
B = audio_sequences.size(0)
|
198 |
+
|
199 |
+
# disabled for inference
|
200 |
+
# input_dim_size = len(face_sequences.size())
|
201 |
+
# if input_dim_size > 4:
|
202 |
+
# audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
203 |
+
# pose_sequences = torch.cat([pose_sequences[:, i] for i in range(pose_sequences.size(1))], dim=0)
|
204 |
+
# emotion_sequences = torch.cat([emotion_sequences[:, i] for i in range(emotion_sequences.size(1))], dim=0)
|
205 |
+
# blink_sequences = torch.cat([blink_sequences[:, i] for i in range(blink_sequences.size(1))], dim=0)
|
206 |
+
# face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
207 |
+
# print(audio_sequences.size(), face_sequences.size(), pose_sequences.size(), emotion_sequences.size())
|
208 |
+
|
209 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
210 |
+
pose_embedding = self.pose_encoder(pose_sequences) # B, 512, 1, 1
|
211 |
+
emotion_embedding = self.emotion_encoder(emotion_sequences) # B, 512, 1, 1
|
212 |
+
blink_embedding = self.blink_encoder(blink_sequences) # B, 512, 1, 1
|
213 |
+
inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) # B, 1536, 1, 1
|
214 |
+
# print(audio_embedding.size(), pose_embedding.size(), emotion_embedding.size(), inputs_embedding.size())
|
215 |
+
|
216 |
+
feats = []
|
217 |
+
x = face_sequences
|
218 |
+
for f in self.face_encoder_blocks:
|
219 |
+
x = f(x)
|
220 |
+
# print(x.shape)
|
221 |
+
feats.append(x)
|
222 |
+
|
223 |
+
x = inputs_embedding
|
224 |
+
for f in self.face_decoder_blocks:
|
225 |
+
x = f(x)
|
226 |
+
# print(x.shape)
|
227 |
+
|
228 |
+
# try:
|
229 |
+
x = torch.cat((x, feats[-1]), dim=1)
|
230 |
+
# except Exception as e:
|
231 |
+
# print(x.size())
|
232 |
+
# print(feats[-1].size())
|
233 |
+
# raise e
|
234 |
+
feats.pop()
|
235 |
+
|
236 |
+
x = self.output_block(x)
|
237 |
+
|
238 |
+
# if input_dim_size > 4:
|
239 |
+
# x = torch.split(x, B, dim=0) # [(B, C, H, W)]
|
240 |
+
# outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
|
241 |
+
|
242 |
+
# else:
|
243 |
+
outputs = x
|
244 |
+
|
245 |
+
return outputs
|
246 |
+
|
247 |
+
|
248 |
+
class Self_Attention(nn.Module):
|
249 |
+
"""
|
250 |
+
Source-Reference Attention Layer
|
251 |
+
"""
|
252 |
+
|
253 |
+
def __init__(self, in_planes_s, in_planes_r):
|
254 |
+
"""
|
255 |
+
Parameters
|
256 |
+
----------
|
257 |
+
in_planes_s: int
|
258 |
+
Number of input source feature vector channels.
|
259 |
+
in_planes_r: int
|
260 |
+
Number of input reference feature vector channels.
|
261 |
+
"""
|
262 |
+
super(Self_Attention, self).__init__()
|
263 |
+
self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1)
|
264 |
+
self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1)
|
265 |
+
self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1)
|
266 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
267 |
+
self.softmax = nn.Softmax(dim=-1)
|
268 |
+
|
269 |
+
def forward(self, source):
|
270 |
+
source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source
|
271 |
+
reference = source
|
272 |
+
"""
|
273 |
+
Parameters
|
274 |
+
----------
|
275 |
+
source : torch.Tensor
|
276 |
+
Source feature maps (B x Cs x Ts x Hs x Ws)
|
277 |
+
reference : torch.Tensor
|
278 |
+
Reference feature maps (B x Cr x Tr x Hr x Wr )
|
279 |
+
Returns :
|
280 |
+
torch.Tensor
|
281 |
+
Source-reference attention value added to the input source features
|
282 |
+
torch.Tensor
|
283 |
+
Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr)
|
284 |
+
"""
|
285 |
+
s_batchsize, sC, sH, sW = source.size()
|
286 |
+
r_batchsize, rC, rH, rW = reference.size()
|
287 |
+
|
288 |
+
proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1)
|
289 |
+
proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH)
|
290 |
+
energy = torch.bmm(proj_query, proj_key)
|
291 |
+
attention = self.softmax(energy)
|
292 |
+
proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW)
|
293 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
294 |
+
out = out.view(s_batchsize, sC, sH, sW)
|
295 |
+
out = self.gamma * out + source
|
296 |
+
return out.half() if isinstance(source, torch.cuda.FloatTensor) else out
|
inference_util.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
# set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
|
4 |
+
os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
5 |
+
# set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
|
6 |
+
os.environ["SAFETENSORS_FAST_GPU"] = "1"
|
7 |
+
import cv2
|
8 |
+
import torch
|
9 |
+
import time
|
10 |
+
import imageio
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
import moviepy.editor as mp
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from audio import load_wav, melspectrogram
|
17 |
+
from fete_model import FETE_model
|
18 |
+
from preprocess_videos import face_detect, load_from_npz
|
19 |
+
|
20 |
+
fps = 25
|
21 |
+
mel_idx_multiplier = 80.0 / fps
|
22 |
+
|
23 |
+
mel_step_size = 16
|
24 |
+
batch_size = 64 if torch.cuda.is_available() else 4
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
print("Using {} for inference.".format(device))
|
27 |
+
use_fp16 = True if torch.cuda.is_available() else False
|
28 |
+
print("Using FP16 for inference.") if use_fp16 else None
|
29 |
+
torch.backends.cudnn.benchmark = True if device == "cuda" else False
|
30 |
+
|
31 |
+
|
32 |
+
def init_model():
|
33 |
+
checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints/obama-fp16.safetensors")
|
34 |
+
model = FETE_model()
|
35 |
+
if checkpoint_path.endswith(".pth") or checkpoint_path.endswith(".ckpt"):
|
36 |
+
if device == "cuda":
|
37 |
+
checkpoint = torch.load(checkpoint_path)
|
38 |
+
else:
|
39 |
+
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
40 |
+
s = checkpoint["state_dict"]
|
41 |
+
else:
|
42 |
+
from safetensors import safe_open
|
43 |
+
|
44 |
+
s = {}
|
45 |
+
with safe_open(checkpoint_path, framework="pt", device=device) as f:
|
46 |
+
for key in f.keys():
|
47 |
+
s[key] = f.get_tensor(key)
|
48 |
+
new_s = {}
|
49 |
+
for k, v in s.items():
|
50 |
+
new_s[k.replace("module.", "")] = v
|
51 |
+
model.load_state_dict(new_s)
|
52 |
+
|
53 |
+
model = model.to(device)
|
54 |
+
model.eval()
|
55 |
+
print("Model loaded")
|
56 |
+
if use_fp16:
|
57 |
+
for name, module in model.named_modules():
|
58 |
+
if ".query_conv" in name or ".key_conv" in name or ".value_conv" in name:
|
59 |
+
# keep attention layers in full precision to avoid error
|
60 |
+
module.to(torch.float)
|
61 |
+
else:
|
62 |
+
module.to(torch.half)
|
63 |
+
print("Model converted to half precision to accelerate inference")
|
64 |
+
return model
|
65 |
+
|
66 |
+
|
67 |
+
def make_mask(image_size=256, border_size=32):
|
68 |
+
mask_bar = np.linspace(1, 0, border_size).reshape(1, -1).repeat(image_size, axis=0)
|
69 |
+
mask = np.zeros((image_size, image_size), dtype=np.float32)
|
70 |
+
mask[-border_size:, :] += mask_bar.T[::-1]
|
71 |
+
mask[:, :border_size] = mask_bar
|
72 |
+
mask[:, -border_size:] = mask_bar[:, ::-1]
|
73 |
+
mask[-border_size:, :][mask[-border_size:, :] < 0.6] = 0.6
|
74 |
+
mask = np.stack([mask] * 3, axis=-1).astype(np.float32)
|
75 |
+
return mask
|
76 |
+
|
77 |
+
|
78 |
+
face_mask = make_mask()
|
79 |
+
|
80 |
+
|
81 |
+
def blend_images(foreground, background):
|
82 |
+
# Blend the foreground and background images using the mask
|
83 |
+
temp_mask = cv2.resize(face_mask, (foreground.shape[1], foreground.shape[0]))
|
84 |
+
blended = cv2.multiply(foreground.astype(np.float32), temp_mask)
|
85 |
+
blended += cv2.multiply(background.astype(np.float32), 1 - temp_mask)
|
86 |
+
blended = np.clip(blended, 0, 255).astype(np.uint8)
|
87 |
+
return blended
|
88 |
+
|
89 |
+
|
90 |
+
def smooth_coord(last_coord, current_coord, factor=0.4):
|
91 |
+
change = np.array(current_coord) - np.array(last_coord)
|
92 |
+
change = change * factor
|
93 |
+
return (np.array(last_coord) + np.array(change)).astype(int).tolist()
|
94 |
+
|
95 |
+
|
96 |
+
def add_black(imgs):
|
97 |
+
for i in range(len(imgs)):
|
98 |
+
# print('x', imgs[i].shape)
|
99 |
+
imgs[i] = cv2.vconcat(
|
100 |
+
[np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)]
|
101 |
+
)
|
102 |
+
# imgs[i] = cv2.hconcat([np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8), imgs[i], np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8)])[:480+150,740-100:-740+100,:]
|
103 |
+
|
104 |
+
# print('xx', imgs[i].shape)
|
105 |
+
return imgs
|
106 |
+
|
107 |
+
|
108 |
+
def remove_black(img):
|
109 |
+
return img[100:-20]
|
110 |
+
|
111 |
+
|
112 |
+
def resize_length(input_attributes, length):
|
113 |
+
input_attributes = np.array(input_attributes)
|
114 |
+
resized_attributes = [input_attributes[int(i_ * (input_attributes.shape[0] / length))] for i_ in range(length)]
|
115 |
+
return np.array(resized_attributes).T
|
116 |
+
|
117 |
+
|
118 |
+
def output_chunks(input_attributes):
|
119 |
+
output_chunks = []
|
120 |
+
len_ = len(input_attributes[0])
|
121 |
+
|
122 |
+
i = 0
|
123 |
+
# print(mel.shape, pose.shape)
|
124 |
+
# (80, 801) (3, 801)
|
125 |
+
while 1:
|
126 |
+
start_idx = int(i * mel_idx_multiplier)
|
127 |
+
if start_idx + mel_step_size > len_:
|
128 |
+
output_chunks.append(input_attributes[:, len_ - mel_step_size :])
|
129 |
+
break
|
130 |
+
output_chunks.append(input_attributes[:, start_idx : start_idx + mel_step_size])
|
131 |
+
i += 1
|
132 |
+
return output_chunks
|
133 |
+
|
134 |
+
|
135 |
+
def prepare_data(face_path, audio_path, pose, emotion, blink, img_size=256, pads=[0, 0, 0, 0]):
|
136 |
+
if os.path.isfile(face_path) and face_path.split(".")[1] in ["jpg", "png", "jpeg"]:
|
137 |
+
static = True
|
138 |
+
full_frames = [cv2.imread(face_path)]
|
139 |
+
else:
|
140 |
+
static = False
|
141 |
+
video_stream = cv2.VideoCapture(face_path)
|
142 |
+
|
143 |
+
# print('Reading video frames...')
|
144 |
+
full_frames = []
|
145 |
+
while 1:
|
146 |
+
still_reading, frame = video_stream.read()
|
147 |
+
if not still_reading:
|
148 |
+
video_stream.release()
|
149 |
+
break
|
150 |
+
full_frames.append(frame)
|
151 |
+
print("Number of frames available for inference: " + str(len(full_frames)))
|
152 |
+
|
153 |
+
wav = load_wav(audio_path, 16000)
|
154 |
+
mel = melspectrogram(wav)
|
155 |
+
# take half
|
156 |
+
len_ = mel.shape[1] # //2
|
157 |
+
mel = mel[:, :len_]
|
158 |
+
# print('>>>', mel.shape)
|
159 |
+
|
160 |
+
pose = resize_length(pose, len_)
|
161 |
+
emotion = resize_length(emotion, len_)
|
162 |
+
blink = resize_length(blink, len_)
|
163 |
+
|
164 |
+
if np.isnan(mel.reshape(-1)).sum() > 0:
|
165 |
+
raise ValueError("Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again")
|
166 |
+
|
167 |
+
mel_chunks = output_chunks(mel)
|
168 |
+
pose_chunks = output_chunks(pose)
|
169 |
+
emotion_chunks = output_chunks(emotion)
|
170 |
+
blink_chunks = output_chunks(blink)
|
171 |
+
|
172 |
+
gen = datagen(face_path, full_frames, mel_chunks, pose_chunks, emotion_chunks, blink_chunks, static=static, img_size=img_size, pads=pads)
|
173 |
+
steps = int(np.ceil(float(len(mel_chunks)) / batch_size))
|
174 |
+
|
175 |
+
return gen, steps
|
176 |
+
|
177 |
+
|
178 |
+
def preprocess_batch(batch):
|
179 |
+
return torch.FloatTensor(np.reshape(batch, [len(batch), 1, batch[0].shape[0], batch[0].shape[1]])).to(device)
|
180 |
+
|
181 |
+
|
182 |
+
def datagen(face_path, frames, mels, poses, emotions, blinks, static=False, img_size=256, pads=[0, 0, 0, 0]):
|
183 |
+
img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
|
184 |
+
scale_factor = img_size // 128
|
185 |
+
|
186 |
+
# print("Length of mel chunks: {}".format(len(mel_chunks)))
|
187 |
+
frames = frames[: len(mels)]
|
188 |
+
frames = add_black(frames)
|
189 |
+
try:
|
190 |
+
video_name = os.path.basename(face_path).split(".")[0]
|
191 |
+
coords = load_from_npz(video_name)
|
192 |
+
face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
|
193 |
+
|
194 |
+
except Exception as e:
|
195 |
+
print("No existing coords found, running face detection...", "Error: ", e)
|
196 |
+
if not static:
|
197 |
+
coords = face_detect(frames, pads)
|
198 |
+
face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
|
199 |
+
else:
|
200 |
+
coords = face_detect([frames[0]], pads)
|
201 |
+
face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
|
202 |
+
|
203 |
+
face_det_results = face_det_results[: len(mels)]
|
204 |
+
|
205 |
+
while len(frames) < len(mels):
|
206 |
+
face_det_results = face_det_results + face_det_results[::-1]
|
207 |
+
frames = frames + frames[::-1]
|
208 |
+
else:
|
209 |
+
face_det_results = face_det_results[: len(mels)]
|
210 |
+
frames = frames[: len(mels)]
|
211 |
+
|
212 |
+
for i in range(len(mels)):
|
213 |
+
idx = 0 if static else i % len(frames)
|
214 |
+
frame_to_save = frames[idx].copy()
|
215 |
+
face, coords = face_det_results[idx].copy()
|
216 |
+
face = cv2.resize(face, (img_size, img_size))
|
217 |
+
|
218 |
+
img_batch.append(face)
|
219 |
+
mel_batch.append(mels[i])
|
220 |
+
pose_batch.append(poses[i])
|
221 |
+
emotion_batch.append(emotions[i])
|
222 |
+
blink_batch.append(blinks[i])
|
223 |
+
frame_batch.append(frame_to_save)
|
224 |
+
coords_batch.append(coords)
|
225 |
+
|
226 |
+
# print(m.shape, poses[i].shape)
|
227 |
+
# (80, 16) (3, 16)
|
228 |
+
if len(img_batch) >= batch_size:
|
229 |
+
img_masked = np.asarray(img_batch).copy()
|
230 |
+
|
231 |
+
img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0
|
232 |
+
|
233 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
|
234 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
235 |
+
|
236 |
+
mel_batch = preprocess_batch(mel_batch)
|
237 |
+
pose_batch = preprocess_batch(pose_batch)
|
238 |
+
emotion_batch = preprocess_batch(emotion_batch)
|
239 |
+
blink_batch = preprocess_batch(blink_batch)
|
240 |
+
|
241 |
+
if use_fp16:
|
242 |
+
yield (
|
243 |
+
img_batch.half(),
|
244 |
+
mel_batch.half(),
|
245 |
+
pose_batch.half(),
|
246 |
+
emotion_batch.half(),
|
247 |
+
blink_batch.half(),
|
248 |
+
), frame_batch, coords_batch
|
249 |
+
else:
|
250 |
+
yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
|
251 |
+
img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
|
252 |
+
|
253 |
+
if len(img_batch) > 0:
|
254 |
+
img_masked = np.asarray(img_batch).copy()
|
255 |
+
|
256 |
+
img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0
|
257 |
+
|
258 |
+
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
|
259 |
+
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
260 |
+
|
261 |
+
mel_batch = preprocess_batch(mel_batch)
|
262 |
+
pose_batch = preprocess_batch(pose_batch)
|
263 |
+
emotion_batch = preprocess_batch(emotion_batch)
|
264 |
+
blink_batch = preprocess_batch(blink_batch)
|
265 |
+
|
266 |
+
if use_fp16:
|
267 |
+
yield (img_batch.half(), mel_batch.half(), pose_batch.half(), emotion_batch.half(), blink_batch.half()), frame_batch, coords_batch
|
268 |
+
else:
|
269 |
+
yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
|
270 |
+
|
271 |
+
|
272 |
+
def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False):
|
273 |
+
timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime(time.time()))
|
274 |
+
gen, steps = prepare_data(face_path, audio_path, pose, emotion, blink)
|
275 |
+
steps = 1 if preview else steps
|
276 |
+
# duration = librosa.get_duration(filename=audio_path)
|
277 |
+
|
278 |
+
if preview:
|
279 |
+
outfile = "/tmp/{}.jpg".format(timestamp)
|
280 |
+
else:
|
281 |
+
outfile = "/tmp/{}.mp4".format(timestamp)
|
282 |
+
tmp_video = "/tmp/temp_{}.mp4".format(timestamp)
|
283 |
+
writer = imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1) if not preview else None
|
284 |
+
# print('Generating frames...', outfile, steps)
|
285 |
+
for inputs, frames, coords in tqdm(gen, total=steps):
|
286 |
+
with torch.no_grad():
|
287 |
+
pred = model(*inputs)
|
288 |
+
|
289 |
+
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0
|
290 |
+
|
291 |
+
for p, f, c in zip(pred, frames, coords):
|
292 |
+
y1, y2, x1, x2 = c
|
293 |
+
y1, y2, x1, x2 = int(y1), int(y2), int(x1), int(x2)
|
294 |
+
y = round(y2 - y1)
|
295 |
+
x = round(x2 - x1)
|
296 |
+
# print(x, y, p.shape)
|
297 |
+
p = cv2.resize(p.astype(np.uint8), (x, y))
|
298 |
+
|
299 |
+
try:
|
300 |
+
f[y1 : y1 + y, x1 : x1 + x] = blend_images(f[y1 : y1 + y, x1 : x1 + x], p)
|
301 |
+
except Exception as e:
|
302 |
+
print(e)
|
303 |
+
f[y1 : y1 + y, x1 : x1 + x] = p
|
304 |
+
# out.write(f[100:-20])
|
305 |
+
f = remove_black(f)
|
306 |
+
if preview:
|
307 |
+
cv2.imwrite(outfile, f, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
|
308 |
+
return outfile
|
309 |
+
writer.append_data(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
|
310 |
+
writer.close()
|
311 |
+
video_clip = mp.VideoFileClip(tmp_video)
|
312 |
+
audio_clip = mp.AudioFileClip(audio_path)
|
313 |
+
video_clip = video_clip.set_audio(audio_clip)
|
314 |
+
video_clip.write_videofile(outfile, codec="libx264")
|
315 |
+
|
316 |
+
print("Saved to {}".format(outfile) if os.path.exists(outfile) else "Failed to save {}".format(outfile))
|
317 |
+
try:
|
318 |
+
os.remove(tmp_video)
|
319 |
+
del video_clip
|
320 |
+
del audio_clip
|
321 |
+
del gen
|
322 |
+
except:
|
323 |
+
pass
|
324 |
+
return outfile
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
model = init_model()
|
329 |
+
|
330 |
+
from attributtes_utils import input_pose, input_emotion, input_blink
|
331 |
+
|
332 |
+
pose = input_pose()
|
333 |
+
emotion = input_emotion()
|
334 |
+
blink = input_blink()
|
335 |
+
audio_path = "./assets/sample.wav"
|
336 |
+
face_path = "./assets/sample.mp4"
|
337 |
+
|
338 |
+
infenrece(model, face_path, audio_path, pose, emotion, blink)
|
preprocess_videos.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import face_detection
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
import glob
|
7 |
+
import os
|
8 |
+
from natsort import natsorted
|
9 |
+
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
def get_squre_coords(coords, image, size=None, last_size=None):
|
13 |
+
y1, y2, x1, x2 = coords
|
14 |
+
w, h = x2 - x1, y2 - y1
|
15 |
+
center = (x1 + w // 2, y1 + h // 2)
|
16 |
+
if size is None:
|
17 |
+
size = (w + h) // 2
|
18 |
+
if last_size is not None:
|
19 |
+
size = (w + h) // 2
|
20 |
+
size = (size - last_size) // 5 + last_size
|
21 |
+
x1, y1 = center[0] - size // 2, center[1] - size // 2
|
22 |
+
x2, y2 = x1 + size, y1 + size
|
23 |
+
return size, [y1, y2, x1, x2]
|
24 |
+
|
25 |
+
|
26 |
+
def get_smoothened_boxes(boxes, T):
|
27 |
+
for i in range(len(boxes)):
|
28 |
+
if i + T > len(boxes):
|
29 |
+
window = boxes[len(boxes) - T :]
|
30 |
+
else:
|
31 |
+
window = boxes[i : i + T]
|
32 |
+
boxes[i] = np.mean(window, axis=0)
|
33 |
+
return boxes
|
34 |
+
|
35 |
+
|
36 |
+
def face_detect(images, pads):
|
37 |
+
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
|
38 |
+
|
39 |
+
batch_size = 32 if device == "cuda" else 4
|
40 |
+
print("face detect batch size:", batch_size)
|
41 |
+
while 1:
|
42 |
+
predictions = []
|
43 |
+
try:
|
44 |
+
for i in tqdm(range(0, len(images), batch_size)):
|
45 |
+
predictions.extend(detector.get_detections_for_batch(np.array(images[i : i + batch_size])))
|
46 |
+
except RuntimeError:
|
47 |
+
if batch_size == 1:
|
48 |
+
raise RuntimeError("Image too big to run face detection on GPU. Please use the --resize_factor argument")
|
49 |
+
batch_size //= 2
|
50 |
+
print("Recovering from OOM error; New batch size: {}".format(batch_size))
|
51 |
+
continue
|
52 |
+
break
|
53 |
+
|
54 |
+
results = []
|
55 |
+
pady1, pady2, padx1, padx2 = pads
|
56 |
+
for rect, image in zip(predictions, images):
|
57 |
+
if rect is None:
|
58 |
+
cv2.imwrite(".temp/faulty_frame.jpg", image) # check this frame where the face was not detected.
|
59 |
+
raise ValueError("Face not detected! Ensure the video contains a face in all the frames.")
|
60 |
+
|
61 |
+
y1 = max(0, rect[1] - pady1)
|
62 |
+
y2 = min(image.shape[0], rect[3] + pady2)
|
63 |
+
x1 = max(0, rect[0] - padx1)
|
64 |
+
x2 = min(image.shape[1], rect[2] + padx2)
|
65 |
+
# y_gap, x_gap = ((y2 - y1) * 2) // 3, ((x2 - x1) * 2) // 3
|
66 |
+
y_gap, x_gap = (y2 - y1)//2, (x2 - x1)//2
|
67 |
+
coords_ = [y1 - y_gap, y2 + y_gap, x1 - x_gap, x2 + x_gap]
|
68 |
+
|
69 |
+
# smooth the coords
|
70 |
+
_, coords = get_squre_coords(coords_, image, None)
|
71 |
+
|
72 |
+
y1, y2, x1, x2 = coords
|
73 |
+
y1 = max(0, y1)
|
74 |
+
y2 = min(image.shape[0], y2)
|
75 |
+
x1 = max(0, x1)
|
76 |
+
x2 = min(image.shape[1], x2)
|
77 |
+
|
78 |
+
results.append([x1, y1, x2, y2])
|
79 |
+
|
80 |
+
print("Number of frames cropped: {}".format(len(results)))
|
81 |
+
print("First coords: {}".format(results[0]))
|
82 |
+
boxes = np.array(results)
|
83 |
+
boxes = get_smoothened_boxes(boxes, T=5)
|
84 |
+
# results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
85 |
+
|
86 |
+
del detector
|
87 |
+
return boxes
|
88 |
+
|
89 |
+
def add_black(imgs):
|
90 |
+
for i in range(len(imgs)):
|
91 |
+
imgs[i] = cv2.vconcat([np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)])
|
92 |
+
|
93 |
+
return imgs
|
94 |
+
|
95 |
+
def preprocess(video_dir="./assets/videos", save_dir="./assets/coords"):
|
96 |
+
all_videos = natsorted(glob.glob(os.path.join(video_dir, "*.mp4")))
|
97 |
+
for video_path in all_videos:
|
98 |
+
video_stream = cv2.VideoCapture(video_path)
|
99 |
+
|
100 |
+
# print('Reading video frames...')
|
101 |
+
full_frames = []
|
102 |
+
while 1:
|
103 |
+
still_reading, frame = video_stream.read()
|
104 |
+
if not still_reading:
|
105 |
+
video_stream.release()
|
106 |
+
break
|
107 |
+
full_frames.append(frame)
|
108 |
+
print("Number of frames available for inference: " + str(len(full_frames)))
|
109 |
+
full_frames = add_black(full_frames)
|
110 |
+
# print('Face detection running...')
|
111 |
+
coords = face_detect(full_frames, pads=(0, 0, 0, 0))
|
112 |
+
np.savez_compressed(os.path.join(save_dir, os.path.basename(video_path).split(".")[0]), coords=coords)
|
113 |
+
|
114 |
+
|
115 |
+
def load_from_npz(video_name, save_dir="./assets/coords"):
|
116 |
+
npz = np.load(os.path.join(save_dir, video_name + ".npz"))
|
117 |
+
return npz["coords"]
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
preprocess()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python==4.7.0.72
|
2 |
+
imageio[ffmpeg]==2.31.1
|
3 |
+
librosa==0.8.1
|
4 |
+
moviepy==1.0.3
|
5 |
+
numpy==1.23.5
|
6 |
+
safetensors==0.3.2
|
7 |
+
torchvision==0.15.2
|
8 |
+
gradio==3.41.0
|
9 |
+
natsort
|
10 |
+
tqdm
|