import argparse
import copy
import os
import pickle
import random
import cv2
import numpy as np
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.runner import load_checkpoint
from mmpose.core import wrap_fp16_model
from mmpose.models import build_posenet
from torchvision import transforms
from models import *
import torchvision.transforms.functional as F
from tools.visualization import plot_results
[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0],
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255],
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255],
[255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]]
class Resize_Pad:
def __init__(self, w=256, h=256):
self.w = w
self.h = h
def __call__(self, image):
_, w_1, h_1 = image.shape
ratio_1 = w_1 / h_1
# check if the original and final aspect ratios are the same within a margin
if round(ratio_1, 2) != 1:
# padding to preserve aspect ratio
if ratio_1 > 1: # Make the image higher
hp = int(w_1 - h_1)
hp = hp // 2
image = F.pad(image, (hp, 0, hp, 0), 0, "constant")
return F.resize(image, [self.h, self.w])
wp = int(h_1 - w_1)
wp = wp // 2
image = F.pad(image, (0, wp, 0, wp), 0, "constant")
return F.resize(image, [self.h, self.w])
return F.resize(image, [self.h, self.w])
def transform_keypoints_to_pad_and_resize(keypoints, image_size):
trans_keypoints = keypoints.clone()
h, w = image_size[:2]
ratio_1 = w / h
if ratio_1 > 1:
# width is bigger than height - pad height
hp = int(w - h)
hp = hp // 2
trans_keypoints[:, 1] = keypoints[:, 1] + hp
trans_keypoints *= (256. / w)
# height is bigger than width - pad width
wp = int(image_size[1] - image_size[0])
wp = wp // 2
trans_keypoints[:, 0] = keypoints[:, 0] + wp
trans_keypoints *= (256. / h)
return trans_keypoints
def parse_args():
parser = argparse.ArgumentParser(description='Pose Anything Demo')
parser.add_argument('--support', help='Image file')
parser.add_argument('--query', help='Image file')
parser.add_argument('--config', default=None, help='test config file path')
parser.add_argument('--checkpoint', default=None, help='checkpoint file')
parser.add_argument('--outdir', default='output', help='checkpoint file')
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. For example, '
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
args = parser.parse_args()
return args
def merge_configs(cfg1, cfg2):
# Merge cfg2 into cfg1
# Overwrite cfg1 if repeated, ignore if value is None.
cfg1 = {} if cfg1 is None else cfg1.copy()
cfg2 = {} if cfg2 is None else cfg2
for k, v in cfg2.items():
if v:
cfg1[k] = v
return cfg1
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
cfg.data.test.test_mode = True
os.makedirs(args.outdir, exist_ok=True)
# Load data
support_img = cv2.imread(args.support)
query_img = cv2.imread(args.query)
if support_img is None or query_img is None:
raise ValueError('Fail to read images')
preprocess = transforms.Compose([
Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)])
# frame = copy.deepcopy(support_img)
padded_support_img = preprocess(support_img).cpu().numpy().transpose(1, 2, 0) * 255
frame = copy.deepcopy(padded_support_img.astype(np.uint8).copy())
kp_src = []
skeleton = []
count = 0
prev_pt = None
prev_pt_idx = None
color_idx = 0
def selectKP(event, x, y, flags, param):
nonlocal kp_src, frame
# if we are in points selection mode, the mouse was clicked,
# list of points with the (x, y) location of the click
# and draw the circle
if event == cv2.EVENT_LBUTTONDOWN:
kp_src.append((x, y))
cv2.circle(frame, (x, y), 2, (0, 0, 255), 1)
cv2.imshow("Source", frame)
if event == cv2.EVENT_RBUTTONDOWN:
kp_src = []
frame = copy.deepcopy(support_img)
cv2.imshow("Source", frame)
def draw_line(event, x, y, flags, param):
nonlocal skeleton, kp_src, frame, count, prev_pt, prev_pt_idx, marked_frame, color_idx
if event == cv2.EVENT_LBUTTONDOWN:
closest_point = min(kp_src, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2)
closest_point_index = kp_src.index(closest_point)
if color_idx < len(COLORS):
c = COLORS[color_idx]
c = random.choices(range(256), k=3)
color = color_idx
cv2.circle(frame, closest_point, 2, c, 1)
if count == 0:
prev_pt = closest_point
prev_pt_idx = closest_point_index
count = count + 1
cv2.imshow("Source", frame)
cv2.line(frame, prev_pt, closest_point, c, 2)
cv2.imshow("Source", frame)
count = 0
skeleton.append((prev_pt_idx, closest_point_index))
color_idx = color_idx + 1
elif event == cv2.EVENT_RBUTTONDOWN:
frame = copy.deepcopy(marked_frame)
cv2.imshow("Source", frame)
count = 0
color_idx = 0
skeleton = []
prev_pt = None
cv2.namedWindow("Source", cv2.WINDOW_NORMAL)
cv2.resizeWindow('Source', 800, 600)
cv2.setMouseCallback("Source", selectKP)
cv2.imshow("Source", frame)
# keep looping until points have been selected
print('Press any key when finished marking the points!! ')
while True:
if cv2.waitKey(1) > 0:
marked_frame = copy.deepcopy(frame)
cv2.setMouseCallback("Source", draw_line)
print('Press any key when finished creating skeleton!!')
while True:
if cv2.waitKey(1) > 0:
kp_src = torch.tensor(kp_src).float()
preprocess = transforms.Compose([
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)])
if len(skeleton) == 0:
skeleton = [(0, 0)]
support_img = preprocess(support_img).flip(0)[None]
query_img = preprocess(query_img).flip(0)[None]
# Create heatmap from keypoints
genHeatMap = TopDownGenerateTargetFewShot()
data_cfg = cfg.data_cfg
data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size])
data_cfg['joint_weights'] = None
data_cfg['use_different_joint_weights'] = False
kp_src_3d = torch.concatenate((kp_src, torch.zeros(kp_src.shape[0], 1)), dim=-1)
kp_src_3d_weight = torch.concatenate((torch.ones_like(kp_src), torch.zeros(kp_src.shape[0], 1)), dim=-1)
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1)
target_s = torch.tensor(target_s).float()[None]
target_weight_s = torch.tensor(target_weight_s).float()[None]
data = {
'img_s': [support_img],
'img_q': query_img,
'target_s': [target_s],
'target_weight_s': [target_weight_s],
'target_q': None,
'target_weight_q': None,
'return_loss': False,
'img_metas': [{'sample_skeleton': [skeleton],
'query_skeleton': skeleton,
'sample_joints_3d': [kp_src_3d],
'query_joints_3d': kp_src_3d,
'sample_center': [kp_src.mean(dim=0)],
'query_center': kp_src.mean(dim=0),
'sample_scale': [kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0]],
'query_scale': kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0],
'sample_rotation': [0],
'query_rotation': 0,
'sample_bbox_score': [1],
'query_bbox_score': 1,
'query_image_file': '',
'sample_image_file': [''],
# Load model
model = build_posenet(cfg.model)
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
model = fuse_conv_bn(model)
with torch.no_grad():
outputs = model(**data)
# visualize results
vis_s_weight = target_weight_s[0]
vis_q_weight = target_weight_s[0]
vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0)
vis_q_image = query_img[0].detach().cpu().numpy().transpose(1, 2, 0)
support_kp = kp_src_3d
if __name__ == '__main__':