File size: 10,067 Bytes
5219368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bc27e7
5219368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bc27e7
 
 
 
 
 
 
5219368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bc27e7
529aa2f
0bc27e7
529aa2f
 
 
5219368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529aa2f
0bc27e7
529aa2f
 
0bc27e7
 
 
 
529aa2f
 
0bc27e7
 
 
 
529aa2f
0bc27e7
 
529aa2f
0bc27e7
 
529aa2f
 
0bc27e7
 
529aa2f
 
0bc27e7
529aa2f
5219368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import argparse
import datetime
import json
import random
import time
import multiprocessing
from pathlib import Path
import os
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
import hotr.data.datasets as datasets
import hotr.util.misc as utils
from hotr.engine.arg_parser import get_args_parser
from hotr.data.datasets import build_dataset, get_coco_api_from_dataset
from hotr.data.datasets.vcoco import make_hoi_transforms
from PIL import Image
from hotr.util.logger import print_params, print_args

import copy
from hotr.data.datasets import builtin_meta
from PIL import Image
import requests
# import mmcv
from matplotlib import pyplot as plt
import imageio

from tools.vis_tool import *
from hotr.models.detr import build

def change_format(results,valid_ids):
    
    boxes,labels,pair_score =\
                    list(map(lambda x: x.cpu().numpy(), [results['boxes'], results['labels'], results['pair_score']]))
    output_i={}
    output_i['predictions']=[]
    output_i['hoi_prediction']=[]

    h_idx=np.where(labels==1)[0]
    for box,label in zip(boxes,labels):
        
        output_i['predictions'].append({'bbox':box.tolist(),'category_id':label})
    
    for i,verb in enumerate(pair_score):
        if i in [1,4,10,23,26,5,18]:
            continue
        for j,hum in enumerate(h_idx):
            for k in range(len(boxes)):
                if verb[j][k]>0:
                    output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]})
            
    return output_i
def vis(args,input_img=None,id=294,return_img=False):

    if args.frozen_weights is not None:
        print("Freeze weights for detector")

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Data Setup
    dataset_train = build_dataset(image_set='train', args=args)
    args.num_classes = dataset_train.num_category()
    args.num_actions = dataset_train.num_action()
    args.action_names = dataset_train.get_actions()
    if args.share_enc: args.hoi_enc_layers = args.enc_layers
    if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers
    if args.dataset_file == 'vcoco':
        # Save V-COCO dataset statistics
        args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0]
        args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1)
        args.human_actions = dataset_train.get_human_action()
        args.object_actions = dataset_train.get_object_action()
        args.num_human_act = dataset_train.num_human_act()
    elif args.dataset_file == 'hico-det':
        args.valid_obj_ids = dataset_train.get_valid_obj_ids()
    print_args(args)

    args.HOIDet=True
    args.eval=True
    args.pretrained_dec=True
    args.share_enc=True
    args.share_dec_param = True
    if args.dataset_file=='hico-det':
        args.valid_ids=args.valid_obj_ids
 
    # Model Setup
    model, criterion, postprocessors = build(args)
    model.to(device)

    model_without_ddp = model

    n_parameters = print_params(model)

    param_dicts = [
        {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": args.lr_backbone,
        },
    ]

    output_dir = Path(args.output_dir)
    
    checkpoint = torch.load(args.resume, map_location='cpu')
    #수정
    module_name=list(checkpoint['model'].keys())
    model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
    
    # if not args.video_vis:
    # url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12))
    # req = requests.get(url, stream=True, timeout=1, verify=False).raw
    
    if input_img is None:
        req = args.image_dir
        img = Image.open(req).convert('RGB')
    else:
        # import pdb;pdb.set_trace()
        img = input_img

    w,h=img.size
    orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device)

    transform=make_hoi_transforms('val')
    sample=img.copy()
    sample,_=transform(sample,None)
    sample = sample.unsqueeze(0).to(device)
    with torch.no_grad():
        model.eval()
        out=model(sample)
        results = postprocessors['hoi'](out, orig_size,dataset=args.dataset_file,args=args)
        output_i=change_format(results[0],args.valid_ids)

    out_dir = './vis'
    image = np.asarray(img, dtype=np.uint8)[:,:,::-1]
    # image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR)

    vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)        
    plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
    # import pdb;pdb.set_trace()
    if return_img:
        return Image.fromarray(vis_img)
    else:
        cv2.imwrite('./vis_res/vis1.jpg',vis_img)
    
    # else:
    #     frames=[]
    #     video_file=id
    
    #     video_reader = mmcv.VideoReader('./vid/'+video_file+'.mp4')
    #     fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    #     video_writer = cv2.VideoWriter(
    #             './vid/'+video_file+'_vis.mp4', fourcc, video_reader.fps,
    #             (video_reader.width, video_reader.height))

    #     orig_size = torch.as_tensor([int(video_reader.height), int(video_reader.width)]).unsqueeze(0).to(device)
    #     transform=make_hoi_transforms('val')

    #     for frame in mmcv.track_iter_progress(video_reader):

    #         frame=mmcv.imread(frame)
    #         frame=frame.copy()
  
    #         frame=Image.fromarray(frame,'RGB')

    #         sample,_=transform(frame,None)
    #         sample=sample.unsqueeze(0).to(device)

    #         with torch.no_grad():
    #             model.eval()
    #             out=model(sample)
    #             results = postprocessors['hoi'](out, orig_size,dataset='vcoco',args=args)
    #             output_i=change_format(results[0],args.valid_ids)

    #         vis_img=draw_img_vcoco(np.array(frame),output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
    #         frames.append(vis_img)
    #         video_writer.write(vis_img)

    #     with imageio.get_writer("smiling.gif", mode="I") as writer:
    #         for idx, frame in enumerate(frames):
    #             # print("Adding frame to GIF file: ", idx + 1)
    #             writer.append_data(frame)
    #     if video_writer:
    #         video_writer.release()
    #     cv2.destroyAllWindows()
 

