|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
import json |
|
import matplotlib.lines as mlines |
|
import matplotlib.patches as mpatches |
|
from pycocotools.coco import COCO |
|
from pycocotools.cocoeval import COCOeval |
|
import os |
|
|
|
|
|
class ColorStyle: |
|
def __init__(self, color, link_pairs, point_color): |
|
self.color = color |
|
self.link_pairs = link_pairs |
|
self.point_color = point_color |
|
|
|
for i in range(len(self.color)): |
|
self.link_pairs[i].append(tuple(np.array(self.color[i])/255.)) |
|
|
|
self.ring_color = [] |
|
for i in range(len(self.point_color)): |
|
self.ring_color.append(tuple(np.array(self.point_color[i])/255.)) |
|
|
|
|
|
|
|
color1 = [(179,0,0),(228,26,28),(255,255,51), |
|
(49,163,84), (0,109,45), (255,255,51), |
|
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), |
|
(217,95,14), (254,153,41),(255,255,51), |
|
(44,127,184),(0,0,255)] |
|
|
|
link_pairs1 = [ |
|
[15, 13], [13, 11], [11, 5], |
|
[12, 14], [14, 16], [12, 6], |
|
[3, 1],[1, 2],[1, 0],[0, 2],[2,4], |
|
[9, 7], [7,5], [5, 6], |
|
[6, 8], [8, 10], |
|
] |
|
|
|
point_color1 = [(240,2,127),(240,2,127),(240,2,127), |
|
(240,2,127), (240,2,127), |
|
(255,255,51),(255,255,51), |
|
(254,153,41),(44,127,184), |
|
(217,95,14),(0,0,255), |
|
(255,255,51),(255,255,51),(228,26,28), |
|
(49,163,84),(252,176,243),(0,176,240), |
|
(255,255,0),(169, 209, 142), |
|
(255,255,0),(169, 209, 142), |
|
(255,255,0),(169, 209, 142)] |
|
|
|
xiaochu_style = ColorStyle(color1, link_pairs1, point_color1) |
|
|
|
|
|
|
|
|
|
color2 = [(252,176,243),(252,176,243),(252,176,243), |
|
(0,176,240), (0,176,240), (0,176,240), |
|
(240,2,127),(240,2,127),(240,2,127), (240,2,127), (240,2,127), |
|
(255,255,0), (255,255,0),(169, 209, 142), |
|
(169, 209, 142),(169, 209, 142)] |
|
|
|
link_pairs2 = [ |
|
[15, 13], [13, 11], [11, 5], |
|
[12, 14], [14, 16], [12, 6], |
|
[3, 1],[1, 2],[1, 0],[0, 2],[2,4], |
|
[9, 7], [7,5], [5, 6], [6, 8], [8, 10], |
|
] |
|
|
|
point_color2 = [(240,2,127),(240,2,127),(240,2,127), |
|
(240,2,127), (240,2,127), |
|
(255,255,0),(169, 209, 142), |
|
(255,255,0),(169, 209, 142), |
|
(255,255,0),(169, 209, 142), |
|
(252,176,243),(0,176,240),(252,176,243), |
|
(0,176,240),(252,176,243),(0,176,240), |
|
(255,255,0),(169, 209, 142), |
|
(255,255,0),(169, 209, 142), |
|
(255,255,0),(169, 209, 142)] |
|
|
|
chunhua_style = ColorStyle(color2, link_pairs2, point_color2) |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Visualize COCO predictions') |
|
|
|
parser.add_argument('--image-path', |
|
help='Path of COCO val images', |
|
type=str, |
|
default='data/coco/images/val2017/' |
|
) |
|
|
|
parser.add_argument('--gt-anno', |
|
help='Path of COCO val annotation', |
|
type=str, |
|
default='data/coco/annotations/person_keypoints_val2017.json' |
|
) |
|
|
|
parser.add_argument('--save-path', |
|
help="Path to save the visualizations", |
|
type=str, |
|
default='visualization/coco/') |
|
|
|
parser.add_argument('--prediction', |
|
help="Prediction file to visualize", |
|
type=str, |
|
required=True) |
|
|
|
parser.add_argument('--style', |
|
help="Style of the visualization: Chunhua style or Xiaochu style", |
|
type=str, |
|
default='chunhua') |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def map_joint_dict(joints): |
|
joints_dict = {} |
|
for i in range(joints.shape[0]): |
|
x = int(joints[i][0]) |
|
y = int(joints[i][1]) |
|
id = i |
|
joints_dict[id] = (x, y) |
|
|
|
return joints_dict |
|
|
|
def plot(data, gt_file, img_path, save_path, |
|
link_pairs, ring_color, save=True): |
|
|
|
|
|
coco = COCO(gt_file) |
|
coco_dt = coco.loadRes(data) |
|
coco_eval = COCOeval(coco, coco_dt, 'keypoints') |
|
coco_eval._prepare() |
|
gts_ = coco_eval._gts |
|
dts_ = coco_eval._dts |
|
|
|
p = coco_eval.params |
|
p.imgIds = list(np.unique(p.imgIds)) |
|
if p.useCats: |
|
p.catIds = list(np.unique(p.catIds)) |
|
p.maxDets = sorted(p.maxDets) |
|
|
|
|
|
catIds = p.catIds if p.useCats else [-1] |
|
threshold = 0.3 |
|
joint_thres = 0.2 |
|
for catId in catIds: |
|
for imgId in p.imgIds[:5000]: |
|
|
|
gts = gts_[imgId, catId] |
|
dts = dts_[imgId, catId] |
|
inds = np.argsort([-d['score'] for d in dts], kind='mergesort') |
|
dts = [dts[i] for i in inds] |
|
if len(dts) > p.maxDets[-1]: |
|
dts = dts[0:p.maxDets[-1]] |
|
if len(gts) == 0 or len(dts) == 0: |
|
continue |
|
|
|
sum_score = 0 |
|
num_box = 0 |
|
img_name = str(imgId).zfill(12) |
|
|
|
|
|
img_file = img_path + img_name + '.jpg' |
|
data_numpy = cv2.imread(img_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) |
|
h = data_numpy.shape[0] |
|
w = data_numpy.shape[1] |
|
|
|
|
|
fig = plt.figure(figsize=(w/100, h/100), dpi=100) |
|
ax = plt.subplot(1,1,1) |
|
bk = plt.imshow(data_numpy[:,:,::-1]) |
|
bk.set_zorder(-1) |
|
print(img_name) |
|
for j, gt in enumerate(gts): |
|
|
|
bb = gt['bbox'] |
|
x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2 |
|
y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2 |
|
|
|
|
|
g = np.array(gt['keypoints']) |
|
|
|
vg = g[2::3] |
|
|
|
for i, dt in enumerate(dts): |
|
|
|
dt_bb = dt['bbox'] |
|
dt_x0 = dt_bb[0] - dt_bb[2]; dt_x1 = dt_bb[0] + dt_bb[2] * 2 |
|
dt_y0 = dt_bb[1] - dt_bb[3]; dt_y1 = dt_bb[1] + dt_bb[3] * 2 |
|
|
|
ol_x = min(x1, dt_x1) - max(x0, dt_x0) |
|
ol_y = min(y1, dt_y1) - max(y0, dt_y0) |
|
ol_area = ol_x * ol_y |
|
s_x = max(x1, dt_x1) - min(x0, dt_x0) |
|
s_y = max(y1, dt_y1) - min(y0, dt_y0) |
|
sum_area = s_x * s_y |
|
iou = ol_area / (sum_area + np.spacing(1)) |
|
score = dt['score'] |
|
|
|
if iou < 0.1 or score < threshold: |
|
continue |
|
else: |
|
print('iou: ', iou) |
|
dt_w = dt_x1 - dt_x0 |
|
dt_h = dt_y1 - dt_y0 |
|
ref = min(dt_w, dt_h) |
|
num_box += 1 |
|
sum_score += dt['score'] |
|
dt_joints = np.array(dt['keypoints']).reshape(17,-1) |
|
joints_dict = map_joint_dict(dt_joints) |
|
|
|
|
|
for k, link_pair in enumerate(link_pairs): |
|
if link_pair[0] in joints_dict \ |
|
and link_pair[1] in joints_dict: |
|
if dt_joints[link_pair[0],2] < joint_thres \ |
|
or dt_joints[link_pair[1],2] < joint_thres \ |
|
or vg[link_pair[0]] == 0 \ |
|
or vg[link_pair[1]] == 0: |
|
continue |
|
if k in range(6,11): |
|
lw = 1 |
|
else: |
|
lw = ref / 100. |
|
line = mlines.Line2D( |
|
np.array([joints_dict[link_pair[0]][0], |
|
joints_dict[link_pair[1]][0]]), |
|
np.array([joints_dict[link_pair[0]][1], |
|
joints_dict[link_pair[1]][1]]), |
|
ls='-', lw=lw, alpha=1, color=link_pair[2],) |
|
line.set_zorder(0) |
|
ax.add_line(line) |
|
|
|
for k in range(dt_joints.shape[0]): |
|
if dt_joints[k,2] < joint_thres \ |
|
or vg[link_pair[0]] == 0 \ |
|
or vg[link_pair[1]] == 0: |
|
continue |
|
if dt_joints[k,0] > w or dt_joints[k,1] > h: |
|
continue |
|
if k in range(5): |
|
radius = 1 |
|
else: |
|
radius = ref / 100 |
|
|
|
circle = mpatches.Circle(tuple(dt_joints[k,:2]), |
|
radius=radius, |
|
ec='black', |
|
fc=ring_color[k], |
|
alpha=1, |
|
linewidth=1) |
|
circle.set_zorder(1) |
|
ax.add_patch(circle) |
|
|
|
avg_score = (sum_score / (num_box+np.spacing(1)))*1000 |
|
|
|
plt.gca().xaxis.set_major_locator(plt.NullLocator()) |
|
plt.gca().yaxis.set_major_locator(plt.NullLocator()) |
|
plt.axis('off') |
|
plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0) |
|
plt.margins(0,0) |
|
if save: |
|
plt.savefig(save_path + \ |
|
'score_'+str(np.int(avg_score))+ \ |
|
'_id_'+str(imgId)+ \ |
|
'_'+img_name + '.png', |
|
format='png', bbox_inckes='tight', dpi=100) |
|
plt.savefig(save_path +'id_'+str(imgId)+ '.pdf', format='pdf', |
|
bbox_inckes='tight', dpi=100) |
|
|
|
plt.close() |
|
|
|
if __name__ == '__main__': |
|
|
|
args = parse_args() |
|
if args.style == 'xiaochu': |
|
|
|
colorstyle = xiaochu_style |
|
elif args.style == 'chunhua': |
|
|
|
colorstyle = chunhua_style |
|
else: |
|
raise Exception('Invalid color style') |
|
|
|
save_path = args.save_path |
|
img_path = args.image_path |
|
if not os.path.exists(save_path): |
|
try: |
|
os.makedirs(save_path) |
|
except Exception: |
|
print('Fail to make {}'.format(save_path)) |
|
|
|
|
|
with open(args.prediction) as f: |
|
data = json.load(f) |
|
gt_file = args.gt_anno |
|
plot(data, gt_file, img_path, save_path, colorstyle.link_pairs, colorstyle.ring_color, save=True) |
|
|
|
|