Jhp commited on
Commit
5219368
1 Parent(s): b5fd524
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -137,3 +137,6 @@ Makefile
137
  #datasets
138
  hico_20160224_det
139
  v-coco
 
 
 
 
137
  #datasets
138
  hico_20160224_det
139
  v-coco
140
+
141
+ # *.ipynb
142
+ vis_res
hotr/models/hotr.py CHANGED
@@ -182,7 +182,7 @@ class HOTR(nn.Module):
182
 
183
  H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1)
184
  O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1)
185
-
186
  outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim)
187
  outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag]
188
 
 
182
 
183
  H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1)
184
  O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1)
185
+ # import pdb;pdb.set_trace()
186
  outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim)
187
  outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag]
188
 
hotr/util/misc.py CHANGED
@@ -22,7 +22,7 @@ from torch import Tensor
22
 
23
  # needed due to empty tensor bug in pytorch and torchvision 0.5
24
  import torchvision
25
- if float(torchvision.__version__[:3]) < 0.7:
26
  from torchvision.ops import _new_empty_tensor
27
  from torchvision.ops.misc import _output_size
28
 
@@ -388,7 +388,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
388
  This will eventually be supported natively by PyTorch, and this
389
  class can go away.
390
  """
391
- if float(torchvision.__version__[:3]) < 0.7:
392
  if input.numel() > 0:
393
  return torch.nn.functional.interpolate(
394
  input, size, scale_factor, mode, align_corners
 
22
 
23
  # needed due to empty tensor bug in pytorch and torchvision 0.5
24
  import torchvision
25
+ if float(torchvision.__version__.split('.',2)[1]) < 5:
26
  from torchvision.ops import _new_empty_tensor
27
  from torchvision.ops.misc import _output_size
28
 
 
388
  This will eventually be supported natively by PyTorch, and this
389
  class can go away.
390
  """
391
+ if float(torchvision.__version__.split('.',2)[1]) < 5:
392
  if input.numel() > 0:
