import argparse |
import numpy as np |
import matplotlib.pyplot as plt |
import cv2 |
import json |
import matplotlib.lines as mlines |
import matplotlib.patches as mpatches |
from pycocotools.coco import COCO |
from pycocotools.cocoeval import COCOeval |
import os |
class ColorStyle: |
def __init__(self, color, link_pairs, point_color): |
self.color = color |
self.link_pairs = link_pairs |
self.point_color = point_color |
for i in range(len(self.color)): |
self.link_pairs[i].append(tuple(np.array(self.color[i])/255.)) |
self.ring_color = [] |
for i in range(len(self.point_color)): |
self.ring_color.append(tuple(np.array(self.point_color[i])/255.)) |
color1 = [(179,0,0),(228,26,28),(255,255,51), |
(49,163,84), (0,109,45), (255,255,51), |
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), |
(217,95,14), (254,153,41),(255,255,51), |
(44,127,184),(0,0,255)] |
link_pairs1 = [ |
[15, 13], [13, 11], [11, 5], |
[12, 14], [14, 16], [12, 6], |
[3, 1],[1, 2],[1, 0],[0, 2],[2,4], |
[9, 7], [7,5], [5, 6], |
[6, 8], [8, 10], |
] |
point_color1 = [(240,2,127),(240,2,127),(240,2,127), |
(240,2,127), (240,2,127), |
(255,255,51),(255,255,51), |
(254,153,41),(44,127,184), |
(217,95,14),(0,0,255), |
(255,255,51),(255,255,51),(228,26,28), |
(49,163,84),(252,176,243),(0,176,240), |
(255,255,0),(169, 209, 142), |
(255,255,0),(169, 209, 142), |
(255,255,0),(169, 209, 142)] |
xiaochu_style = ColorStyle(color1, link_pairs1, point_color1) |
color2 = [(252,176,243),(252,176,243),(252,176,243), |
(0,176,240), (0,176,240), (0,176,240), |
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), |
(255,255,0), (255,255,0),(169, 209, 142), |
(169, 209, 142),(169, 209, 142)] |
link_pairs2 = [ |
[15, 13], [13, 11], [11, 5], |
[12, 14], [14, 16], [12, 6], |
[3, 1],[1, 2],[1, 0],[0, 2],[2,4], |
[9, 7], [7,5], [5, 6], [6, 8], [8, 10], |
] |
point_color2 = [(240,2,127),(240,2,127),(240,2,127), |
(240,2,127), (240,2,127), |
(255,255,0),(169, 209, 142), |
(255,255,0),(169, 209, 142), |
(255,255,0),(169, 209, 142), |
(252,176,243),(0,176,240),(252,176,243), |
(0,176,240),(252,176,243),(0,176,240), |
(255,255,0),(169, 209, 142), |
(255,255,0),(169, 209, 142), |
(255,255,0),(169, 209, 142)] |
chunhua_style = ColorStyle(color2, link_pairs2, point_color2) |
def parse_args(): |
parser = argparse.ArgumentParser(description='Visualize COCO predictions') |
parser.add_argument('--image-path', |
help='Path of COCO val images', |
type=str, |
default='data/coco/images/val2017/' |
) |
parser.add_argument('--gt-anno', |
help='Path of COCO val annotation', |
type=str, |
default='data/coco/annotations/person_keypoints_val2017.json' |
) |
parser.add_argument('--save-path', |
help="Path to save the visualizations", |
type=str, |
default='visualization/coco/') |
parser.add_argument('--prediction', |
help="Prediction file to visualize", |
type=str, |
required=True) |
parser.add_argument('--style', |
help="Style of the visualization: Chunhua style or Xiaochu style", |
type=str, |
default='chunhua') |
args = parser.parse_args() |
return args |
def map_joint_dict(joints): |
joints_dict = {} |
for i in range(joints.shape[0]): |
x = int(joints[i][0]) |
y = int(joints[i][1]) |
id = i |
joints_dict[id] = (x, y) |
return joints_dict |
def plot(data, gt_file, img_path, save_path, |
link_pairs, ring_color, save=True): |
coco = COCO(gt_file) |
coco_dt = coco.loadRes(data) |
coco_eval = COCOeval(coco, coco_dt, 'keypoints') |
coco_eval._prepare() |
gts_ = coco_eval._gts |
dts_ = coco_eval._dts |
p = coco_eval.params |
p.imgIds = list(np.unique(p.imgIds)) |
if p.useCats: |
p.catIds = list(np.unique(p.catIds)) |
p.maxDets = sorted(p.maxDets) |
catIds = p.catIds if p.useCats else [-1] |
threshold = 0.3 |
joint_thres = 0.2 |
for catId in catIds: |
for imgId in p.imgIds[:5000]: |
gts = gts_[imgId, catId] |
dts = dts_[imgId, catId] |
inds = np.argsort([-d['score'] for d in dts], kind='mergesort') |
dts = [dts[i] for i in inds] |
if len(dts) > p.maxDets[-1]: |
dts = dts[0:p.maxDets[-1]] |
if len(gts) == 0 or len(dts) == 0: |
continue |
sum_score = 0 |
num_box = 0 |
img_name = str(imgId).zfill(12) |
img_file = img_path + img_name + '.jpg' |
data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) |
h = data_numpy.shape[0] |
w = data_numpy.shape[1] |
fig = plt.figure(figsize=(w/100, h/100), dpi=100) |
ax = plt.subplot(1,1,1) |
bk = plt.imshow(data_numpy[:,:,::-1]) |
bk.set_zorder(-1) |
print(img_name) |
for j, gt in enumerate(gts): |
bb = gt['bbox'] |
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 |
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 |
g = np.array(gt['keypoints']) |
vg = g[2::3] |
for i, dt in enumerate(dts): |
dt_bb = dt['bbox'] |
dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2 |
dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2 |
ol_x = min(x1, dt_x1) - max(x0, dt_x0) |
ol_y = min(y1, dt_y1) - max(y0, dt_y0) |
ol_area = ol_x * ol_y |
s_x = max(x1, dt_x1) - min(x0, dt_x0) |
s_y = max(y1, dt_y1) - min(y0, dt_y0) |
sum_area = s_x * s_y |
iou = ol_area / (sum_area + np.spacing(1)) |
score = dt['score'] |
if iou < 0.1 or score < threshold: |
continue |
else: |
print('iou: ', iou) |
dt_w = dt_x1 - dt_x0 |
dt_h = dt_y1 - dt_y0 |
ref = min(dt_w, dt_h) |
num_box += 1 |
sum_score += dt['score'] |
dt_joints = np.array(dt['keypoints']).reshape(17,-1) |
joints_dict = map_joint_dict(dt_joints) |
for k, link_pair in enumerate(link_pairs): |
if link_pair[0] in joints_dict \ |
and link_pair[1] in joints_dict: |
if dt_joints[link_pair[0],2] < joint_thres \ |
or dt_joints[link_pair[1],2] < joint_thres \ |
or vg[link_pair[0]] == 0 \ |
or vg[link_pair[1]] == 0: |
continue |
if k in range(6,11): |
lw = 1 |
else: |
lw = ref / 100. |
line = mlines.Line2D( |
np.array([joints_dict[link_pair[0]][0], |
joints_dict[link_pair[1]][0]]), |
np.array([joints_dict[link_pair[0]][1], |
joints_dict[link_pair[1]][1]]), |
ls='-', lw=lw, alpha=1, color=link_pair[2],) |
line.set_zorder(0) |
ax.add_line(line) |
for k in range(dt_joints.shape[0]): |
if dt_joints[k,2] < joint_thres \ |
or vg[link_pair[0]] == 0 \ |
or vg[link_pair[1]] == 0: |
continue |
if dt_joints[k,0] > w or dt_joints[k,1] > h: |
continue |
if k in range(5): |
radius = 1 |
else: |
radius = ref / 100 |
circle = mpatches.Circle(tuple(dt_joints[k,:2]), |
radius=radius, |
ec='black', |
fc=ring_color[k], |
alpha=1, |
linewidth=1) |
circle.set_zorder(1) |
ax.add_patch(circle) |
avg_score = (sum_score / (num_box+np.spacing(1)))*1000 |
plt.gca().xaxis.set_major_locator(plt.NullLocator()) |
plt.gca().yaxis.set_major_locator(plt.NullLocator()) |
plt.axis('off') |
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0) |
plt.margins(0,0) |
if save: |
plt.savefig(save_path + \ |
'score_'+str(np.int(avg_score))+ \ |
'_id_'+str(imgId)+ \ |
'_'+img_name + '.png', |
format='png', bbox_inckes='tight', dpi=100) |
plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf', |
bbox_inckes='tight', dpi=100) |
plt.close() |
if __name__ == '__main__': |
args = parse_args() |
if args.style == 'xiaochu': |
colorstyle = xiaochu_style |
elif args.style == 'chunhua': |
colorstyle = chunhua_style |
else: |
raise Exception('Invalid color style') |
save_path = args.save_path |
img_path = args.image_path |
if not os.path.exists(save_path): |
try: |
os.makedirs(save_path) |
except Exception: |
print('Fail to make {}'.format(save_path)) |
with open(args.prediction) as f: |
data = json.load(f) |
gt_file = args.gt_anno |
plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True) |