from __future__ import absolute_import |
from __future__ import division |
from __future__ import print_function |
import argparse |
import csv |
import os |
import shutil |
from PIL import Image |
import torch |
import torch.nn.parallel |
import torch.backends.cudnn as cudnn |
import torch.optim |
import torch.utils.data |
import torch.utils.data.distributed |
import torchvision.transforms as transforms |
import torchvision |
import cv2 |
import numpy as np |
import sys |
sys.path.append("../lib") |
import time |
import models |
from config import cfg |
from config import update_config |
from core.inference import get_final_preds |
from utils.transforms import get_affine_transform |
CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
0: 'nose', |
1: 'left_eye', |
2: 'right_eye', |
3: 'left_ear', |
4: 'right_ear', |
5: 'left_shoulder', |
6: 'right_shoulder', |
7: 'left_elbow', |
8: 'right_elbow', |
9: 'left_wrist', |
10: 'right_wrist', |
11: 'left_hip', |
12: 'right_hip', |
13: 'left_knee', |
14: 'right_knee', |
15: 'left_ankle', |
16: 'right_ankle' |
} |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', |
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', |
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', |
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', |
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', |
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' |
] |
def get_person_detection_boxes(model, img, threshold=0.5): |
pil_image = Image.fromarray(img) |
transform = transforms.Compose([transforms.ToTensor()]) |
transformed_img = transform(pil_image) |
pred = model([transformed_img.to(CTX)]) |
for i in list(pred[0]['labels'].cpu().numpy())] |
pred_boxes = [[(i[0], i[1]), (i[2], i[3])] |
for i in list(pred[0]['boxes'].cpu().detach().numpy())] |
pred_scores = list(pred[0]['scores'].cpu().detach().numpy()) |
person_boxes = [] |
for pred_class, pred_box, pred_score in zip(pred_classes, pred_boxes, pred_scores): |
if (pred_score > threshold) and (pred_class == 'person'): |
person_boxes.append(pred_box) |
return person_boxes |
def get_pose_estimation_prediction(pose_model, image, centers, scales, transform): |
rotation = 0 |
model_inputs = [] |
for center, scale in zip(centers, scales): |
trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) |
model_input = cv2.warpAffine( |
image, |
trans, |
(int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])), |
flags=cv2.INTER_LINEAR) |
model_input = transform(model_input) |
model_inputs.append(model_input) |
model_inputs = torch.stack(model_inputs) |
output = pose_model(model_inputs.to(CTX)) |
coords, _ = get_final_preds( |
cfg, |
output.cpu().detach().numpy(), |
np.asarray(centers), |
np.asarray(scales)) |
return coords |
def box_to_center_scale(box, model_image_width, model_image_height): |
"""convert a box to center,scale information required for pose transformation |
Parameters |
---------- |
box : list of tuple |
list of length 2 with two tuples of floats representing |
bottom left and top right corner of a box |
model_image_width : int |
model_image_height : int |
Returns |
------- |
(numpy array, numpy array) |
Two numpy arrays, coordinates for the center of the box and the scale of the box |
""" |
center = np.zeros((2), dtype=np.float32) |
bottom_left_corner = box[0] |
top_right_corner = box[1] |
box_width = top_right_corner[0]-bottom_left_corner[0] |
box_height = top_right_corner[1]-bottom_left_corner[1] |
bottom_left_x = bottom_left_corner[0] |
bottom_left_y = bottom_left_corner[1] |
center[0] = bottom_left_x + box_width * 0.5 |
center[1] = bottom_left_y + box_height * 0.5 |
aspect_ratio = model_image_width * 1.0 / model_image_height |
pixel_std = 200 |
if box_width > aspect_ratio * box_height: |
box_height = box_width * 1.0 / aspect_ratio |
elif box_width < aspect_ratio * box_height: |
box_width = box_height * aspect_ratio |
scale = np.array( |
[box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], |
dtype=np.float32) |
if center[0] != -1: |
scale = scale * 1.25 |
return center, scale |
def prepare_output_dirs(prefix='/output/'): |
pose_dir = os.path.join(prefix, "pose") |
if os.path.exists(pose_dir) and os.path.isdir(pose_dir): |
shutil.rmtree(pose_dir) |
os.makedirs(pose_dir, exist_ok=True) |
return pose_dir |
def parse_args(): |
parser = argparse.ArgumentParser(description='Train keypoints network') |
parser.add_argument('--cfg', type=str, required=True) |
parser.add_argument('--videoFile', type=str, required=True) |
parser.add_argument('--outputDir', type=str, default='/output/') |
parser.add_argument('--inferenceFps', type=int, default=10) |
parser.add_argument('--writeBoxFrames', action='store_true') |
parser.add_argument('opts', |
help='Modify config options using the command-line', |
default=None, |
nargs=argparse.REMAINDER) |
args = parser.parse_args() |
args.modelDir = '' |
args.logDir = '' |
args.dataDir = '' |
args.prevModelDir = '' |
return args |
def main(): |
pose_transform = transforms.Compose([ |
transforms.ToTensor(), |
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
std=[0.229, 0.224, 0.225]), |
]) |
cudnn.benchmark = cfg.CUDNN.BENCHMARK |
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
args = parse_args() |
update_config(cfg, args) |
pose_dir = prepare_output_dirs(args.outputDir) |
csv_output_rows = [] |
box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) |
box_model.to(CTX) |
box_model.eval() |
pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( |
cfg, is_train=False |
) |
print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) |
pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) |
else: |
print('expected model defined in config at TEST.MODEL_FILE') |
pose_model.to(CTX) |
pose_model.eval() |
vidcap = cv2.VideoCapture(args.videoFile) |
fps = vidcap.get(cv2.CAP_PROP_FPS) |
if fps < args.inferenceFps: |
print('desired inference fps is '+str(args.inferenceFps)+' but video fps is '+str(fps)) |
exit() |
skip_frame_cnt = round(fps / args.inferenceFps) |
frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
outcap = cv2.VideoWriter('{}/{}_pose.avi'.format(args.outputDir, os.path.splitext(os.path.basename(args.videoFile))[0]), |
cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt), (frame_width, frame_height)) |
count = 0 |
while vidcap.isOpened(): |
total_now = time.time() |
ret, image_bgr = vidcap.read() |
count += 1 |
if not ret: |
continue |
if count % skip_frame_cnt != 0: |
continue |
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) |
image_per = image_rgb.copy() |
image_pose = image_rgb.copy() |
else: |
image_per = image_bgr.copy() |
image_pose = image_bgr.copy() |
image_debug = image_bgr.copy() |
now = time.time() |
pred_boxes = get_person_detection_boxes(box_model, image_per, threshold=0.9) |
then = time.time() |
print("Find person bbox in: {} sec".format(then - now)) |
if not pred_boxes: |
count += 1 |
continue |
if args.writeBoxFrames: |
for box in pred_boxes: |
cv2.rectangle(image_debug, box[0], box[1], color=(0, 255, 0), |
thickness=3) |
centers = [] |
scales = [] |
for box in pred_boxes: |
center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) |
centers.append(center) |
scales.append(scale) |
now = time.time() |
pose_preds = get_pose_estimation_prediction(pose_model, image_pose, centers, scales, transform=pose_transform) |
then = time.time() |
print("Find person pose in: {} sec".format(then - now)) |
new_csv_row = [] |
for coords in pose_preds: |
for coord in coords: |
x_coord, y_coord = int(coord[0]), int(coord[1]) |
cv2.circle(image_debug, (x_coord, y_coord), 4, (255, 0, 0), 2) |
new_csv_row.extend([x_coord, y_coord]) |
total_then = time.time() |
text = "{:03.2f} sec".format(total_then - total_now) |
cv2.putText(image_debug, text, (100, 50), cv2.FONT_HERSHEY_SIMPLEX, |
1, (0, 0, 255), 2, cv2.LINE_AA) |
cv2.imshow("pos", image_debug) |
if cv2.waitKey(1) & 0xFF == ord('q'): |
break |
csv_output_rows.append(new_csv_row) |
img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count)) |
cv2.imwrite(img_file, image_debug) |
outcap.write(image_debug) |
csv_headers = ['frame'] |
for keypoint in COCO_KEYPOINT_INDEXES.values(): |
csv_headers.extend([keypoint+'_x', keypoint+'_y']) |
csv_output_filename = os.path.join(args.outputDir, 'pose-data.csv') |
with open(csv_output_filename, 'w', newline='') as csvfile: |
csvwriter = csv.writer(csvfile) |
csvwriter.writerow(csv_headers) |
csvwriter.writerows(csv_output_rows) |
vidcap.release() |
outcap.release() |
cv2.destroyAllWindows() |
if __name__ == '__main__': |
main() |