[email protected] commited on
Commit
2b34e02
·
1 Parent(s): 1c08139
.gitignore ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ face_detection/face_detection.pb
31
+ assets/audios/*.wav
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+ .vercel
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Free-View Expressive Talking Head Video Editing
3
- emoji: 🔥
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.41.0
8
  app_file: app.py
 
1
  ---
2
+ title: FETE
3
+ emoji: 👀
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.41.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from natsort import natsorted
4
+ import gradio as gr
5
+
6
+ from inference_util import init_model, infenrece
7
+ from attributtes_utils import input_pose, input_emotion, input_blink
8
+
9
+ model = init_model()
10
+
11
+
12
+ def process(input_vid, audio_path, pose_select, emotion_select, blink_select):
13
+ pose = input_pose(pose_select)
14
+ emotion = input_emotion(emotion_select)
15
+ blink = input_blink(blink_select)
16
+
17
+ print("input_vid: ", input_vid)
18
+ result = infenrece(model, os.path.join("./assets/videos/", input_vid), os.path.join("./assets/audios/", audio_path), pose, emotion, blink)
19
+ print("result: ", result)
20
+
21
+ print("finished !")
22
+
23
+ return result # , gr.Group.update(visible=True)
24
+
25
+
26
+ available_videos = natsorted(glob.glob("./assets/videos/*.mp4"))
27
+ available_videos = [os.path.basename(x) for x in available_videos]
28
+
29
+ # prepare audio
30
+ for video in available_videos:
31
+ audio = video.replace(".mp4", ".wav")
32
+ if not os.path.exists(os.path.join("./assets/audios/", audio)):
33
+ os.system(f"ffmpeg -i ./assets/videos/{video} -vn -acodec pcm_s16le -ar 16000 -ac 1 ./assets/audios/{audio}")
34
+ available_audios = natsorted(glob.glob("./assets/audios/*.wav"))
35
+ available_audios = [os.path.basename(x) for x in available_audios]
36
+
37
+
38
+
39
+ with gr.Blocks() as demo:
40
+ gr.HTML(
41
+ """
42
+ <h1 style="text-align: center; font-size: 40px; font-family: 'Times New Roman', Times, serif;">
43
+ Free-View Expressive Talking Head Video Editing
44
+ </h1>
45
+ <p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
46
+ <a href="https://sky24h.github.io/websites/icassp2023_free-view_video-editing/" target="_blank">
47
+ <b>Project Page</b>
48
+ <br>
49
+ </a>
50
+ <a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/FETE?duplicate=true">
51
+ <img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
52
+ </a>
53
+ </p>
54
+ """)
55
+ with gr.Column(elem_id="col-container"):
56
+ with gr.Row():
57
+ with gr.Column():
58
+ # select and preview video from a list of examples
59
+ video_preview = gr.Video(label="Video Preview", elem_id="video-preview", height=360, value="./assets/videos/sample1.mp4")
60
+ video_input = gr.Dropdown(available_videos, label="Input Video", value="sample1.mp4")
61
+ audio_preview = gr.Audio(label="Audio Preview", elem_id="audio-preview", height=360, value="./assets/audios/sample2.wav")
62
+ audio_input = gr.Dropdown(available_audios, label="Input Audio", value="sample2.wav")
63
+ pose_select = gr.Radio(["front", "left_right_shaking"], label="Pose", value="front")
64
+ emotion_select = gr.Radio(["neutral", "happy", "angry", "surprised"], label="Emotion", value="neutral")
65
+ blink_select = gr.Radio(["yes", "no"], label="Blink", value="yes")
66
+ # with gr.Row():
67
+ with gr.Column():
68
+ video_out = gr.Video(label="Video Output", elem_id="video-output", height=360)
69
+ # titile: Free-View Expressive Talking Head Video Editing
70
+
71
+ submit_btn = gr.Button("Generate video")
72
+
73
+ inputs = [video_input, audio_input, pose_select, emotion_select, blink_select]
74
+ outputs = [video_out]
75
+
76
+ video_preview_output = [video_preview]
77
+ audio_preview_output = [audio_preview]
78
+
79
+ video_input.select(lambda x: "./assets/videos/" + x, video_input, video_preview_output)
80
+ audio_input.select(lambda x: "./assets/audios/" + x, audio_input, audio_preview_output)
81
+ submit_btn.click(process, inputs, outputs)
82
+
83
+ demo.queue(max_size=12).launch()
assets/audios/audio files.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
assets/coords/coord files.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
assets/coords/sample1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5323c1d6ed9dbda978859aed00bd6b4ec2ca122dbc556821edd4add44181c360
3
+ size 647
assets/coords/sample2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aaf2424b58f2c32b3b55d172a4437bd73687ca2346e4928575c481fa024cf41d
3
+ size 1150
assets/coords/sample3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb9c284ec586ffe5cb9af9e715907810ce6f07694ef4245c29bb06646a04f3ed
3
+ size 839
assets/coords/sample4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa3cd8b488cbd368c22b8d10f072d8a55cc73c3c0e8cea6e11420f2743cc00c4
3
+ size 901
assets/coords/sample5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ee15a8dd3b47bc036a4502ccd60a0a5c29262a5581593e8f97c11e18b389e67
3
+ size 974
assets/videos/sample1.mp4 ADDED
Binary file (184 kB). View file
 
assets/videos/sample2.mp4 ADDED
Binary file (625 kB). View file
 
assets/videos/sample3.mp4 ADDED
Binary file (374 kB). View file
 
assets/videos/sample4.mp4 ADDED
Binary file (562 kB). View file
 
assets/videos/sample5.mp4 ADDED
Binary file (698 kB). View file
 
attributtes_utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ def input_pose(pose_select="front"):
6
+ step = 1
7
+ if pose_select == "front":
8
+ pose = [[0.0, 0.0, 0.0] for i in range(0, 10, step)]#-20 to 20
9
+ elif pose_select == "left_right_shaking":
10
+ pose = [[-i, 0.0, 0.0] for i in range(0, 20, step)]#0 to -20
11
+ pose += [[i - 20.0, 0.0, 0.0] for i in range(0, 40, step)] # -20 to 20
12
+ pose += [[20.0 - i, 0.0, 0.0] for i in range(0, 20, step)] # 20 to 0
13
+ pose = pose + pose
14
+ pose = pose + pose
15
+ pose = pose + pose
16
+ # pose = pose + pose[::-1]
17
+ else:
18
+ raise ValueError("pose_select Error")
19
+
20
+ return pose
21
+
22
+
23
+ EMOTIONS = ['angry', 'disgust', 'fear', 'happy', 'sad', 'surprise', 'neutral']
24
+ def input_emotion(emotion_select="neutral"):
25
+ sacle_factor = 2
26
+ if emotion_select == "neutral":
27
+ emotion = [[0.0,0.0,0.0,0.0,0.0,0.0,1.0] for _ in range(2)]#((i%50))*0.04
28
+ elif emotion_select == "happy":
29
+ emotion = [[0.0,0.0,0.0,1.0,0.0,0.0,0.0] for _ in range(2)]#((i%50))*0.04
30
+ elif emotion_select == "angry":
31
+ emotion = [[1.0,0.0,0.0,0.0,0.0,0.0,0.0] for _ in range(2)]
32
+ elif emotion_select == "surprised":
33
+ emotion = [[0.0,0.0,0.0,0.0,0.0,1.0,0.0] for _ in range(2)]
34
+ else:
35
+ raise ValueError("emotion_select Error")
36
+
37
+ return emotion * sacle_factor
38
+
39
+
40
+ def input_blink(blink_select="yes"):
41
+ if blink_select == "yes":
42
+ blink = [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [0.8], [0.6], [0.0], [0.0], [1.0]]
43
+ blink = blink + blink
44
+ else:
45
+ blink = [[1.0] for _ in range(2)]
46
+ return blink
47
+
audio.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+
5
+ # import tensorflow as tf
6
+ from scipy import signal
7
+ from scipy.io import wavfile
8
+
9
+ hp_num_mels = 80
10
+ hp_rescale = True
11
+ hp_rescaling_max = 0.9
12
+ hp_use_lws = False
13
+ hp_n_fft = 800
14
+ hp_hop_size = 200
15
+ hp_win_size = 800
16
+ hp_sample_rate = 16000
17
+ hp_frame_shift_ms = None
18
+ hp_signal_normalization = True
19
+ hp_allow_clipping_in_normalization = True
20
+ hp_symmetric_mels = True
21
+ hp_max_abs_value = 4.0
22
+ hp_preemphasize = True
23
+ hp_preemphasis = 0.97
24
+ hp_min_level_db = -100
25
+ hp_ref_level_db = 20
26
+ hp_fmin = 55
27
+ hp_fmax = 7600
28
+
29
+
30
+ def load_wav(path, sr):
31
+ return librosa.core.load(path, sr=sr)[0]
32
+
33
+
34
+ def save_wav(wav, path, sr):
35
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
36
+ # proposed by @dsmiller
37
+ wavfile.write(path, sr, wav.astype(np.int16))
38
+
39
+
40
+ def save_wavenet_wav(wav, path, sr):
41
+ librosa.output.write_wav(path, wav, sr=sr)
42
+
43
+
44
+ def preemphasis(wav, k, preemphasize=True):
45
+ if preemphasize:
46
+ return signal.lfilter([1, -k], [1], wav)
47
+ return wav
48
+
49
+
50
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
51
+ if inv_preemphasize:
52
+ return signal.lfilter([1], [1, -k], wav)
53
+ return wav
54
+
55
+
56
+ def get_hop_size():
57
+ hop_size = hp_hop_size
58
+ if hop_size is None:
59
+ assert hp_frame_shift_ms is not None
60
+ hop_size = int(hp_frame_shift_ms / 1000 * hp_sample_rate)
61
+ return hop_size
62
+
63
+
64
+ def linearspectrogram(wav):
65
+ D = _stft(preemphasis(wav, hp_preemphasis, hp_preemphasize))
66
+ S = _amp_to_db(np.abs(D)) - hp_ref_level_db
67
+ if hp_signal_normalization:
68
+ return _normalize(S)
69
+ return S
70
+
71
+
72
+ def melspectrogram(wav):
73
+ D = _stft(preemphasis(wav, hp_preemphasis, hp_preemphasize))
74
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp_ref_level_db
75
+ if hp_signal_normalization:
76
+ return _normalize(S)
77
+ return S
78
+
79
+
80
+ def _lws_processor():
81
+ import lws
82
+
83
+ return lws.lws(hp_n_fft, get_hop_size(), fftsize=hp_win_size, mode="speech")
84
+
85
+
86
+ def _stft(y):
87
+ if hp_use_lws:
88
+ return _lws_processor(hp).stft(y).T
89
+ else:
90
+ return librosa.stft(y=y, n_fft=hp_n_fft, hop_length=get_hop_size(), win_length=hp_win_size)
91
+
92
+
93
+ ##########################################################
94
+ # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
95
+ def num_frames(length, fsize, fshift):
96
+ """Compute number of time frames of spectrogram"""
97
+ pad = fsize - fshift
98
+ if length % fshift == 0:
99
+ M = (length + pad * 2 - fsize) // fshift + 1
100
+ else:
101
+ M = (length + pad * 2 - fsize) // fshift + 2
102
+ return M
103
+
104
+
105
+ def pad_lr(x, fsize, fshift):
106
+ """Compute left and right padding"""
107
+ M = num_frames(len(x), fsize, fshift)
108
+ pad = fsize - fshift
109
+ T = len(x) + 2 * pad
110
+ r = (M - 1) * fshift + fsize - T
111
+ return pad, pad + r
112
+
113
+
114
+ ##########################################################
115
+ # Librosa correct padding
116
+ def librosa_pad_lr(x, fsize, fshift):
117
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
118
+
119
+
120
+ # Conversions
121
+ _mel_basis = None
122
+
123
+
124
+ def _linear_to_mel(spectogram):
125
+ global _mel_basis
126
+ if _mel_basis is None:
127
+ _mel_basis = _build_mel_basis()
128
+ return np.dot(_mel_basis, spectogram)
129
+
130
+
131
+ def _build_mel_basis():
132
+ assert hp_fmax <= hp_sample_rate // 2
133
+ return librosa.filters.mel(hp_sample_rate, hp_n_fft, n_mels=hp_num_mels, fmin=hp_fmin, fmax=hp_fmax)
134
+
135
+
136
+ def _amp_to_db(x):
137
+ min_level = np.exp(hp_min_level_db / 20 * np.log(10))
138
+ return 20 * np.log10(np.maximum(min_level, x))
139
+
140
+
141
+ def _normalize(S):
142
+ if hp_allow_clipping_in_normalization:
143
+ if hp_symmetric_mels:
144
+ return np.clip(
145
+ (2 * hp_max_abs_value) * ((S - hp_min_level_db) / (-hp_min_level_db)) - hp_max_abs_value,
146
+ -hp_max_abs_value,
147
+ hp_max_abs_value,
148
+ )
149
+ else:
150
+ return np.clip(
151
+ hp_max_abs_value * ((S - hp_min_level_db) / (-hp_min_level_db)),
152
+ 0,
153
+ hp_max_abs_value,
154
+ )
155
+
156
+ assert S.max() <= 0 and S.min() - hp_min_level_db >= 0
157
+ if hp_symmetric_mels:
158
+ return (2 * hp_max_abs_value) * ((S - hp_min_level_db) / (-hp_min_level_db)) - hp_max_abs_value
159
+ else:
160
+ return hp_max_abs_value * ((S - hp_min_level_db) / (-hp_min_level_db))
checkpoints/obama-fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a11a77110f573c5d379f7ce0a14b2c85be12c41723bef601c612f06d0d1f7d55
3
+ size 96452948
face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = '[email protected]'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize
face_detection/api.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+ ROOT = os.path.dirname(os.path.abspath(__file__))
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+
59
+ # Get the face detector
60
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
61
+ globals(), locals(), [face_detector], 0)
62
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
63
+
64
+ def get_detections_for_batch(self, images):
65
+ images = images[..., ::-1]
66
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
67
+ results = []
68
+
69
+ for i, d in enumerate(detected_faces):
70
+ if len(d) == 0:
71
+ results.append(None)
72
+ continue
73
+ d = d[0]
74
+ d = np.clip(d, 0, None)
75
+
76
+ x1, y1, x2, y2 = map(int, d[:-1])
77
+ results.append((x1, y1, x2, y2))
78
+
79
+ return results
face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+
70
+ bboxlist = []
71
+ for i in range(len(olist) // 2):
72
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
73
+ olist = [oelem.data.cpu() for oelem in olist]
74
+ for i in range(len(olist) // 2):
75
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
76
+ FB, FC, FH, FW = ocls.size() # feature map size
77
+ stride = 2**(i + 2) # 4,8,16,32,64,128
78
+ anchor = stride * 4
79
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
80
+ for Iindex, hindex, windex in poss:
81
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
82
+ score = ocls[:, 1, hindex, windex]
83
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
84
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
85
+ variances = [0.1, 0.2]
86
+ box = batch_decode(loc, priors, variances)
87
+ box = box[:, 0] * 1.0
88
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
89
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
90
+ bboxlist = np.array(bboxlist)
91
+ if 0 == len(bboxlist):
92
+ bboxlist = np.zeros((1, BB, 5))
93
+
94
+ return bboxlist
95
+
96
+ def flip_detect(net, img, device):
97
+ img = cv2.flip(img, 1)
98
+ b = detect(net, img, device)
99
+
100
+ bboxlist = np.zeros(b.shape)
101
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
102
+ bboxlist[:, 1] = b[:, 1]
103
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
104
+ bboxlist[:, 3] = b[:, 3]
105
+ bboxlist[:, 4] = b[:, 4]
106
+ return bboxlist
107
+
108
+
109
+ def pts_to_bb(pts):
110
+ min_x, min_y = np.min(pts, axis=0)
111
+ max_x, max_y = np.max(pts, axis=0)
112
+ return np.array([min_x, min_y, max_x, max_y])
face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
face_detection/detection/sfd/s3fd.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:619a31681264d3f7f7fc7a16a42cbbe8b23f31a256f75a366e5a1bcd59b33543
3
+ size 89843225
face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
fete_model.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class Conv2d(nn.Module):
7
+ def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
10
+ self.act = nn.ReLU()
11
+
12
+ def forward(self, x):
13
+ out = self.conv_block(x)
14
+ return self.act(out)
15
+
16
+
17
+ class Conv2d_res(nn.Module):
18
+ # TensorRT does not support 'if' statement, thus we create independent Conv2d_res for residual block
19
+ def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs):
20
+ super().__init__(*args, **kwargs)
21
+ self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
22
+ self.act = nn.ReLU()
23
+
24
+ def forward(self, x):
25
+ out = self.conv_block(x)
26
+ out += x
27
+ return self.act(out)
28
+
29
+
30
+ class Conv2dTranspose(nn.Module):
31
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ self.conv_block = nn.Sequential(
34
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
35
+ nn.BatchNorm2d(cout),
36
+ )
37
+ self.act = nn.ReLU()
38
+
39
+ def forward(self, x):
40
+ out = self.conv_block(x)
41
+ return self.act(out)
42
+
43
+
44
+ class FETE_model(nn.Module):
45
+ def __init__(self):
46
+ super(FETE_model, self).__init__()
47
+
48
+ self.face_encoder_blocks = nn.ModuleList(
49
+ [
50
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=2, padding=3)), # 256,256 -> 128,128
51
+ nn.Sequential(
52
+ Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 64,64
53
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
54
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
55
+ ),
56
+ nn.Sequential(
57
+ Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 32,32
58
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
59
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
60
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
61
+ ),
62
+ nn.Sequential(
63
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 16,16
64
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
65
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
66
+ ),
67
+ nn.Sequential(
68
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 8,8
69
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
70
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
71
+ ),
72
+ nn.Sequential(
73
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 4,4
74
+ Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
75
+ ),
76
+ nn.Sequential(
77
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=0), # 1, 1
78
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
79
+ ),
80
+ ]
81
+ )
82
+
83
+ self.audio_encoder = nn.Sequential(
84
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
85
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
86
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
87
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
88
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
89
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
90
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
91
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
92
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
93
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
94
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
95
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
96
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
97
+ )
98
+
99
+ self.pose_encoder = nn.Sequential(
100
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
101
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
102
+ Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
103
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
104
+ Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
105
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
106
+ Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
107
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
108
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
109
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
110
+ )
111
+
112
+ self.emotion_encoder = nn.Sequential(
113
+ Conv2d(1, 32, kernel_size=7, stride=1, padding=1),
114
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
115
+ Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
116
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
117
+ Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
118
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
119
+ Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
120
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
121
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=0),
122
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
123
+ )
124
+
125
+ self.blink_encoder = nn.Sequential(
126
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
127
+ Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1),
128
+ Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1),
129
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
130
+ Conv2d(64, 128, kernel_size=3, stride=(1, 2), padding=1),
131
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
132
+ Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1),
133
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
134
+ Conv2d(256, 512, kernel_size=1, stride=(1, 2), padding=0),
135
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
136
+ )
137
+
138
+ self.face_decoder_blocks = nn.ModuleList(
139
+ [
140
+ nn.Sequential(
141
+ Conv2d(2048, 512, kernel_size=1, stride=1, padding=0),
142
+ ),
143
+ nn.Sequential(
144
+ Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0), # 4,4
145
+ Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
146
+ ),
147
+ nn.Sequential(
148
+ Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
149
+ Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1),
150
+ Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), # 8,8
151
+ Self_Attention(512, 512),
152
+ ),
153
+ nn.Sequential(
154
+ Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
155
+ Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1),
156
+ Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), # 16, 16
157
+ Self_Attention(384, 384),
158
+ ),
159
+ nn.Sequential(
160
+ Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
161
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1),
162
+ Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), # 32, 32
163
+ Self_Attention(256, 256),
164
+ ),
165
+ nn.Sequential(
166
+ Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
167
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
168
+ Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1),
169
+ ), # 64, 64
170
+ nn.Sequential(
171
+ Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
172
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
173
+ Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1),
174
+ ),
175
+ ]
176
+ ) # 128,128
177
+
178
+ # self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
179
+ # nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
180
+ # nn.Sigmoid())
181
+
182
+ self.output_block = nn.Sequential(
183
+ Conv2dTranspose(80, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
184
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
185
+ nn.Sigmoid(),
186
+ )
187
+
188
+ def forward(
189
+ self,
190
+ face_sequences,
191
+ audio_sequences,
192
+ pose_sequences,
193
+ emotion_sequences,
194
+ blink_sequences,
195
+ ):
196
+ # audio_sequences = (B, T, 1, 80, 16)
197
+ B = audio_sequences.size(0)
198
+
199
+ # disabled for inference
200
+ # input_dim_size = len(face_sequences.size())
201
+ # if input_dim_size > 4:
202
+ # audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
203
+ # pose_sequences = torch.cat([pose_sequences[:, i] for i in range(pose_sequences.size(1))], dim=0)
204
+ # emotion_sequences = torch.cat([emotion_sequences[:, i] for i in range(emotion_sequences.size(1))], dim=0)
205
+ # blink_sequences = torch.cat([blink_sequences[:, i] for i in range(blink_sequences.size(1))], dim=0)
206
+ # face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
207
+ # print(audio_sequences.size(), face_sequences.size(), pose_sequences.size(), emotion_sequences.size())
208
+
209
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
210
+ pose_embedding = self.pose_encoder(pose_sequences) # B, 512, 1, 1
211
+ emotion_embedding = self.emotion_encoder(emotion_sequences) # B, 512, 1, 1
212
+ blink_embedding = self.blink_encoder(blink_sequences) # B, 512, 1, 1
213
+ inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) # B, 1536, 1, 1
214
+ # print(audio_embedding.size(), pose_embedding.size(), emotion_embedding.size(), inputs_embedding.size())
215
+
216
+ feats = []
217
+ x = face_sequences
218
+ for f in self.face_encoder_blocks:
219
+ x = f(x)
220
+ # print(x.shape)
221
+ feats.append(x)
222
+
223
+ x = inputs_embedding
224
+ for f in self.face_decoder_blocks:
225
+ x = f(x)
226
+ # print(x.shape)
227
+
228
+ # try:
229
+ x = torch.cat((x, feats[-1]), dim=1)
230
+ # except Exception as e:
231
+ # print(x.size())
232
+ # print(feats[-1].size())
233
+ # raise e
234
+ feats.pop()
235
+
236
+ x = self.output_block(x)
237
+
238
+ # if input_dim_size > 4:
239
+ # x = torch.split(x, B, dim=0) # [(B, C, H, W)]
240
+ # outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
241
+
242
+ # else:
243
+ outputs = x
244
+
245
+ return outputs
246
+
247
+
248
+ class Self_Attention(nn.Module):
249
+ """
250
+ Source-Reference Attention Layer
251
+ """
252
+
253
+ def __init__(self, in_planes_s, in_planes_r):
254
+ """
255
+ Parameters
256
+ ----------
257
+ in_planes_s: int
258
+ Number of input source feature vector channels.
259
+ in_planes_r: int
260
+ Number of input reference feature vector channels.
261
+ """
262
+ super(Self_Attention, self).__init__()
263
+ self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1)
264
+ self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1)
265
+ self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1)
266
+ self.gamma = nn.Parameter(torch.zeros(1))
267
+ self.softmax = nn.Softmax(dim=-1)
268
+
269
+ def forward(self, source):
270
+ source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source
271
+ reference = source
272
+ """
273
+ Parameters
274
+ ----------
275
+ source : torch.Tensor
276
+ Source feature maps (B x Cs x Ts x Hs x Ws)
277
+ reference : torch.Tensor
278
+ Reference feature maps (B x Cr x Tr x Hr x Wr )
279
+ Returns :
280
+ torch.Tensor
281
+ Source-reference attention value added to the input source features
282
+ torch.Tensor
283
+ Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr)
284
+ """
285
+ s_batchsize, sC, sH, sW = source.size()
286
+ r_batchsize, rC, rH, rW = reference.size()
287
+
288
+ proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1)
289
+ proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH)
290
+ energy = torch.bmm(proj_query, proj_key)
291
+ attention = self.softmax(energy)
292
+ proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW)
293
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
294
+ out = out.view(s_batchsize, sC, sH, sW)
295
+ out = self.gamma * out + source
296
+ return out.half() if isinstance(source, torch.cuda.FloatTensor) else out
inference_util.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # set CUDA_MODULE_LOADING=LAZY to speed up the serverless function
4
+ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
5
+ # set SAFETENSORS_FAST_GPU=1 to speed up the serverless function
6
+ os.environ["SAFETENSORS_FAST_GPU"] = "1"
7
+ import cv2
8
+ import torch
9
+ import time
10
+ import imageio
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import moviepy.editor as mp
14
+ import torch
15
+
16
+ from audio import load_wav, melspectrogram
17
+ from fete_model import FETE_model
18
+ from preprocess_videos import face_detect, load_from_npz
19
+
20
+ fps = 25
21
+ mel_idx_multiplier = 80.0 / fps
22
+
23
+ mel_step_size = 16
24
+ batch_size = 64 if torch.cuda.is_available() else 4
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+ print("Using {} for inference.".format(device))
27
+ use_fp16 = True if torch.cuda.is_available() else False
28
+ print("Using FP16 for inference.") if use_fp16 else None
29
+ torch.backends.cudnn.benchmark = True if device == "cuda" else False
30
+
31
+
32
+ def init_model():
33
+ checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "checkpoints/obama-fp16.safetensors")
34
+ model = FETE_model()
35
+ if checkpoint_path.endswith(".pth") or checkpoint_path.endswith(".ckpt"):
36
+ if device == "cuda":
37
+ checkpoint = torch.load(checkpoint_path)
38
+ else:
39
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
40
+ s = checkpoint["state_dict"]
41
+ else:
42
+ from safetensors import safe_open
43
+
44
+ s = {}
45
+ with safe_open(checkpoint_path, framework="pt", device=device) as f:
46
+ for key in f.keys():
47
+ s[key] = f.get_tensor(key)
48
+ new_s = {}
49
+ for k, v in s.items():
50
+ new_s[k.replace("module.", "")] = v
51
+ model.load_state_dict(new_s)
52
+
53
+ model = model.to(device)
54
+ model.eval()
55
+ print("Model loaded")
56
+ if use_fp16:
57
+ for name, module in model.named_modules():
58
+ if ".query_conv" in name or ".key_conv" in name or ".value_conv" in name:
59
+ # keep attention layers in full precision to avoid error
60
+ module.to(torch.float)
61
+ else:
62
+ module.to(torch.half)
63
+ print("Model converted to half precision to accelerate inference")
64
+ return model
65
+
66
+
67
+ def make_mask(image_size=256, border_size=32):
68
+ mask_bar = np.linspace(1, 0, border_size).reshape(1, -1).repeat(image_size, axis=0)
69
+ mask = np.zeros((image_size, image_size), dtype=np.float32)
70
+ mask[-border_size:, :] += mask_bar.T[::-1]
71
+ mask[:, :border_size] = mask_bar
72
+ mask[:, -border_size:] = mask_bar[:, ::-1]
73
+ mask[-border_size:, :][mask[-border_size:, :] < 0.6] = 0.6
74
+ mask = np.stack([mask] * 3, axis=-1).astype(np.float32)
75
+ return mask
76
+
77
+
78
+ face_mask = make_mask()
79
+
80
+
81
+ def blend_images(foreground, background):
82
+ # Blend the foreground and background images using the mask
83
+ temp_mask = cv2.resize(face_mask, (foreground.shape[1], foreground.shape[0]))
84
+ blended = cv2.multiply(foreground.astype(np.float32), temp_mask)
85
+ blended += cv2.multiply(background.astype(np.float32), 1 - temp_mask)
86
+ blended = np.clip(blended, 0, 255).astype(np.uint8)
87
+ return blended
88
+
89
+
90
+ def smooth_coord(last_coord, current_coord, factor=0.4):
91
+ change = np.array(current_coord) - np.array(last_coord)
92
+ change = change * factor
93
+ return (np.array(last_coord) + np.array(change)).astype(int).tolist()
94
+
95
+
96
+ def add_black(imgs):
97
+ for i in range(len(imgs)):
98
+ # print('x', imgs[i].shape)
99
+ imgs[i] = cv2.vconcat(
100
+ [np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)]
101
+ )
102
+ # imgs[i] = cv2.hconcat([np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8), imgs[i], np.zeros((imgs[i].shape[0], 100, 3), dtype=np.uint8)])[:480+150,740-100:-740+100,:]
103
+
104
+ # print('xx', imgs[i].shape)
105
+ return imgs
106
+
107
+
108
+ def remove_black(img):
109
+ return img[100:-20]
110
+
111
+
112
+ def resize_length(input_attributes, length):
113
+ input_attributes = np.array(input_attributes)
114
+ resized_attributes = [input_attributes[int(i_ * (input_attributes.shape[0] / length))] for i_ in range(length)]
115
+ return np.array(resized_attributes).T
116
+
117
+
118
+ def output_chunks(input_attributes):
119
+ output_chunks = []
120
+ len_ = len(input_attributes[0])
121
+
122
+ i = 0
123
+ # print(mel.shape, pose.shape)
124
+ # (80, 801) (3, 801)
125
+ while 1:
126
+ start_idx = int(i * mel_idx_multiplier)
127
+ if start_idx + mel_step_size > len_:
128
+ output_chunks.append(input_attributes[:, len_ - mel_step_size :])
129
+ break
130
+ output_chunks.append(input_attributes[:, start_idx : start_idx + mel_step_size])
131
+ i += 1
132
+ return output_chunks
133
+
134
+
135
+ def prepare_data(face_path, audio_path, pose, emotion, blink, img_size=256, pads=[0, 0, 0, 0]):
136
+ if os.path.isfile(face_path) and face_path.split(".")[1] in ["jpg", "png", "jpeg"]:
137
+ static = True
138
+ full_frames = [cv2.imread(face_path)]
139
+ else:
140
+ static = False
141
+ video_stream = cv2.VideoCapture(face_path)
142
+
143
+ # print('Reading video frames...')
144
+ full_frames = []
145
+ while 1:
146
+ still_reading, frame = video_stream.read()
147
+ if not still_reading:
148
+ video_stream.release()
149
+ break
150
+ full_frames.append(frame)
151
+ print("Number of frames available for inference: " + str(len(full_frames)))
152
+
153
+ wav = load_wav(audio_path, 16000)
154
+ mel = melspectrogram(wav)
155
+ # take half
156
+ len_ = mel.shape[1] # //2
157
+ mel = mel[:, :len_]
158
+ # print('>>>', mel.shape)
159
+
160
+ pose = resize_length(pose, len_)
161
+ emotion = resize_length(emotion, len_)
162
+ blink = resize_length(blink, len_)
163
+
164
+ if np.isnan(mel.reshape(-1)).sum() > 0:
165
+ raise ValueError("Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again")
166
+
167
+ mel_chunks = output_chunks(mel)
168
+ pose_chunks = output_chunks(pose)
169
+ emotion_chunks = output_chunks(emotion)
170
+ blink_chunks = output_chunks(blink)
171
+
172
+ gen = datagen(face_path, full_frames, mel_chunks, pose_chunks, emotion_chunks, blink_chunks, static=static, img_size=img_size, pads=pads)
173
+ steps = int(np.ceil(float(len(mel_chunks)) / batch_size))
174
+
175
+ return gen, steps
176
+
177
+
178
+ def preprocess_batch(batch):
179
+ return torch.FloatTensor(np.reshape(batch, [len(batch), 1, batch[0].shape[0], batch[0].shape[1]])).to(device)
180
+
181
+
182
+ def datagen(face_path, frames, mels, poses, emotions, blinks, static=False, img_size=256, pads=[0, 0, 0, 0]):
183
+ img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
184
+ scale_factor = img_size // 128
185
+
186
+ # print("Length of mel chunks: {}".format(len(mel_chunks)))
187
+ frames = frames[: len(mels)]
188
+ frames = add_black(frames)
189
+ try:
190
+ video_name = os.path.basename(face_path).split(".")[0]
191
+ coords = load_from_npz(video_name)
192
+ face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
193
+
194
+ except Exception as e:
195
+ print("No existing coords found, running face detection...", "Error: ", e)
196
+ if not static:
197
+ coords = face_detect(frames, pads)
198
+ face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
199
+ else:
200
+ coords = face_detect([frames[0]], pads)
201
+ face_det_results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(frames, coords)]
202
+
203
+ face_det_results = face_det_results[: len(mels)]
204
+
205
+ while len(frames) < len(mels):
206
+ face_det_results = face_det_results + face_det_results[::-1]
207
+ frames = frames + frames[::-1]
208
+ else:
209
+ face_det_results = face_det_results[: len(mels)]
210
+ frames = frames[: len(mels)]
211
+
212
+ for i in range(len(mels)):
213
+ idx = 0 if static else i % len(frames)
214
+ frame_to_save = frames[idx].copy()
215
+ face, coords = face_det_results[idx].copy()
216
+ face = cv2.resize(face, (img_size, img_size))
217
+
218
+ img_batch.append(face)
219
+ mel_batch.append(mels[i])
220
+ pose_batch.append(poses[i])
221
+ emotion_batch.append(emotions[i])
222
+ blink_batch.append(blinks[i])
223
+ frame_batch.append(frame_to_save)
224
+ coords_batch.append(coords)
225
+
226
+ # print(m.shape, poses[i].shape)
227
+ # (80, 16) (3, 16)
228
+ if len(img_batch) >= batch_size:
229
+ img_masked = np.asarray(img_batch).copy()
230
+
231
+ img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0
232
+
233
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
234
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
235
+
236
+ mel_batch = preprocess_batch(mel_batch)
237
+ pose_batch = preprocess_batch(pose_batch)
238
+ emotion_batch = preprocess_batch(emotion_batch)
239
+ blink_batch = preprocess_batch(blink_batch)
240
+
241
+ if use_fp16:
242
+ yield (
243
+ img_batch.half(),
244
+ mel_batch.half(),
245
+ pose_batch.half(),
246
+ emotion_batch.half(),
247
+ blink_batch.half(),
248
+ ), frame_batch, coords_batch
249
+ else:
250
+ yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
251
+ img_batch, mel_batch, pose_batch, emotion_batch, blink_batch, frame_batch, coords_batch = [], [], [], [], [], [], []
252
+
253
+ if len(img_batch) > 0:
254
+ img_masked = np.asarray(img_batch).copy()
255
+
256
+ img_masked[:, 16 * scale_factor : -16 * scale_factor, 16 * scale_factor : -16 * scale_factor] = 0.0
257
+
258
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
259
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
260
+
261
+ mel_batch = preprocess_batch(mel_batch)
262
+ pose_batch = preprocess_batch(pose_batch)
263
+ emotion_batch = preprocess_batch(emotion_batch)
264
+ blink_batch = preprocess_batch(blink_batch)
265
+
266
+ if use_fp16:
267
+ yield (img_batch.half(), mel_batch.half(), pose_batch.half(), emotion_batch.half(), blink_batch.half()), frame_batch, coords_batch
268
+ else:
269
+ yield (img_batch, mel_batch, pose_batch, emotion_batch, blink_batch), frame_batch, coords_batch
270
+
271
+
272
+ def infenrece(model, face_path, audio_path, pose, emotion, blink, preview=False):
273
+ timestamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime(time.time()))
274
+ gen, steps = prepare_data(face_path, audio_path, pose, emotion, blink)
275
+ steps = 1 if preview else steps
276
+ # duration = librosa.get_duration(filename=audio_path)
277
+
278
+ if preview:
279
+ outfile = "/tmp/{}.jpg".format(timestamp)
280
+ else:
281
+ outfile = "/tmp/{}.mp4".format(timestamp)
282
+ tmp_video = "/tmp/temp_{}.mp4".format(timestamp)
283
+ writer = imageio.get_writer(tmp_video, fps=fps, codec="libx264", quality=10, pixelformat="yuv420p", macro_block_size=1) if not preview else None
284
+ # print('Generating frames...', outfile, steps)
285
+ for inputs, frames, coords in tqdm(gen, total=steps):
286
+ with torch.no_grad():
287
+ pred = model(*inputs)
288
+
289
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0
290
+
291
+ for p, f, c in zip(pred, frames, coords):
292
+ y1, y2, x1, x2 = c
293
+ y1, y2, x1, x2 = int(y1), int(y2), int(x1), int(x2)
294
+ y = round(y2 - y1)
295
+ x = round(x2 - x1)
296
+ # print(x, y, p.shape)
297
+ p = cv2.resize(p.astype(np.uint8), (x, y))
298
+
299
+ try:
300
+ f[y1 : y1 + y, x1 : x1 + x] = blend_images(f[y1 : y1 + y, x1 : x1 + x], p)
301
+ except Exception as e:
302
+ print(e)
303
+ f[y1 : y1 + y, x1 : x1 + x] = p
304
+ # out.write(f[100:-20])
305
+ f = remove_black(f)
306
+ if preview:
307
+ cv2.imwrite(outfile, f, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
308
+ return outfile
309
+ writer.append_data(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))
310
+ writer.close()
311
+ video_clip = mp.VideoFileClip(tmp_video)
312
+ audio_clip = mp.AudioFileClip(audio_path)
313
+ video_clip = video_clip.set_audio(audio_clip)
314
+ video_clip.write_videofile(outfile, codec="libx264")
315
+
316
+ print("Saved to {}".format(outfile) if os.path.exists(outfile) else "Failed to save {}".format(outfile))
317
+ try:
318
+ os.remove(tmp_video)
319
+ del video_clip
320
+ del audio_clip
321
+ del gen
322
+ except:
323
+ pass
324
+ return outfile
325
+
326
+
327
+ if __name__ == "__main__":
328
+ model = init_model()
329
+
330
+ from attributtes_utils import input_pose, input_emotion, input_blink
331
+
332
+ pose = input_pose()
333
+ emotion = input_emotion()
334
+ blink = input_blink()
335
+ audio_path = "./assets/sample.wav"
336
+ face_path = "./assets/sample.mp4"
337
+
338
+ infenrece(model, face_path, audio_path, pose, emotion, blink)
preprocess_videos.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import face_detection
2
+ import numpy as np
3
+ import cv2
4
+ from tqdm import tqdm
5
+ import torch
6
+ import glob
7
+ import os
8
+ from natsort import natsorted
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ def get_squre_coords(coords, image, size=None, last_size=None):
13
+ y1, y2, x1, x2 = coords
14
+ w, h = x2 - x1, y2 - y1
15
+ center = (x1 + w // 2, y1 + h // 2)
16
+ if size is None:
17
+ size = (w + h) // 2
18
+ if last_size is not None:
19
+ size = (w + h) // 2
20
+ size = (size - last_size) // 5 + last_size
21
+ x1, y1 = center[0] - size // 2, center[1] - size // 2
22
+ x2, y2 = x1 + size, y1 + size
23
+ return size, [y1, y2, x1, x2]
24
+
25
+
26
+ def get_smoothened_boxes(boxes, T):
27
+ for i in range(len(boxes)):
28
+ if i + T > len(boxes):
29
+ window = boxes[len(boxes) - T :]
30
+ else:
31
+ window = boxes[i : i + T]
32
+ boxes[i] = np.mean(window, axis=0)
33
+ return boxes
34
+
35
+
36
+ def face_detect(images, pads):
37
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device)
38
+
39
+ batch_size = 32 if device == "cuda" else 4
40
+ print("face detect batch size:", batch_size)
41
+ while 1:
42
+ predictions = []
43
+ try:
44
+ for i in tqdm(range(0, len(images), batch_size)):
45
+ predictions.extend(detector.get_detections_for_batch(np.array(images[i : i + batch_size])))
46
+ except RuntimeError:
47
+ if batch_size == 1:
48
+ raise RuntimeError("Image too big to run face detection on GPU. Please use the --resize_factor argument")
49
+ batch_size //= 2
50
+ print("Recovering from OOM error; New batch size: {}".format(batch_size))
51
+ continue
52
+ break
53
+
54
+ results = []
55
+ pady1, pady2, padx1, padx2 = pads
56
+ for rect, image in zip(predictions, images):
57
+ if rect is None:
58
+ cv2.imwrite(".temp/faulty_frame.jpg", image) # check this frame where the face was not detected.
59
+ raise ValueError("Face not detected! Ensure the video contains a face in all the frames.")
60
+
61
+ y1 = max(0, rect[1] - pady1)
62
+ y2 = min(image.shape[0], rect[3] + pady2)
63
+ x1 = max(0, rect[0] - padx1)
64
+ x2 = min(image.shape[1], rect[2] + padx2)
65
+ # y_gap, x_gap = ((y2 - y1) * 2) // 3, ((x2 - x1) * 2) // 3
66
+ y_gap, x_gap = (y2 - y1)//2, (x2 - x1)//2
67
+ coords_ = [y1 - y_gap, y2 + y_gap, x1 - x_gap, x2 + x_gap]
68
+
69
+ # smooth the coords
70
+ _, coords = get_squre_coords(coords_, image, None)
71
+
72
+ y1, y2, x1, x2 = coords
73
+ y1 = max(0, y1)
74
+ y2 = min(image.shape[0], y2)
75
+ x1 = max(0, x1)
76
+ x2 = min(image.shape[1], x2)
77
+
78
+ results.append([x1, y1, x2, y2])
79
+
80
+ print("Number of frames cropped: {}".format(len(results)))
81
+ print("First coords: {}".format(results[0]))
82
+ boxes = np.array(results)
83
+ boxes = get_smoothened_boxes(boxes, T=5)
84
+ # results = [[image[y1:y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
85
+
86
+ del detector
87
+ return boxes
88
+
89
+ def add_black(imgs):
90
+ for i in range(len(imgs)):
91
+ imgs[i] = cv2.vconcat([np.zeros((100, imgs[i].shape[1], 3), dtype=np.uint8), imgs[i], np.zeros((20, imgs[i].shape[1], 3), dtype=np.uint8)])
92
+
93
+ return imgs
94
+
95
+ def preprocess(video_dir="./assets/videos", save_dir="./assets/coords"):
96
+ all_videos = natsorted(glob.glob(os.path.join(video_dir, "*.mp4")))
97
+ for video_path in all_videos:
98
+ video_stream = cv2.VideoCapture(video_path)
99
+
100
+ # print('Reading video frames...')
101
+ full_frames = []
102
+ while 1:
103
+ still_reading, frame = video_stream.read()
104
+ if not still_reading:
105
+ video_stream.release()
106
+ break
107
+ full_frames.append(frame)
108
+ print("Number of frames available for inference: " + str(len(full_frames)))
109
+ full_frames = add_black(full_frames)
110
+ # print('Face detection running...')
111
+ coords = face_detect(full_frames, pads=(0, 0, 0, 0))
112
+ np.savez_compressed(os.path.join(save_dir, os.path.basename(video_path).split(".")[0]), coords=coords)
113
+
114
+
115
+ def load_from_npz(video_name, save_dir="./assets/coords"):
116
+ npz = np.load(os.path.join(save_dir, video_name + ".npz"))
117
+ return npz["coords"]
118
+
119
+ if __name__ == "__main__":
120
+ preprocess()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.7.0.72
2
+ imageio[ffmpeg]==2.31.1
3
+ librosa==0.8.1
4
+ moviepy==1.0.3
5
+ numpy==1.23.5
6
+ safetensors==0.3.2
7
+ torchvision==0.15.2
8
+ gradio==3.41.0
9
+ natsort
10
+ tqdm