Spaces:
Runtime error
Runtime error
import argparse | |
import datetime | |
import json | |
import random | |
import time | |
import multiprocessing | |
from pathlib import Path | |
import os | |
import cv2 | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, DistributedSampler | |
import hotr.data.datasets as datasets | |
import hotr.util.misc as utils | |
from hotr.engine.arg_parser import get_args_parser | |
from hotr.data.datasets import build_dataset, get_coco_api_from_dataset | |
from hotr.data.datasets.vcoco import make_hoi_transforms | |
from PIL import Image | |
from hotr.util.logger import print_params, print_args | |
import copy | |
from hotr.data.datasets import builtin_meta | |
from PIL import Image | |
import requests | |
# import mmcv | |
from matplotlib import pyplot as plt | |
import imageio | |
from tools.vis_tool import * | |
from hotr.models.detr import build | |
def change_format(results,valid_ids): | |
boxes,labels,pair_score =\ | |
list(map(lambda x: x.cpu().numpy(), [results['boxes'], results['labels'], results['pair_score']])) | |
output_i={} | |
output_i['predictions']=[] | |
output_i['hoi_prediction']=[] | |
h_idx=np.where(labels==1)[0] | |
for box,label in zip(boxes,labels): | |
output_i['predictions'].append({'bbox':box.tolist(),'category_id':label}) | |
for i,verb in enumerate(pair_score): | |
if i in [1,4,10,23,26,5,18]: | |
continue | |
for j,hum in enumerate(h_idx): | |
for k in range(len(boxes)): | |
if verb[j][k]>0: | |
output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]}) | |
return output_i | |
def vis(args,input_img=None,id=294,return_img=False): | |
if args.frozen_weights is not None: | |
print("Freeze weights for detector") | |
if not torch.cuda.is_available(): | |
args.device = 'cpu' | |
device = torch.device(args.device) | |
# fix the seed for reproducibility | |
seed = args.seed + utils.get_rank() | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
# Data Setup | |
dataset_train = build_dataset(image_set='train', args=args) | |
args.num_classes = dataset_train.num_category() | |
args.num_actions = dataset_train.num_action() | |
args.action_names = dataset_train.get_actions() | |
if args.share_enc: args.hoi_enc_layers = args.enc_layers | |
if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers | |
if args.dataset_file == 'vcoco': | |
# Save V-COCO dataset statistics | |
args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0] | |
args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1) | |
args.human_actions = dataset_train.get_human_action() | |
args.object_actions = dataset_train.get_object_action() | |
args.num_human_act = dataset_train.num_human_act() | |
elif args.dataset_file == 'hico-det': | |
args.valid_obj_ids = dataset_train.get_valid_obj_ids() | |
print_args(args) | |
args.HOIDet=True | |
args.eval=True | |
args.pretrained_dec=True | |
args.share_enc=True | |
args.share_dec_param = True | |
if args.dataset_file=='hico-det': | |
args.valid_ids=args.valid_obj_ids | |
# Model Setup | |
model, criterion, postprocessors = build(args) | |
model.to(device) | |
model_without_ddp = model | |
n_parameters = print_params(model) | |
param_dicts = [ | |
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, | |
{ | |
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], | |
"lr": args.lr_backbone, | |
}, | |
] | |
output_dir = Path(args.output_dir) | |
checkpoint = torch.load(args.resume, map_location='cpu') | |
#수정 | |
module_name=list(checkpoint['model'].keys()) | |
model_without_ddp.load_state_dict(checkpoint['model'], strict=False) | |
# if not args.video_vis: | |
# url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12)) | |
# req = requests.get(url, stream=True, timeout=1, verify=False).raw | |
if input_img is None: | |
req = args.image_dir | |
img = Image.open(req).convert('RGB') | |
else: | |
# import pdb;pdb.set_trace() | |
img = input_img | |
w,h=img.size | |
orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device) | |
transform=make_hoi_transforms('val') | |
sample=img.copy() | |
sample,_=transform(sample,None) | |
sample = sample.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
model.eval() | |
out=model(sample) | |
results = postprocessors['hoi'](out, orig_size,dataset=args.dataset_file,args=args) | |
output_i=change_format(results[0],args.valid_ids) | |
out_dir = './vis' | |
image = np.asarray(img, dtype=np.uint8)[:,:,::-1] | |
# image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR) | |
vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES) | |
plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB)) | |
if return_img: | |
vis_img = Image.fromarray(vis_img[:,:,::-1]) | |
# import pdb;pdb.set_trace() | |
return vis_img | |
else: | |
cv2.imwrite('./vis_res/vis1.jpg',vis_img) | |
# else: | |
# frames=[] | |
# video_file=id | |
# video_reader = mmcv.VideoReader('./vid/'+video_file+'.mp4') | |
# fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
# video_writer = cv2.VideoWriter( | |
# './vid/'+video_file+'_vis.mp4', fourcc, video_reader.fps, | |
# (video_reader.width, video_reader.height)) | |
# orig_size = torch.as_tensor([int(video_reader.height), int(video_reader.width)]).unsqueeze(0).to(device) | |
# transform=make_hoi_transforms('val') | |
# for frame in mmcv.track_iter_progress(video_reader): | |
# frame=mmcv.imread(frame) | |
# frame=frame.copy() | |
# frame=Image.fromarray(frame,'RGB') | |
# sample,_=transform(frame,None) | |
# sample=sample.unsqueeze(0).to(device) | |
# with torch.no_grad(): | |
# model.eval() | |
# out=model(sample) | |
# results = postprocessors['hoi'](out, orig_size,dataset='vcoco',args=args) | |
# output_i=change_format(results[0],args.valid_ids) | |
# vis_img=draw_img_vcoco(np.array(frame),output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES) | |
# frames.append(vis_img) | |
# video_writer.write(vis_img) | |
# with imageio.get_writer("smiling.gif", mode="I") as writer: | |
# for idx, frame in enumerate(frames): | |
# # print("Adding frame to GIF file: ", idx + 1) | |
# writer.append_data(frame) | |
# if video_writer: | |
# video_writer.release() | |
# cv2.destroyAllWindows() | |
# def visualization(id, video_vis=False, dataset_file='vcoco', path_id = 0 ,data_path='v-coco', threshold=0.4, topk=10,aug_path = '[]'): | |
# parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | |
# checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' | |
# with open('./v-coco/data/vcoco_test.ids') as file: | |
# test_idxs = [line.rstrip('\n') for line in file] | |
# if not video_vis: | |
# id = test_idxs[id] | |
# args = parser.parse_args(args=['--dataset_file',dataset_file,'--data_path',data_path,'--resume',checkpoint_dir,'--num_hoi_queries' ,'16','--temperature' ,'0.05', '--augpath_name',aug_path ,'--path_id','{}'.format(path_id)]) | |
# args.video_vis=video_vis | |
# args.threshold=threshold | |
# args.topk=topk | |
# if args.output_dir: | |
# Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
# vis(args,id) | |
# 230727 for huggingface | |
def visualization(input_img,threshold,topk): | |
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | |
args = parser.parse_args(args=[]) | |
args.threshold = threshold | |
args.topk = int(topk) | |
# checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' | |
args.resume= './checkpoints/vcoco/checkpoint.pth' | |
# with open('./v-coco/data/splits/vcoco_test.ids') as file: | |
# test_idxs = [line.rstrip('\n') for line in file] | |
# # if not video_vis: | |
# id = test_idxs[309] | |
# args = parser.parse_args() | |
args.dataset_file = 'vcoco' | |
args.data_path = 'v-coco' | |
# args.resume = checkpoint_dir | |
args.num_hoi_queries = 16 | |
args.temperature = 0.05 | |
args.augpath_name = ['p2','p3','p4'] | |
# args.path_id = 1 | |
# args.threshold = threshold | |
# args.topk = topk | |
if args.output_dir: | |
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
return vis(args,input_img=input_img,return_img=True) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()]) | |
parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float) | |
# parser.add_argument('--path_id',help='index of inference path', default=1, type=int) | |
parser.add_argument('--topk',help='topk prediction', default=5, type=int) | |
parser.add_argument('--video_vis', action='store_true') | |
parser.add_argument('--image_dir', default='', type=str) | |
args = parser.parse_args() | |
# checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth' | |
args.resume= './checkpoints/vcoco/checkpoint.pth' | |
with open('./v-coco/data/splits/vcoco_test.ids') as file: | |
test_idxs = [line.rstrip('\n') for line in file] | |
# if not video_vis: | |
id = test_idxs[309] | |
# args = parser.parse_args() | |
# args.dataset_file = 'vcoco' | |
# args.data_path = 'v-coco' | |
# args.resume = checkpoint_dir | |
# args.num_hoi_queries = 16 | |
# args.temperature = 0.05 | |
args.augpath_name = ['p2','p3','p4'] | |
# args.path_id = 1 | |
if args.output_dir: | |
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
vis(args,id) | |