393
  return torch.nn.functional.interpolate(
394
  input, size, scale_factor, mode, align_corners
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pycocotools
2
+ opencv-python
3
+ wandb
4
+ imageio
5
+ scipy
tools/vis_tool.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import numpy as np
6
+ import cv2
7
+
8
+
9
+ vcoco_action_string = {2: 'hold', 3: 'stand', 4: 'sit', 5: 'ride', 6: 'walk',\
10
+ 7: 'look', 8: 'hit_inst', 9: 'hit_obj', 10: 'eat_obj', \
11
+ 11: 'eat_inst', 12: 'jump', 13: 'lay', 14: 'talk', 15: \
12
+ 'carry', 16: 'throw', 17: 'catch', 18: 'cut_inst', 19:'cut_obj', \
13
+ 20: 'run', 21: 'work_on_comp', 22: 'ski', 23: 'surf', 24: 'skateboard', \
14
+ 25: 'smile', 26: 'drink', 27: 'kick', 28: 'point', 29: 'read', 30: 'snowboard'}
15
+ def draw_box_on_img(box, img,color=None):
16
+
17
+ vis_img = img.copy()
18
+ box = [int(x) for x in box]
19
+ cv2.rectangle(vis_img, (box[0], box[1]), (box[2], box[3]), color, 2)
20
+ draw_point=[int((box[0]+box[2])*1.0/2),int((box[1]+box[3])*1.0/2)]
21
+
22
+ return vis_img,color
23
+
24
+
25
+ def draw_line_on_img_vcoco(box,line, img, class_index,color):
26
+
27
+ vis_img = img.copy()
28
+ font=cv2.FONT_HERSHEY_SIMPLEX
29
+ x=int(box[0])+2
30
+ y=int(box[1])+2
31
+ f=int(box[1])+2
32
+ for i in range(len(class_index)):
33
+
34
+ font_scale=1
35
+ font_thickness=2
36
+
37
+ text_size, _ = cv2.getTextSize(vcoco_action_string[class_index[i]] , font, font_scale, font_thickness)
38
+ vis_img=cv2.rectangle(vis_img,(x,y),(x+text_size[0],y+text_size[1]+5),color[1],-1)
39
+
40
+
41
+ vis_img=cv2.putText(vis_img, vcoco_action_string[class_index[i]] ,(x,y + text_size[1] ),font,font_scale,[51,255,153],font_thickness)
42
+ y=y+text_size[1]+5
43
+
44
+ return vis_img,y
45
+
46
+
47
+ def draw_img_vcoco(img, output_i, top_k,threshold,color):
48
+ list_action = []
49
+ for action in output_i['hoi_prediction']:
50
+ subject_id = action['subject_id']
51
+ object_id = action['object_id']
52
+ category_id = action['category_id']
53
+ score = action['score']
54
+ single_out = [subject_id,object_id,category_id,score]
55
+ list_action.append(single_out)
56
+ list_action = sorted(list_action, key=lambda x:x[-1], reverse=True)
57
+ action_dict = []
58
+ action_cate = []
59
+ action_color=[]
60
+ subj_box=[]
61
+ sb={}
62
+ sbj=[]
63
+ for action in list_action[:top_k]:
64
+
65
+ subject_id,object_id,category_id,score = action
66
+ if score<threshold:
67
+ break
68
+ subject_obj = output_i['predictions'][subject_id]
69
+ subject_box = subject_obj['bbox']
70
+ object_obj = output_i['predictions'][object_id]
71
+ object_box = object_obj['bbox']
72
+
73
+ point_1 = [int((subject_box[0]+subject_box[2])*1.0/2),int((subject_box[1]+subject_box[3])*1.0/2)]
74
+ point_2 = [int((object_box[0]+object_box[2])*1.0/2),int((object_box[1]+object_box[3])*1.0/2)]
75
+
76
+ if [point_1,point_2] not in action_dict:
77
+
78
+ img,color_hum = draw_box_on_img(subject_box, img, color[subject_obj['category_id']]['color'])
79
+
80
+ img,color_obj = draw_box_on_img(object_box, img, color[object_obj['category_id']]['color'])
81
+
82
+ action_dict.append([point_1,point_2])
83
+ action_color.append([color_hum,color_obj])
84
+ subj_box.append([int(subject_box[0]),int(subject_box[1])])
85
+
86
+ action_cate.append([])
87
+ action_cate[action_dict.index([point_1,point_2])].append(category_id)
88
+
89
+ for i,(action_item,clr) in enumerate(zip(action_dict,action_color)):
90
+
91
+ img,offset = draw_line_on_img_vcoco(subj_box[i],action_item,img,action_cate[action_dict.index(action_item)],clr)
92
+
93
+ for p in range(i+1,len(subj_box)):
94
+ if subj_box[p]==subj_box[i]:
95
+ subj_box[p][1]=offset
96
+ return img
upload_checkpoint_hugginface.ipynb ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "a19d0964-4c83-4bc9-b59f-f04a57ca020f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "5dba33c11d7a4f708b5d6a03869ccb30",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": [
17
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
18
+ ]
19
+ },
20
+ "metadata": {},
21
+ "output_type": "display_data"
22
+ }
23
+ ],
24
+ "source": [
25
+ "from huggingface_hub import notebook_login\n",
26
+ " \n",
27
+ "notebook_login()"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 4,
33
+ "id": "4b0ba848-4799-476d-b1d6-16a7cde1f4ad",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Cloning https://huggingface.co/jhp/hoi into local empty directory.\n"
41
+ ]
42
+ }
43
+ ],
44
+ "source": [
45
+ "from huggingface_hub import Repository\n",
46
+ " \n",
47
+ "repo = Repository('CPC_HOTR', clone_from='jhp/hoi')"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 9,
53
+ "id": "5d092b5a-99b3-4a4c-9ef4-90724eb665ef",
54
+ "metadata": {},
55
+ "outputs": [
56
+ {
57
+ "data": {
58
+ "application/vnd.jupyter.widget-view+json": {
59
+ "model_id": "b51f83592a414a1ea646717c434e2a23",
60
+ "version_major": 2,
61
+ "version_minor": 0
62
+ },
63
+ "text/plain": [
64
+ "checkpoint.pth: 0%| | 0.00/301M [00:00<?, ?B/s]"
65
+ ]
66
+ },
67
+ "metadata": {},
68
+ "output_type": "display_data"
69
+ },
70
+ {
71
+ "data": {
72
+ "text/plain": [
73
+ "'https://huggingface.co/jhp/hoi/tree/main/./'"
74
+ ]
75
+ },
76
+ "execution_count": 9,
77
+ "metadata": {},
78
+ "output_type": "execute_result"
79
+ }
80
+ ],
81
+ "source": [
82
+ "from huggingface_hub import HfApi\n",
83
+ "api = HfApi()\n",
84
+ "api.upload_folder(\n",
85
+ " folder_path=\"./checkpoints/\",\n",
86
+ " repo_id=\"jhp/hoi\",\n",
87
+ " path_in_repo=\"./\",\n",
88
+ " # allow_patterns=\"*.txt\", # Upload all local text files\n",
89
+ " # delete_patterns=\"*.txt\", # Delete all remote text files before\n",
90
+ ")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "82eedf6b-21a2-4dad-bc7d-6f3631525ff2",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": []
100
+ }
101
+ ],
102
+ "metadata": {
103
+ "kernelspec": {
104
+ "display_name": "Python 3 (ipykernel)",
105
+ "language": "python",
106
+ "name": "python3"
107
+ },
108
+ "language_info": {
109
+ "codemirror_mode": {
110
+ "name": "ipython",
111
+ "version": 3
112
+ },
113
+ "file_extension": ".py",
114
+ "mimetype": "text/x-python",
115
+ "name": "python",
116
+ "nbconvert_exporter": "python",
117
+ "pygments_lexer": "ipython3",
118
+ "version": "3.9.17"
119
+ }
120
+ },
121
+ "nbformat": 4,
122
+ "nbformat_minor": 5
123
+ }
visualization.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import random
5
+ import time
6
+ import multiprocessing
7
+ from pathlib import Path
8
+ import os
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data import DataLoader, DistributedSampler
13
+ import hotr.data.datasets as datasets
14
+ import hotr.util.misc as utils
15
+ from hotr.engine.arg_parser import get_args_parser
16
+ from hotr.data.datasets import build_dataset, get_coco_api_from_dataset
17
+ from hotr.data.datasets.vcoco import make_hoi_transforms
18
+ from PIL import Image
19
+ from hotr.util.logger import print_params, print_args
20
+
21
+ import copy
22
+ from hotr.data.datasets import builtin_meta
23
+ from PIL import Image
24
+ import requests
25
+ # import mmcv
26
+ from matplotlib import pyplot as plt
27
+ import imageio
28
+
29
+ from tools.vis_tool import *
30
+ from hotr.models.detr import build
31
+
32
+ def change_format(results,valid_ids):
33
+
34
+ boxes,labels,pair_score =\
35
+ list(map(lambda x: x.cpu().numpy(), [results['boxes'], results['labels'], results['pair_score']]))
36
+ output_i={}
37
+ output_i['predictions']=[]
38
+ output_i['hoi_prediction']=[]
39
+
40
+ h_idx=np.where(labels==1)[0]
41
+ for box,label in zip(boxes,labels):
42
+
43
+ output_i['predictions'].append({'bbox':box.tolist(),'category_id':label})
44
+
45
+ for i,verb in enumerate(pair_score):
46
+ if i in [1,4,10,23,26,5,18]:
47
+ continue
48
+ for j,hum in enumerate(h_idx):
49
+ for k in range(len(boxes)):
50
+ if verb[j][k]>0:
51
+ output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]})
52
+
53
+ return output_i
54
+ def vis(args,id=294):
55
+
56
+ if args.frozen_weights is not None:
57
+ print("Freeze weights for detector")
58
+
59
+ device = torch.device(args.device)
60
+
61
+ # fix the seed for reproducibility
62
+ seed = args.seed + utils.get_rank()
63
+ torch.manual_seed(seed)
64
+ np.random.seed(seed)
65
+ random.seed(seed)
66
+
67
+ # Data Setup
68
+ dataset_train = build_dataset(image_set='train', args=args)
69
+ args.num_classes = dataset_train.num_category()
70
+ args.num_actions = dataset_train.num_action()
71
+ args.action_names = dataset_train.get_actions()
72
+ if args.share_enc: args.hoi_enc_layers = args.enc_layers
73
+ if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers
74
+ if args.dataset_file == 'vcoco':
75
+ # Save V-COCO dataset statistics
76
+ args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0]
77
+ args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1)
78
+ args.human_actions = dataset_train.get_human_action()
79
+ args.object_actions = dataset_train.get_object_action()
80
+ args.num_human_act = dataset_train.num_human_act()
81
+ elif args.dataset_file == 'hico-det':
82
+ args.valid_obj_ids = dataset_train.get_valid_obj_ids()
83
+ print_args(args)
84
+
85
+ args.HOIDet=True
86
+ args.eval=True
87
+ args.pretrained_dec=True
88
+ args.share_enc=True
89
+ args.share_dec_param = True
90
+ if args.dataset_file=='hico-det':
91
+ args.valid_ids=args.valid_obj_ids
92
+
93
+ # Model Setup
94
+ model, criterion, postprocessors = build(args)
95
+ model.to(device)
96
+
97
+ model_without_ddp = model
98
+
99
+ n_parameters = print_params(model)
100
+
101
+ param_dicts = [
102
+ {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
103
+ {
104
+ "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
105
+ "lr": args.lr_backbone,
106
+ },
107
+ ]
108
+
109
+ output_dir = Path(args.output_dir)
110
+
111
+ checkpoint = torch.load(args.resume, map_location='cpu')
112
+ #수정
113
+ module_name=list(checkpoint['model'].keys())
114
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
115
+
116
+ # if not args.video_vis:
117
+ # url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12))
118
+ # req = requests.get(url, stream=True, timeout=1, verify=False).raw
119
+ req = args.image_dir
120
+ img = Image.open(req).convert('RGB')
121
+
122
+ w,h=img.size
123
+ orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device)
124
+
125
+ transform=make_hoi_transforms('val')
126
+ sample=img.copy()
127
+ sample,_=transform(sample,None)
128
+ sample = sample.unsqueeze(0).to(device)
129
+ with torch.no_grad():
130
+ model.eval()
131
+ out=model(sample)
132
+ results = postprocessors['hoi'](out, orig_size,dataset=args.dataset_file,args=args)
133
+ output_i=change_format(results[0],args.valid_ids)
134
+
135
+ out_dir = './vis'
136
+ image = np.asarray(img, dtype=np.uint8)[:,:,::-1]
137
+ # image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR)
138
+
139
+ vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
140
+ plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
141
+ cv2.imwrite('./vis_res/vis1.jpg',vis_img)
142
+
143
+ # else:
144
+ # frames=[]
145
+ # video_file=id
146
+
147
+ # video_reader = mmcv.VideoReader('./vid/'+video_file+'.mp4')
148
+ # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
149
+ # video_writer = cv2.VideoWriter(
150
+ # './vid/'+video_file+'_vis.mp4', fourcc, video_reader.fps,
151
+ # (video_reader.width, video_reader.height))
152
+
153
+ # orig_size = torch.as_tensor([int(video_reader.height), int(video_reader.width)]).unsqueeze(0).to(device)
154
+ # transform=make_hoi_transforms('val')
155
+
156
+ # for frame in mmcv.track_iter_progress(video_reader):
157
+
158
+ # frame=mmcv.imread(frame)
159
+ # frame=frame.copy()
160
+
161
+ # frame=Image.fromarray(frame,'RGB')
162
+
163
+ # sample,_=transform(frame,None)
164
+ # sample=sample.unsqueeze(0).to(device)
165
+
166
+ # with torch.no_grad():
167
+ # model.eval()
168
+ # out=model(sample)
169
+ # results = postprocessors['hoi'](out, orig_size,dataset='vcoco',args=args)
170
+ # output_i=change_format(results[0],args.valid_ids)
171
+
172
+ # vis_img=draw_img_vcoco(np.array(frame),output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
173
+ # frames.append(vis_img)
174
+ # video_writer.write(vis_img)
175
+
176
+ # with imageio.get_writer("smiling.gif", mode="I") as writer:
177
+ # for idx, frame in enumerate(frames):
178
+ # # print("Adding frame to GIF file: ", idx + 1)
179
+ # writer.append_data(frame)
180
+ # if video_writer:
181
+ # video_writer.release()
182
+ # cv2.destroyAllWindows()
183
+
184
+
185
+ # def visualization(id, video_vis=False, dataset_file='vcoco', path_id = 0 ,data_path='v-coco', threshold=0.4, topk=10,aug_path = '[]'):
186
+
187
+ # parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
188
+ # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
189
+ # with open('./v-coco/data/vcoco_test.ids') as file:
190
+ # test_idxs = [line.rstrip('\n') for line in file]
191
+ # if not video_vis:
192
+ # id = test_idxs[id]
193
+ # args = parser.parse_args(args=['--dataset_file',dataset_file,'--data_path',data_path,'--resume',checkpoint_dir,'--num_hoi_queries' ,'16','--temperature' ,'0.05', '--augpath_name',aug_path ,'--path_id','{}'.format(path_id)])
194
+ # args.video_vis=video_vis
195
+ # args.threshold=threshold
196
+ # args.topk=topk
197
+
198
+ # if args.output_dir:
199
+ # Path(args.output_dir).mkdir(parents=True, exist_ok=True)
200
+ # vis(args,id)
201
+
202
+ if __name__ == '__main__':
203
+ parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
204
+ parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float)
205
+ # parser.add_argument('--path_id',help='index of inference path', default=1, type=int)
206
+ parser.add_argument('--topk',help='topk prediction', default=5, type=int)
207
+ parser.add_argument('--video_vis', action='store_true')
208
+ parser.add_argument('--image_dir', default='', type=str)
209
+ args = parser.parse_args()
210
+ # checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
211
+ args.resume= './checkpoints/vcoco/checkpoint.pth'
212
+ with open('./v-coco/data/splits/vcoco_test.ids') as file:
213
+ test_idxs = [line.rstrip('\n') for line in file]
214
+ # if not video_vis:
215
+ id = test_idxs[309]
216
+ # args = parser.parse_args()
217
+ # args.dataset_file = 'vcoco'
218
+ # args.data_path = 'v-coco'
219
+ # args.resume = checkpoint_dir
220
+ # args.num_hoi_queries = 16
221
+ # args.temperature = 0.05
222
+ args.augpath_name = ['p2','p3','p4']
223
+ # args.path_id = 1
224
+
225
+ if args.output_dir:
226
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
227
+ vis(args,id)