|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import csv |
|
import os |
|
import shutil |
|
|
|
from PIL import Image |
|
import torch |
|
import torch.nn.parallel |
|
import torch.backends.cudnn as cudnn |
|
import torch.optim |
|
import torch.utils.data |
|
import torch.utils.data.distributed |
|
import torchvision.transforms as transforms |
|
import torchvision |
|
import cv2 |
|
import numpy as np |
|
import time |
|
|
|
|
|
import _init_paths |
|
import models |
|
from config import cfg |
|
from config import update_config |
|
from core.function import get_final_preds |
|
from utils.transforms import get_affine_transform |
|
|
|
import sys, os, distutils.core |
|
|
|
# os.system('python -m pip install pyyaml==5.3.1') |
|
# dist = distutils.core.run_setup("./detectron2/setup.py") |
|
# temp = ' '.join([f"'{x}'" for x in dist.install_requires]) |
|
# cmd = "python -m pip install {0}".format(temp) |
|
# os.system(cmd) |
|
sys.path.insert(0, os.path.abspath('./detectron2')) |
|
|
|
import detectron2 |
|
# from detectron2.modeling import build_model |
|
from detectron2 import model_zoo |
|
from detectron2.engine import DefaultPredictor |
|
from detectron2.config import get_cfg |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.data import MetadataCatalog, DatasetCatalog |
|
from detectron2.utils.visualizer import Visualizer |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.data.datasets import register_coco_instances |
|
from detectron2.utils.visualizer import ColorMode |
|
|
|
|
|
COCO_KEYPOINT_INDEXES = { |
|
0: 'r ankle', |
|
1: 'r knee', |
|
2: 'r hip', |
|
3: 'l hip', |
|
4: 'l knee', |
|
5: 'l ankle', |
|
6: 'pelvis', |
|
7: 'thorax', |
|
8: 'upper neck', |
|
9: 'head', |
|
10: 'r wrist', |
|
11: 'r elbow', |
|
12: 'r shoulder', |
|
13: 'l shoulder', |
|
14: 'l elbow', |
|
15: 'l wrist', |
|
} |
|
|
|
COCO_INSTANCE_CATEGORY_NAMES = [ |
|
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
|
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', |
|
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
|
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', |
|
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', |
|
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', |
|
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', |
|
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', |
|
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', |
|
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', |
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', |
|
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' |
|
] |
|
|
|
SKELETON = [ |
|
[1,2],[1,0],[2,6],[3,6],[4,5],[3,4],[6,7],[7,8],[9,8],[7,12],[7,13],[11,12],[13,14],[14,15],[10,11] |
|
] |
|
|
|
CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], |
|
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], |
|
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] |
|
|
|
NUM_KPTS = 16 |
|
|
|
CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
def draw_pose(keypoints,img): |
|
"""draw the keypoints and the skeletons. |
|
:params keypoints: the shape should be equal to [17,2] |
|
:params img: |
|
""" |
|
assert keypoints.shape == (NUM_KPTS,2) |
|
for i in range(len(SKELETON)): |
|
kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1] |
|
x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1] |
|
x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1] |
|
cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1) |
|
cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1) |
|
cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2) |
|
|
|
def draw_bbox(box,img): |
|
"""draw the detected bounding box on the image. |
|
:param img: |
|
""" |
|
# cv2.rectangle(img, (int(box[0][0]), int(box[0][1])), (int(box[1][0]), int(box[1][1])), (0, 255, 0),3) |
|
cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0),3) |
|
|
|
|
|
def get_person_detection_boxes(model, img, threshold=0.5): |
|
pred = model(img) |
|
pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] |
|
for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score |
|
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] |
|
for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes |
|
pred_score = list(pred[0]['scores'].detach().cpu().numpy()) |
|
print(max(pred_score)) |
|
if not pred_score or max(pred_score)<threshold: |
|
print("pred_score didn't make threshold") |
|
return [] |
|
# Get list of index with score greater than threshold |
|
pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] |
|
pred_boxes = pred_boxes[:pred_t+1] |
|
pred_classes = pred_classes[:pred_t+1] |
|
print('pred_boxes', pred_boxes) |
|
person_boxes = [] |
|
for idx, box in enumerate(pred_boxes): |
|
if pred_classes[idx] == 'person': |
|
person_boxes.append(box) |
|
print("person_boxes", person_boxes) |
|
return person_boxes |
|
|
|
|
|
def get_pose_estimation_prediction(pose_model, image, center, scale): |
|
rotation = 0 |
|
|
|
# pose estimation transformation |
|
# srcTri = np.array( [[0, 0], [image.shape[1] - 1, 0], [0, image.shape[0] - 1]] ).astype(np.float32) |
|
# dstTri = np.array( [[0, image.shape[1]*0.33], [image.shape[1]*0.85, image.shape[0]*0.25], [image.shape[1]*0.15, image.shape[0]*0.7]] ).astype(np.float32) |
|
trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) |
|
# trans = cv2.getAffineTransform(srcTri, dstTri) |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
print("scale:", scale) |
|
print("center:", center) |
|
print("trans:", trans) |
|
# print("transform:", transform) |
|
# print("transform.normalize:", transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
# std=[0.229, 0.224, 0.225])) |
|
# model_input = cv2.warpAffine( |
|
# image, |
|
# trans, |
|
# (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), |
|
# flags=cv2.INTER_LINEAR) |
|
model_input = cv2.warpAffine( |
|
image, |
|
trans, |
|
(256, 256), |
|
flags=cv2.INTER_LINEAR) |
|
|
|
|
|
# # pose estimation inference |
|
model_input = transform(model_input).unsqueeze(0) |
|
|
|
# switch to evaluate mode |
|
|
|
pose_model.eval() |
|
with torch.no_grad(): |
|
# compute output heatmap |
|
output = pose_model(model_input) |
|
|
|
print('hi') |
|
preds, _ = get_final_preds( |
|
cfg, |
|
output.clone().cpu().numpy(), |
|
np.asarray([center]), |
|
np.asarray([scale])) |
|
|
|
return preds |
|
|
|
|
|
def box_to_center_scale(box, model_image_width, model_image_height): |
|
"""convert a box to center,scale information required for pose transformation |
|
Parameters |
|
---------- |
|
box : list of tuple |
|
list of length 2 with two tuples of floats representing |
|
bottom left and top right corner of a box |
|
model_image_width : int |
|
model_image_height : int |
|
|
|
Returns |
|
------- |
|
(numpy array, numpy array) |
|
Two numpy arrays, coordinates for the center of the box and the scale of the box |
|
""" |
|
center = np.zeros((2), dtype=np.float32) |
|
|
|
bottom_left_corner = (box[0].data.cpu().item(), box[1].data.cpu().item()) |
|
top_right_corner = (box[2].data.cpu().item(), box[3].data.cpu().item()) |
|
box_width = top_right_corner[0]-bottom_left_corner[0] |
|
box_height = top_right_corner[1]-bottom_left_corner[1] |
|
bottom_left_x = bottom_left_corner[0] |
|
bottom_left_y = bottom_left_corner[1] |
|
center[0] = bottom_left_x + box_width * 0.5 |
|
center[1] = bottom_left_y + box_height * 0.5 |
|
|
|
aspect_ratio = model_image_width * 1.0 / model_image_height |
|
pixel_std = 200 |
|
|
|
if box_width > aspect_ratio * box_height: |
|
box_height = box_width * 1.0 / aspect_ratio |
|
elif box_width < aspect_ratio * box_height: |
|
box_width = box_height * aspect_ratio |
|
scale = np.array( |
|
[box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], |
|
dtype=np.float32) |
|
if center[0] != -1: |
|
scale = scale * 1.25 |
|
|
|
return center, scale |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Train keypoints network') |
|
# general |
|
parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml') |
|
parser.add_argument('--video', type=str) |
|
parser.add_argument('--webcam',action='store_true') |
|
parser.add_argument('--image',type=str) |
|
parser.add_argument('--write',action='store_true') |
|
parser.add_argument('--showFps',action='store_true') |
|
|
|
parser.add_argument('opts', |
|
help='Modify config options using the command-line', |
|
default=None, |
|
nargs=argparse.REMAINDER) |
|
|
|
args = parser.parse_args() |
|
|
|
# args expected by supporting codebase |
|
args.modelDir = '' |
|
args.logDir = '' |
|
args.dataDir = '' |
|
args.prevModelDir = '' |
|
return args |
|
|
|
|
|
def main(): |
|
# cudnn related setting |
|
cudnn.benchmark = cfg.CUDNN.BENCHMARK |
|
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
|
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
|
|
|
args = parse_args() |
|
update_config(cfg, args) |
|
|
|
|
|
# box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) |
|
# box_model.to(CTX) |
|
# box_model.eval() |
|
|
|
cfgg = get_cfg() |
|
cfgg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) |
|
cfgg.OUTPUT_DIR = "./output/diver/" |
|
cfgg.MODEL.WEIGHTS = os.path.join(cfgg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained |
|
cfgg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold |
|
cfgg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512) |
|
cfgg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets) |
|
|
|
predictor = DefaultPredictor(cfgg) |
|
# register_coco_instances("diver_vals", {}, "./coco_annotations/diver/val.json", "../data/ExPose/Olympics2012_Diving_2570") |
|
|
|
# splash_metadata = MetadataCatalog.get('splash_vals') |
|
# dataset_dicts = DatasetCatalog.get("splash_vals") |
|
|
|
pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( |
|
cfg, is_train=False |
|
) |
|
|
|
if cfg.TEST.MODEL_FILE: |
|
print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) |
|
pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) |
|
else: |
|
print('expected model defined in config at TEST.MODEL_FILE') |
|
|
|
pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS) |
|
pose_model.to(CTX) |
|
pose_model.eval() |
|
|
|
# Loading an video or an image or webcam |
|
if args.webcam: |
|
vidcap = cv2.VideoCapture(0) |
|
elif args.video: |
|
vidcap = cv2.VideoCapture(args.video) |
|
elif args.image: |
|
image_bgr = cv2.imread(args.image) |
|
else: |
|
print('please use --video or --webcam or --image to define the input.') |
|
return |
|
|
|
if args.webcam or args.video: |
|
if args.write: |
|
save_path = 'output.avi' |
|
fourcc = cv2.VideoWriter_fourcc(*'XVID') |
|
out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4)))) |
|
while True: |
|
ret, image_bgr = vidcap.read() |
|
if ret: |
|
last_time = time.time() |
|
image = image_bgr[:, :, [2, 1, 0]] |
|
|
|
input = [] |
|
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
|
img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) |
|
input.append(img_tensor) |
|
|
|
# object detection box |
|
pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.5) |
|
|
|
# pose estimation |
|
if len(pred_boxes) >= 1: |
|
for box in pred_boxes: |
|
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) |
|
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() |
|
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) |
|
if len(pose_preds)>=1: |
|
for kpt in pose_preds: |
|
draw_pose(kpt,image_bgr) # draw the poses |
|
|
|
if args.showFps: |
|
fps = 1/(time.time()-last_time) |
|
img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) |
|
|
|
if args.write: |
|
out.write(image_bgr) |
|
|
|
cv2.imshow('demo',image_bgr) |
|
if cv2.waitKey(1) & 0XFF==ord('q'): |
|
break |
|
else: |
|
print('cannot load the video.') |
|
break |
|
|
|
cv2.destroyAllWindows() |
|
vidcap.release() |
|
if args.write: |
|
print('video has been saved as {}'.format(save_path)) |
|
out.release() |
|
|
|
else: |
|
# estimate on the image |
|
last_time = time.time() |
|
image = image_bgr[:, :, [2, 1, 0]] |
|
|
|
input = [] |
|
img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
|
img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX) |
|
input.append(img_tensor) |
|
|
|
# object detection box |
|
# pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.5) |
|
outputs = predictor(image_bgr) |
|
scores = outputs['instances'].scores |
|
pred_boxes = [] |
|
if len(scores) > 0: |
|
pred_boxes = outputs['instances'].pred_boxes |
|
# max_instance = torch.argmax(scores) |
|
# print(pred_boxes) |
|
# pred_boxes = pred_boxes[max_instance] |
|
print("pred_boxes", pred_boxes) |
|
|
|
# pose estimation |
|
if len(pred_boxes) >= 1: |
|
for box in pred_boxes: |
|
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) |
|
# center, scale = box_to_center_scale(box, 360, 640) |
|
image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() |
|
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale) |
|
print("pose_preds", pose_preds) |
|
draw_bbox(box,image_bgr) |
|
if len(pose_preds)>=1: |
|
print('drawing preds') |
|
for kpt in pose_preds: |
|
draw_pose(kpt,image_bgr) # draw the poses |
|
break # only want to use the box with the highest confidence score |
|
|
|
if args.showFps: |
|
fps = 1/(time.time()-last_time) |
|
img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2) |
|
|
|
if args.write: |
|
out_folder = './output/pose-estimator/FINAWorldChampionships2019_Women10m_final_r1_0' |
|
save_path = '{}/{}'.format(out_folder, args.image.split('/')[-1]) |
|
if not os.path.exists(out_folder): |
|
os.makedirs(out_folder) |
|
cv2.imwrite(save_path,image_bgr) |
|
print('the result image has been saved as {}'.format(save_path)) |
|
|
|
# cv2.imshow('demo',image_bgr) |
|
# if cv2.waitKey(0) & 0XFF==ord('q'): |
|
# cv2.destroyAllWindows() |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|