MimicMotion / mimicmotion /dwpose /dwpose_detector.py
fffiloni's picture
Upload 25 files
e394497 verified
raw
history blame
2.47 kB
import os
import numpy as np
import torch
from .wholebody import Wholebody
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DWposeDetector:
"""
A pose detect method for image-like data.
Parameters:
model_det: (str) serialized ONNX format model path,
such as https://huggingface.co./yzd-v/DWPose/blob/main/yolox_l.onnx
model_pose: (str) serialized ONNX format model path,
such as https://huggingface.co./yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx
device: (str) 'cpu' or 'cuda:{device_id}'
"""
def __init__(self, model_det, model_pose, device='cpu'):
self.args = model_det, model_pose, device
def release_memory(self):
if hasattr(self, 'pose_estimation'):
del self.pose_estimation
import gc; gc.collect()
def __call__(self, oriImg):
if not hasattr(self, 'pose_estimation'):
self.pose_estimation = Wholebody(*self.args)
oriImg = oriImg.copy()
H, W, C = oriImg.shape
with torch.no_grad():
candidate, score = self.pose_estimation(oriImg)
nums, _, locs = candidate.shape
candidate[..., 0] /= float(W)
candidate[..., 1] /= float(H)
body = candidate[:, :18].copy()
body = body.reshape(nums * 18, locs)
subset = score[:, :18].copy()
for i in range(len(subset)):
for j in range(len(subset[i])):
if subset[i][j] > 0.3:
subset[i][j] = int(18 * i + j)
else:
subset[i][j] = -1
# un_visible = subset < 0.3
# candidate[un_visible] = -1
# foot = candidate[:, 18:24]
faces = candidate[:, 24:92]
hands = candidate[:, 92:113]
hands = np.vstack([hands, candidate[:, 113:]])
faces_score = score[:, 24:92]
hands_score = np.vstack([score[:, 92:113], score[:, 113:]])
bodies = dict(candidate=body, subset=subset, score=score[:, :18])
pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)
return pose
dwpose_detector = DWposeDetector(
model_det="models/DWPose/yolox_l.onnx",
model_pose="models/DWPose/dw-ll_ucoco_384.onnx",
device=device)