Spaces:
Build error
Build error
Podtekatel
commited on
Commit
β’
046b3c9
1
Parent(s):
4425d8c
Initial commit for arcane
Browse files- .gitattributes +0 -1
- README.md +4 -4
- app.py +66 -0
- demo/IMG1.jpg +0 -0
- demo/IMG2.png +0 -0
- demo/IMG3.jpg +0 -0
- hf_download.py +18 -0
- inference/__init__.py +0 -0
- inference/box_utils.py +31 -0
- inference/center_crop.py +24 -0
- inference/face_detector.py +121 -0
- inference/model_pipeline.py +110 -0
- inference/onnx_model.py +14 -0
- packages.txt +1 -0
- requirements.txt +5 -0
.gitattributes
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
5 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
---
|
2 |
title: Arcane Style Transfer
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: bsd-3-clause
|
11 |
---
|
12 |
|
|
|
1 |
---
|
2 |
title: Arcane Style Transfer
|
3 |
+
emoji: π©π»βπ§ππ
|
4 |
colorFrom: blue
|
5 |
+
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.8.2
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: bsd-3-clause
|
11 |
---
|
12 |
|
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from huggingface_hub import hf_hub_url, cached_download
|
8 |
+
|
9 |
+
from inference.face_detector import StatRetinaFaceDetector
|
10 |
+
from inference.model_pipeline import VSNetModelPipeline
|
11 |
+
from inference.onnx_model import ONNXModel
|
12 |
+
|
13 |
+
logging.basicConfig(
|
14 |
+
format='%(asctime)s %(levelname)-8s %(message)s',
|
15 |
+
level=logging.INFO,
|
16 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
17 |
+
|
18 |
+
MODEL_IMG_SIZE = 256
|
19 |
+
def load_model():
|
20 |
+
REPO_ID = "Podtekatel/ARCNEGAN"
|
21 |
+
FILENAME = "arcane_exp_203_ep_281.onnx"
|
22 |
+
|
23 |
+
global model
|
24 |
+
global pipeline
|
25 |
+
|
26 |
+
model_path = cached_download(
|
27 |
+
hf_hub_url(REPO_ID, FILENAME), use_auth_token=os.getenv('HF_TOKEN')
|
28 |
+
)
|
29 |
+
model = ONNXModel(model_path)
|
30 |
+
|
31 |
+
pipeline = VSNetModelPipeline(model, StatRetinaFaceDetector(MODEL_IMG_SIZE), background_resize=1024, no_detected_resize=1024)
|
32 |
+
return model
|
33 |
+
|
34 |
+
load_model()
|
35 |
+
|
36 |
+
def inference(img):
|
37 |
+
img = np.array(img)
|
38 |
+
out_img = pipeline(img)
|
39 |
+
out_img = Image.fromarray(out_img)
|
40 |
+
return out_img
|
41 |
+
|
42 |
+
|
43 |
+
title = "JJStyleTransfer"
|
44 |
+
description = "Gradio Demo for Arcane Season 1 style transfer. To use it, simply upload your image, or click one of the examples to load them."
|
45 |
+
article = "This is one of my successful experiments on style transfer. I've built my own pipeline, generator model and private dataset to train this model<br>" \
|
46 |
+
"" \
|
47 |
+
"" \
|
48 |
+
"" \
|
49 |
+
"Model pipeline which used in project is improved CartoonGAN.<br>" \
|
50 |
+
"This model was trained on RTX 2080 Ti 1.5 days with batch size 7.<br>" \
|
51 |
+
"Model weights 64 MB in ONNX fp32 format, infers 25 ms on GPU and 150 ms on CPU at 256x256 resolution.<br>" \
|
52 |
+
"If you want to use this app or integrate this model into yours, please contact me at email '[email protected]'."
|
53 |
+
|
54 |
+
imgs_folder = 'demo'
|
55 |
+
examples = [[os.path.join(imgs_folder, img_filename)] for img_filename in sorted(os.listdir(imgs_folder))]
|
56 |
+
|
57 |
+
demo = gr.Interface(
|
58 |
+
fn=inference,
|
59 |
+
inputs=[gr.inputs.Image(type="pil")],
|
60 |
+
outputs=gr.outputs.Image(type="pil"),
|
61 |
+
title=title,
|
62 |
+
description=description,
|
63 |
+
article=article,
|
64 |
+
examples=examples)
|
65 |
+
|
66 |
+
demo.launch()
|
demo/IMG1.jpg
ADDED
demo/IMG2.png
ADDED
demo/IMG3.jpg
ADDED
hf_download.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from huggingface_hub import hf_hub_url, cached_download
|
3 |
+
import joblib
|
4 |
+
|
5 |
+
REPO_ID = "MalchuL/JJBAGAN"
|
6 |
+
FILENAME = "198_jjba_8_k_2_099_ep.onnx"
|
7 |
+
|
8 |
+
model = cached_download(
|
9 |
+
hf_hub_url(REPO_ID, FILENAME)
|
10 |
+
)
|
11 |
+
print(model)
|
12 |
+
|
13 |
+
import onnxruntime
|
14 |
+
ort_session = onnxruntime.InferenceSession(str(model))
|
15 |
+
input_name = ort_session.get_inputs()[0].name
|
16 |
+
ort_inputs = {input_name: np.random.randn(1, 3, 256, 256).astype(dtype=np.float32)}
|
17 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
18 |
+
print(ort_outs)
|
inference/__init__.py
ADDED
File without changes
|
inference/box_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def convert_to_square(bboxes):
|
5 |
+
"""Convert bounding boxes to a square form.
|
6 |
+
Arguments:
|
7 |
+
bboxes: a float numpy array of shape [n, 4].
|
8 |
+
Returns:
|
9 |
+
a float numpy array of shape [4],
|
10 |
+
squared bounding boxes.
|
11 |
+
"""
|
12 |
+
|
13 |
+
square_bboxes = np.zeros_like(bboxes)
|
14 |
+
x1, y1, x2, y2 = bboxes
|
15 |
+
h = y2 - y1 + 1.0
|
16 |
+
w = x2 - x1 + 1.0
|
17 |
+
max_side = np.maximum(h, w)
|
18 |
+
square_bboxes[0] = x1 + w * 0.5 - max_side * 0.5
|
19 |
+
square_bboxes[1] = y1 + h * 0.5 - max_side * 0.5
|
20 |
+
square_bboxes[2] = square_bboxes[0] + max_side - 1.0
|
21 |
+
square_bboxes[3] = square_bboxes[1] + max_side - 1.0
|
22 |
+
return square_bboxes
|
23 |
+
|
24 |
+
|
25 |
+
def scale_box(box, scale):
|
26 |
+
x1, y1, x2, y2 = box
|
27 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
28 |
+
w, h = x2 - x1, y2 - y1
|
29 |
+
new_w, new_h = w * scale, h * scale
|
30 |
+
y1, y2, x1, x2 = center_y - new_h / 2, center_y + new_h / 2, center_x - new_w / 2, center_x + new_w / 2,
|
31 |
+
return np.array((x1, y1, x2, y2))
|
inference/center_crop.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
# From albumentations
|
5 |
+
def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
|
6 |
+
height, width = img.shape[:2]
|
7 |
+
if height < crop_height or width < crop_width:
|
8 |
+
raise ValueError(
|
9 |
+
"Requested crop size ({crop_height}, {crop_width}) is "
|
10 |
+
"larger than the image size ({height}, {width})".format(
|
11 |
+
crop_height=crop_height, crop_width=crop_width, height=height, width=width
|
12 |
+
)
|
13 |
+
)
|
14 |
+
x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
|
15 |
+
img = img[y1:y2, x1:x2]
|
16 |
+
return img
|
17 |
+
|
18 |
+
|
19 |
+
def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
|
20 |
+
y1 = (height - crop_height) // 2
|
21 |
+
y2 = y1 + crop_height
|
22 |
+
x1 = (width - crop_width) // 2
|
23 |
+
x2 = x1 + crop_width
|
24 |
+
return x1, y1, x2, y2
|
inference/face_detector.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from retinaface import RetinaFace
|
8 |
+
from retinaface.model import retinaface_model
|
9 |
+
|
10 |
+
from .box_utils import convert_to_square
|
11 |
+
|
12 |
+
|
13 |
+
class FaceDetector(ABC):
|
14 |
+
def __init__(self, target_size):
|
15 |
+
self.target_size = target_size
|
16 |
+
@abstractmethod
|
17 |
+
def detect_crops(self, img, *args, **kwargs) -> List[np.ndarray]:
|
18 |
+
"""
|
19 |
+
Img is a numpy ndarray in range [0..255], uint8 dtype, RGB type
|
20 |
+
Returns ndarray with [x1, y1, x2, y2] in row
|
21 |
+
"""
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def postprocess_crops(self, crops, *args, **kwargs) -> List[np.ndarray]:
|
26 |
+
return crops
|
27 |
+
|
28 |
+
def sort_faces(self, crops):
|
29 |
+
sorted_faces = sorted(crops, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))
|
30 |
+
sorted_faces = np.stack(sorted_faces, axis=0)
|
31 |
+
return sorted_faces
|
32 |
+
|
33 |
+
def fix_range_crops(self, img, crops):
|
34 |
+
H, W, _ = img.shape
|
35 |
+
final_crops = []
|
36 |
+
for crop in crops:
|
37 |
+
x1, y1, x2, y2 = crop
|
38 |
+
x1 = max(min(round(x1), W), 0)
|
39 |
+
y1 = max(min(round(y1), H), 0)
|
40 |
+
x2 = max(min(round(x2), W), 0)
|
41 |
+
y2 = max(min(round(y2), H), 0)
|
42 |
+
new_crop = [x1, y1, x2, y2]
|
43 |
+
final_crops.append(new_crop)
|
44 |
+
final_crops = np.array(final_crops, dtype=np.int)
|
45 |
+
return final_crops
|
46 |
+
|
47 |
+
def crop_faces(self, img, crops) -> List[np.ndarray]:
|
48 |
+
cropped_faces = []
|
49 |
+
for crop in crops:
|
50 |
+
x1, y1, x2, y2 = crop
|
51 |
+
face_crop = img[y1:y2, x1:x2, :]
|
52 |
+
cropped_faces.append(face_crop)
|
53 |
+
return cropped_faces
|
54 |
+
|
55 |
+
def unify_and_merge(self, cropped_images):
|
56 |
+
return cropped_images
|
57 |
+
|
58 |
+
def __call__(self, img):
|
59 |
+
return self.detect_faces(img)
|
60 |
+
|
61 |
+
def detect_faces(self, img):
|
62 |
+
crops = self.detect_crops(img)
|
63 |
+
if crops is None or len(crops) == 0:
|
64 |
+
return [], []
|
65 |
+
crops = self.sort_faces(crops)
|
66 |
+
updated_crops = self.postprocess_crops(crops)
|
67 |
+
updated_crops = self.fix_range_crops(img, updated_crops)
|
68 |
+
cropped_faces = self.crop_faces(img, updated_crops)
|
69 |
+
unified_faces = self.unify_and_merge(cropped_faces)
|
70 |
+
return unified_faces, updated_crops
|
71 |
+
|
72 |
+
|
73 |
+
class StatRetinaFaceDetector(FaceDetector):
|
74 |
+
def __init__(self, target_size=None):
|
75 |
+
super().__init__(target_size)
|
76 |
+
self.model = retinaface_model.build_model()
|
77 |
+
#self.relative_offsets = [0.3258, 0.5225, 0.3258, 0.1290]
|
78 |
+
self.relative_offsets = [0.3619, 0.5830, 0.3619, 0.1909]
|
79 |
+
|
80 |
+
def postprocess_crops(self, crops, *args, **kwargs) -> np.ndarray:
|
81 |
+
final_crops = []
|
82 |
+
x1_offset, y1_offset, x2_offset, y2_offset = self.relative_offsets
|
83 |
+
for crop in crops:
|
84 |
+
x1, y1, x2, y2 = crop
|
85 |
+
w, h = x2 - x1, y2 - y1
|
86 |
+
x1 -= w * x1_offset
|
87 |
+
y1 -= h * y1_offset
|
88 |
+
x2 += w * x2_offset
|
89 |
+
y2 += h * y2_offset
|
90 |
+
crop = np.array([x1, y1, x2, y2], dtype=crop.dtype)
|
91 |
+
crop = convert_to_square(crop)
|
92 |
+
final_crops.append(crop)
|
93 |
+
final_crops = np.stack(final_crops, axis=0)
|
94 |
+
return final_crops
|
95 |
+
|
96 |
+
def detect_crops(self, img, *args, **kwargs):
|
97 |
+
faces = RetinaFace.detect_faces(img, model=self.model)
|
98 |
+
crops = []
|
99 |
+
if isinstance(faces, tuple):
|
100 |
+
faces = {}
|
101 |
+
for name, face in faces.items():
|
102 |
+
x1, y1, x2, y2 = face['facial_area']
|
103 |
+
crop = np.array([x1, y1, x2, y2])
|
104 |
+
crops.append(crop)
|
105 |
+
if len(crops) > 0:
|
106 |
+
crops = np.stack(crops, axis=0)
|
107 |
+
return crops
|
108 |
+
|
109 |
+
def unify_and_merge(self, cropped_images):
|
110 |
+
if self.target_size is None:
|
111 |
+
return cropped_images
|
112 |
+
else:
|
113 |
+
resized_images = []
|
114 |
+
for cropped_image in cropped_images:
|
115 |
+
resized_image = cv2.resize(cropped_image, (self.target_size, self.target_size),
|
116 |
+
interpolation=cv2.INTER_LINEAR)
|
117 |
+
resized_images.append(resized_image)
|
118 |
+
|
119 |
+
resized_images = np.stack(resized_images, axis=0)
|
120 |
+
return resized_images
|
121 |
+
|
inference/model_pipeline.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .center_crop import center_crop
|
8 |
+
from .face_detector import FaceDetector
|
9 |
+
|
10 |
+
|
11 |
+
class VSNetModelPipeline:
|
12 |
+
def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256):
|
13 |
+
self.background_resize = background_resize
|
14 |
+
self.no_detected_resize = no_detected_resize
|
15 |
+
self.model = model
|
16 |
+
self.face_detector = face_detector
|
17 |
+
self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size)
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def create_circular_mask(h, w, power=None, clipping_coef=0.85):
|
21 |
+
center = (int(w / 2), int(h / 2))
|
22 |
+
|
23 |
+
Y, X = np.ogrid[:h, :w]
|
24 |
+
dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
|
25 |
+
print(dist_from_center.max(), dist_from_center.min())
|
26 |
+
clipping_radius = min((h - center[0]), (w - center[1])) * clipping_coef
|
27 |
+
max_size = max((h - center[0]), (w - center[1]))
|
28 |
+
dist_from_center[dist_from_center < clipping_radius] = clipping_radius
|
29 |
+
dist_from_center[dist_from_center > max_size] = max_size
|
30 |
+
max_distance, min_distance = np.max(dist_from_center), np.min(dist_from_center)
|
31 |
+
dist_from_center = 1 - (dist_from_center - min_distance) / (max_distance - min_distance)
|
32 |
+
if power is not None:
|
33 |
+
dist_from_center = np.power(dist_from_center, power)
|
34 |
+
dist_from_center = np.stack([dist_from_center] * 3, axis=2)
|
35 |
+
# mask = dist_from_center <= radius
|
36 |
+
return dist_from_center
|
37 |
+
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def resize_size(image, size=720, always_apply=True):
|
41 |
+
h, w, c = np.shape(image)
|
42 |
+
if min(h, w) > size or always_apply:
|
43 |
+
if h < w:
|
44 |
+
h, w = int(size * h / w), size
|
45 |
+
else:
|
46 |
+
h, w = size, int(size * w / h)
|
47 |
+
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
|
48 |
+
return image
|
49 |
+
|
50 |
+
def normalize(self, img):
|
51 |
+
img = img.astype(np.float32) / 255 * 2 - 1
|
52 |
+
return img
|
53 |
+
|
54 |
+
def denormalize(self, img):
|
55 |
+
return (img + 1) / 2
|
56 |
+
|
57 |
+
def divide_crop(self, img, must_divided=32):
|
58 |
+
h, w, _ = img.shape
|
59 |
+
h = h // must_divided * must_divided
|
60 |
+
w = w // must_divided * must_divided
|
61 |
+
|
62 |
+
img = center_crop(img, h, w)
|
63 |
+
return img
|
64 |
+
|
65 |
+
def merge_crops(self, faces_imgs, crops, full_image):
|
66 |
+
for face, crop in zip(faces_imgs, crops):
|
67 |
+
x1, y1, x2, y2 = crop
|
68 |
+
W, H = x2 - x1, y2 - y1
|
69 |
+
result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR)
|
70 |
+
face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
71 |
+
input_face = full_image[y1: y2, x1: x2]
|
72 |
+
full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8)
|
73 |
+
return full_image
|
74 |
+
|
75 |
+
def __call__(self, img):
|
76 |
+
return self.process_image(img)
|
77 |
+
|
78 |
+
def process_image(self, img):
|
79 |
+
img = self.resize_size(img, size=self.background_resize)
|
80 |
+
img = self.divide_crop(img)
|
81 |
+
|
82 |
+
face_crops, coords = self.face_detector(img)
|
83 |
+
|
84 |
+
if len(face_crops) > 0:
|
85 |
+
start_time = time.time()
|
86 |
+
faces = self.normalize(face_crops)
|
87 |
+
faces = faces.transpose(0, 3, 1, 2)
|
88 |
+
out_faces = self.model(faces)
|
89 |
+
out_faces = self.denormalize(out_faces)
|
90 |
+
out_faces = out_faces.transpose(0, 2, 3, 1)
|
91 |
+
out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8)
|
92 |
+
end_time = time.time()
|
93 |
+
logging.info(f'Face FPS {1 / (end_time - start_time)}')
|
94 |
+
else:
|
95 |
+
out_faces = []
|
96 |
+
img = self.resize_size(img, size=self.no_detected_resize)
|
97 |
+
img = self.divide_crop(img)
|
98 |
+
|
99 |
+
start_time = time.time()
|
100 |
+
full_image = self.normalize(img)
|
101 |
+
full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2)
|
102 |
+
full_image = self.model(full_image)
|
103 |
+
full_image = self.denormalize(full_image)
|
104 |
+
full_image = full_image.transpose(0, 2, 3, 1)
|
105 |
+
full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8)
|
106 |
+
end_time = time.time()
|
107 |
+
logging.info(f'Background FPS {1 / (end_time - start_time)}')
|
108 |
+
|
109 |
+
result_image = self.merge_crops(out_faces, coords, full_image[0])
|
110 |
+
return result_image
|
inference/onnx_model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import onnxruntime
|
3 |
+
|
4 |
+
|
5 |
+
class ONNXModel:
|
6 |
+
def __init__(self, onnx_mode_path):
|
7 |
+
self.path = onnx_mode_path
|
8 |
+
self.ort_session = onnxruntime.InferenceSession(str(self.path))
|
9 |
+
self.input_name = self.ort_session.get_inputs()[0].name
|
10 |
+
|
11 |
+
def __call__(self, img):
|
12 |
+
ort_inputs = {self.input_name: img.astype(dtype=np.float32)}
|
13 |
+
ort_outs = self.ort_session.run(None, ort_inputs)[0]
|
14 |
+
return ort_outs
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
onnxruntime
|
3 |
+
numpy
|
4 |
+
gradio
|
5 |
+
retina-face
|