import argparse import copy import os import pickle import random import cv2 import numpy as np import torch from mmcv import Config, DictAction from mmcv.cnn import fuse_conv_bn from mmcv.runner import load_checkpoint from mmpose.core import wrap_fp16_model from mmpose.models import build_posenet from torchvision import transforms from models import * import torchvision.transforms.functional as F from tools.visualization import plot_results COLORS = [ [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]] class Resize_Pad: def __init__(self, w=256, h=256): self.w = w self.h = h def __call__(self, image): _, w_1, h_1 = image.shape ratio_1 = w_1 / h_1 # check if the original and final aspect ratios are the same within a margin if round(ratio_1, 2) != 1: # padding to preserve aspect ratio if ratio_1 > 1: # Make the image higher hp = int(w_1 - h_1) hp = hp // 2 image = F.pad(image, (hp, 0, hp, 0), 0, "constant") return F.resize(image, [self.h, self.w]) else: wp = int(h_1 - w_1) wp = wp // 2 image = F.pad(image, (0, wp, 0, wp), 0, "constant") return F.resize(image, [self.h, self.w]) else: return F.resize(image, [self.h, self.w]) def transform_keypoints_to_pad_and_resize(keypoints, image_size): trans_keypoints = keypoints.clone() h, w = image_size[:2] ratio_1 = w / h if ratio_1 > 1: # width is bigger than height - pad height hp = int(w - h) hp = hp // 2 trans_keypoints[:, 1] = keypoints[:, 1] + hp trans_keypoints *= (256. / w) else: # height is bigger than width - pad width wp = int(image_size[1] - image_size[0]) wp = wp // 2 trans_keypoints[:, 0] = keypoints[:, 0] + wp trans_keypoints *= (256. / h) return trans_keypoints def parse_args(): parser = argparse.ArgumentParser(description='Pose Anything Demo') parser.add_argument('--support', help='Image file') parser.add_argument('--query', help='Image file') parser.add_argument('--config', default=None, help='test config file path') parser.add_argument('--checkpoint', default=None, help='checkpoint file') parser.add_argument('--outdir', default='output', help='checkpoint file') parser.add_argument( '--fuse-conv-bn', action='store_true', help='Whether to fuse conv and bn, this will slightly increase' 'the inference speed') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, default={}, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. For example, ' "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") args = parser.parse_args() return args def merge_configs(cfg1, cfg2): # Merge cfg2 into cfg1 # Overwrite cfg1 if repeated, ignore if value is None. cfg1 = {} if cfg1 is None else cfg1.copy() cfg2 = {} if cfg2 is None else cfg2 for k, v in cfg2.items(): if v: cfg1[k] = v return cfg1 def main(): random.seed(0) np.random.seed(0) torch.manual_seed(0) args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) # set cudnn_benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True cfg.data.test.test_mode = True os.makedirs(args.outdir, exist_ok=True) # Load data support_img = cv2.imread(args.support) query_img = cv2.imread(args.query) if support_img is None or query_img is None: raise ValueError('Fail to read images') preprocess = transforms.Compose([ transforms.ToTensor(), Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)]) # frame = copy.deepcopy(support_img) padded_support_img = preprocess(support_img).cpu().numpy().transpose(1, 2, 0) * 255 frame = copy.deepcopy(padded_support_img.astype(np.uint8).copy()) kp_src = [] skeleton = [] count = 0 prev_pt = None prev_pt_idx = None color_idx = 0 def selectKP(event, x, y, flags, param): nonlocal kp_src, frame # if we are in points selection mode, the mouse was clicked, # list of points with the (x, y) location of the click # and draw the circle if event == cv2.EVENT_LBUTTONDOWN: kp_src.append((x, y)) cv2.circle(frame, (x, y), 2, (0, 0, 255), 1) cv2.imshow("Source", frame) if event == cv2.EVENT_RBUTTONDOWN: kp_src = [] frame = copy.deepcopy(support_img) cv2.imshow("Source", frame) def draw_line(event, x, y, flags, param): nonlocal skeleton, kp_src, frame, count, prev_pt, prev_pt_idx, marked_frame, color_idx if event == cv2.EVENT_LBUTTONDOWN: closest_point = min(kp_src, key=lambda p: (p[0] - x) ** 2 + (p[1] - y) ** 2) closest_point_index = kp_src.index(closest_point) if color_idx < len(COLORS): c = COLORS[color_idx] else: c = random.choices(range(256), k=3) color = color_idx cv2.circle(frame, closest_point, 2, c, 1) if count == 0: prev_pt = closest_point prev_pt_idx = closest_point_index count = count + 1 cv2.imshow("Source", frame) else: cv2.line(frame, prev_pt, closest_point, c, 2) cv2.imshow("Source", frame) count = 0 skeleton.append((prev_pt_idx, closest_point_index)) color_idx = color_idx + 1 elif event == cv2.EVENT_RBUTTONDOWN: frame = copy.deepcopy(marked_frame) cv2.imshow("Source", frame) count = 0 color_idx = 0 skeleton = [] prev_pt = None cv2.namedWindow("Source", cv2.WINDOW_NORMAL) cv2.resizeWindow('Source', 800, 600) cv2.setMouseCallback("Source", selectKP) cv2.imshow("Source", frame) # keep looping until points have been selected print('Press any key when finished marking the points!! ') while True: if cv2.waitKey(1) > 0: break marked_frame = copy.deepcopy(frame) cv2.setMouseCallback("Source", draw_line) print('Press any key when finished creating skeleton!!') while True: if cv2.waitKey(1) > 0: break cv2.destroyAllWindows() kp_src = torch.tensor(kp_src).float() preprocess = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), Resize_Pad(cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size)]) if len(skeleton) == 0: skeleton = [(0, 0)] support_img = preprocess(support_img).flip(0)[None] query_img = preprocess(query_img).flip(0)[None] # Create heatmap from keypoints genHeatMap = TopDownGenerateTargetFewShot() data_cfg = cfg.data_cfg data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, cfg.model.encoder_config.img_size]) data_cfg['joint_weights'] = None data_cfg['use_different_joint_weights'] = False kp_src_3d = torch.concatenate((kp_src, torch.zeros(kp_src.shape[0], 1)), dim=-1) kp_src_3d_weight = torch.concatenate((torch.ones_like(kp_src), torch.zeros(kp_src.shape[0], 1)), dim=-1) target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, kp_src_3d, kp_src_3d_weight, sigma=1) target_s = torch.tensor(target_s).float()[None] target_weight_s = torch.tensor(target_weight_s).float()[None] data = { 'img_s': [support_img], 'img_q': query_img, 'target_s': [target_s], 'target_weight_s': [target_weight_s], 'target_q': None, 'target_weight_q': None, 'return_loss': False, 'img_metas': [{'sample_skeleton': [skeleton], 'query_skeleton': skeleton, 'sample_joints_3d': [kp_src_3d], 'query_joints_3d': kp_src_3d, 'sample_center': [kp_src.mean(dim=0)], 'query_center': kp_src.mean(dim=0), 'sample_scale': [kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0]], 'query_scale': kp_src.max(dim=0)[0] - kp_src.min(dim=0)[0], 'sample_rotation': [0], 'query_rotation': 0, 'sample_bbox_score': [1], 'query_bbox_score': 1, 'query_image_file': '', 'sample_image_file': [''], }] } # Load model model = build_posenet(cfg.model) fp16_cfg = cfg.get('fp16', None) if fp16_cfg is not None: wrap_fp16_model(model) load_checkpoint(model, args.checkpoint, map_location='cpu') if args.fuse_conv_bn: model = fuse_conv_bn(model) model.eval() with torch.no_grad(): outputs = model(**data) # visualize results vis_s_weight = target_weight_s[0] vis_q_weight = target_weight_s[0] vis_s_image = support_img[0].detach().cpu().numpy().transpose(1, 2, 0) vis_q_image = query_img[0].detach().cpu().numpy().transpose(1, 2, 0) support_kp = kp_src_3d plot_results(vis_s_image, vis_q_image, support_kp, vis_s_weight, None, vis_q_weight, skeleton, None, torch.tensor(outputs['points']).squeeze(0), out_dir=args.outdir) if __name__ == '__main__': main()