import argparse import datetime import json import random import time import multiprocessing from pathlib import Path import os import cv2 import numpy as np import torch from torch.utils.data import DataLoader, DistributedSampler import hotr.data.datasets as datasets import hotr.util.misc as utils from hotr.engine.arg_parser import get_args_parser from hotr.data.datasets import build_dataset, get_coco_api_from_dataset from hotr.data.datasets.vcoco import make_hoi_transforms from PIL import Image from hotr.util.logger import print_params, print_args import copy from hotr.data.datasets import builtin_meta from PIL import Image import requests # import mmcv from matplotlib import pyplot as plt import imageio from tools.vis_tool import * from hotr.models.detr import build def change_format(results,valid_ids): boxes,labels,pair_score =\ list(map(lambda x: x.cpu().numpy(), [results['boxes'], results['labels'], results['pair_score']])) output_i={} output_i['predictions']=[] output_i['hoi_prediction']=[] h_idx=np.where(labels==1)[0] for box,label in zip(boxes,labels): output_i['predictions'].append({'bbox':box.tolist(),'category_id':label}) for i,verb in enumerate(pair_score): if i in [1,4,10,23,26,5,18]: continue for j,hum in enumerate(h_idx): for k in range(len(boxes)): if verb[j][k]>0: output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]}) return output_i def vis(args,input_img=None,id=294,return_img=False): if args.frozen_weights is not None: print("Freeze weights for detector") if not torch.cuda.is_available(): args.device = 'cpu' device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + utils.get_rank() torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) # Data Setup dataset_train = build_dataset(image_set='train', args=args) args.num_classes = dataset_train.num_category() args.num_actions = dataset_train.num_action() args.action_names = dataset_train.get_actions() if args.share_enc: args.hoi_enc_layers = args.enc_layers if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers if args.dataset_file == 'vcoco': # Save V-COCO dataset statistics args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0] args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1) args.human_actions = dataset_train.get_human_action() args.object_actions = dataset_train.get_object_action() args.num_human_act = dataset_train.num_human_act() elif args.dataset_file == 'hico-det': args.valid_obj_ids = dataset_train.get_valid_obj_ids() print_args(args) args.HOIDet=True args.eval=True args.pretrained_dec=True args.share_enc=True args.share_dec_param = True if args.dataset_file=='hico-det': args.valid_ids=args.valid_obj_ids # Model Setup model, criterion, postprocessors = build(args) model.to(device) model_without_ddp = model n_parameters = print_params(model) param_dicts = [ {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, { "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], "lr": args.lr_backbone, }, ] output_dir = Path(args.output_dir) checkpoint = torch.load(args.resume, map_location='cpu') #수정 module_name=list(checkpoint['model'].keys()) model_without_ddp.load_state_dict(checkpoint['model'], strict=False) # if not args.video_vis: # url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12)) # req = requests.get(url, stream=True, timeout=1, verify=False).raw if input_img is None: req = args.image_dir img = Image.open(req).convert('RGB') else: # import pdb;pdb.set_trace() img = input_img w,h=img.size orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device) transform=make_hoi_transforms('val') sample=img.copy() sample,_=transform(sample,None) sample = sample.unsqueeze(0).to(device) with torch.no_grad(): model.eval() out=model(sample) results = postprocessors['hoi'](out, orig_size,dataset=args.dataset_file,args=args) output_i=change_format(results[0],args.valid_ids) out_dir = './vis' image = np.asarray(img, dtype=np.uint8)[:,:,::-1] # image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR) vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES) plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB)) if return_img: vis_img = Image.fromarray(vis_img[:,:,::-1]) # import pdb;pdb.set_trace() return vis_img else: cv2.imwrite('./vis_res/vis1.jpg',vis_img) # else: # frames=[] # video_file=id # video_reader = mmcv.VideoReader('./vid/'+video_file+'.mp4') # fourcc = cv2.VideoWriter_fourcc(*'mp4v') # video_writer = cv2.VideoWriter( # './vid/'+video_file+'_vis.mp4', fourcc, video_reader.fps, # (video_reader.width, video_reader.height)) # orig_size = torch.as_tensor([int(video_reader.height), int(video_reader.width)]).unsqueeze(0).to(device) # transform=make_hoi_transforms('val') # for frame in mmcv.track_iter_progress(video_reader): # frame=mmcv.imread(frame) # frame=frame.copy() # frame=Image.fromarray(frame,'RGB') # sample,_=transform(frame,None) # sample=sample.unsqueeze(0).to(device) # with torch.no_grad(): # model.eval() # out=model(sample) # results = postprocessors['hoi'](out, orig_size,dataset='vcoco',args=args) # output_i=change_format(results[0],args.valid_ids) # vis_img=draw_img_vcoco(np.array(frame),output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES) # frames.append(vis_img) # video_writer.write(vis_img) # with imageio.get_writer("smiling.gif", mode="I") as writer: # for idx, frame in enumerate(frames): # # print("Adding frame to GIF file: ", idx + 1) # writer.append_data(frame) # if video_writer: # video_writer.release() # cv2.destroyAllWindows() # def visualization(id, video_vis=False, dataset_file='vcoco', path_id = 0 ,data_path='v-coco', threshold=0.4, topk=10,aug_path = '[]'): # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' # with open('./v-coco/data/vcoco_test.ids') as file: # test_idxs = [line.rstrip('\n') for line in file] # if not video_vis: # id = test_idxs[id] # args = parser.parse_args(args=['--dataset_file',dataset_file,'--data_path',data_path,'--resume',checkpoint_dir,'--num_hoi_queries' ,'16','--temperature' ,'0.05', '--augpath_name',aug_path ,'--path_id','{}'.format(path_id)]) # args.video_vis=video_vis # args.threshold=threshold # args.topk=topk # if args.output_dir: # Path(args.output_dir).mkdir(parents=True, exist_ok=True) # vis(args,id) # 230727 for huggingface def visualization(input_img,threshold,topk): parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) args = parser.parse_args(args=[]) args.threshold = threshold args.topk = int(topk) # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' args.resume= './checkpoints/vcoco/checkpoint.pth' # with open('./v-coco/data/splits/vcoco_test.ids') as file: # test_idxs = [line.rstrip('\n') for line in file] # # if not video_vis: # id = test_idxs[309] # args = parser.parse_args() args.dataset_file = 'vcoco' args.data_path = 'v-coco' # args.resume = checkpoint_dir args.num_hoi_queries = 16 args.temperature = 0.05 args.augpath_name = ['p2','p3','p4'] # args.path_id = 1 # args.threshold = threshold # args.topk = topk if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) return vis(args,input_img=input_img,return_img=True) if __name__ == '__main__': parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float) # parser.add_argument('--path_id',help='index of inference path', default=1, type=int) parser.add_argument('--topk',help='topk prediction', default=5, type=int) parser.add_argument('--video_vis', action='store_true') parser.add_argument('--image_dir', default='', type=str) args = parser.parse_args() # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' args.resume= './checkpoints/vcoco/checkpoint.pth' with open('./v-coco/data/splits/vcoco_test.ids') as file: test_idxs = [line.rstrip('\n') for line in file] # if not video_vis: id = test_idxs[309] # args = parser.parse_args() # args.dataset_file = 'vcoco' # args.data_path = 'v-coco' # args.resume = checkpoint_dir # args.num_hoi_queries = 16 # args.temperature = 0.05 args.augpath_name = ['p2','p3','p4'] # args.path_id = 1 if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) vis(args,id)