# def visualization(id, video_vis=False, dataset_file='vcoco', path_id = 0 ,data_path='v-coco', threshold=0.4, topk=10,aug_path = '[]'):

#     parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
#     checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
#     with open('./v-coco/data/vcoco_test.ids') as file:
#       test_idxs = [line.rstrip('\n') for line in file]
#     if not video_vis:
#       id = test_idxs[id]
#     args = parser.parse_args(args=['--dataset_file',dataset_file,'--data_path',data_path,'--resume',checkpoint_dir,'--num_hoi_queries' ,'16','--temperature' ,'0.05', '--augpath_name',aug_path ,'--path_id','{}'.format(path_id)])
#     args.video_vis=video_vis
#     args.threshold=threshold
#     args.topk=topk
    
#     if args.output_dir:
#         Path(args.output_dir).mkdir(parents=True, exist_ok=True)
#     vis(args,id)

# 230727 for huggingface
def visualization(input_img,threshold,topk):

    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    args = parser.parse_args(args=[])
    args.threshold = threshold
    args.topk = int(topk)
    
    # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
    args.resume= './checkpoints/vcoco/checkpoint.pth'
    # with open('./v-coco/data/splits/vcoco_test.ids') as file:
    #   test_idxs = [line.rstrip('\n') for line in file]
    # # if not video_vis:
    # id = test_idxs[309]
    # args = parser.parse_args()
    args.dataset_file = 'vcoco'
    args.data_path = 'v-coco'
    # args.resume = checkpoint_dir
    args.num_hoi_queries = 16
    args.temperature = 0.05
    args.augpath_name = ['p2','p3','p4']
    # args.path_id = 1
    # args.threshold = threshold
    # args.topk = topk
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    vis(args,input_img=input_img,return_img=True)

if __name__ == '__main__':
    parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
    parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float)
    # parser.add_argument('--path_id',help='index of inference path', default=1, type=int)
    parser.add_argument('--topk',help='topk prediction', default=5, type=int)
    parser.add_argument('--video_vis', action='store_true')
    parser.add_argument('--image_dir', default='', type=str)
    args = parser.parse_args()
    # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
    args.resume= './checkpoints/vcoco/checkpoint.pth'
    with open('./v-coco/data/splits/vcoco_test.ids') as file:
      test_idxs = [line.rstrip('\n') for line in file]
    # if not video_vis:
    id = test_idxs[309]
    # args = parser.parse_args()
    # args.dataset_file = 'vcoco'
    # args.data_path = 'v-coco'
    # args.resume = checkpoint_dir
    # args.num_hoi_queries = 16
    # args.temperature = 0.05
    args.augpath_name = ['p2','p3','p4']
    # args.path_id = 1
    
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    vis(args,id)