Spaces:
Runtime error
Runtime error
23
Browse files- .gitattributes +35 -0
- .gitignore +3 -0
- hotr/models/hotr.py +1 -1
- hotr/util/misc.py +2 -2
- requirements.txt +5 -0
- tools/vis_tool.py +96 -0
- upload_checkpoint_hugginface.ipynb +123 -0
- visualization.py +227 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
@@ -137,3 +137,6 @@ Makefile
|
|
137 |
#datasets
|
138 |
hico_20160224_det
|
139 |
v-coco
|
|
|
|
|
|
|
|
137 |
#datasets
|
138 |
hico_20160224_det
|
139 |
v-coco
|
140 |
+
|
141 |
+
# *.ipynb
|
142 |
+
vis_res
|
hotr/models/hotr.py
CHANGED
@@ -182,7 +182,7 @@ class HOTR(nn.Module):
|
|
182 |
|
183 |
H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1)
|
184 |
O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1)
|
185 |
-
|
186 |
outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim)
|
187 |
outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag]
|
188 |
|
|
|
182 |
|
183 |
H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1)
|
184 |
O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1)
|
185 |
+
# import pdb;pdb.set_trace()
|
186 |
outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim)
|
187 |
outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag]
|
188 |
|
hotr/util/misc.py
CHANGED
@@ -22,7 +22,7 @@ from torch import Tensor
|
|
22 |
|
23 |
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
24 |
import torchvision
|
25 |
-
if float(torchvision.__version__[
|
26 |
from torchvision.ops import _new_empty_tensor
|
27 |
from torchvision.ops.misc import _output_size
|
28 |
|
@@ -388,7 +388,7 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
|
|
388 |
This will eventually be supported natively by PyTorch, and this
|
389 |
class can go away.
|
390 |
"""
|
391 |
-
if float(torchvision.__version__[
|
392 |
if input.numel() > 0:
|
393 |
return torch.nn.functional.interpolate(
|
394 |
input, size, scale_factor, mode, align_corners
|
|
|
22 |
|
23 |
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
24 |
import torchvision
|
25 |
+
if float(torchvision.__version__.split('.',2)[1]) < 5:
|
26 |
from torchvision.ops import _new_empty_tensor
|
27 |
from torchvision.ops.misc import _output_size
|
28 |
|
|
|
388 |
This will eventually be supported natively by PyTorch, and this
|
389 |
class can go away.
|
390 |
"""
|
391 |
+
if float(torchvision.__version__.split('.',2)[1]) < 5:
|
392 |
if input.numel() > 0:
|
393 |
return torch.nn.functional.interpolate(
|
394 |
input, size, scale_factor, mode, align_corners
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pycocotools
|
2 |
+
opencv-python
|
3 |
+
wandb
|
4 |
+
imageio
|
5 |
+
scipy
|
tools/vis_tool.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
vcoco_action_string = {2: 'hold', 3: 'stand', 4: 'sit', 5: 'ride', 6: 'walk',\
|
10 |
+
7: 'look', 8: 'hit_inst', 9: 'hit_obj', 10: 'eat_obj', \
|
11 |
+
11: 'eat_inst', 12: 'jump', 13: 'lay', 14: 'talk', 15: \
|
12 |
+
'carry', 16: 'throw', 17: 'catch', 18: 'cut_inst', 19:'cut_obj', \
|
13 |
+
20: 'run', 21: 'work_on_comp', 22: 'ski', 23: 'surf', 24: 'skateboard', \
|
14 |
+
25: 'smile', 26: 'drink', 27: 'kick', 28: 'point', 29: 'read', 30: 'snowboard'}
|
15 |
+
def draw_box_on_img(box, img,color=None):
|
16 |
+
|
17 |
+
vis_img = img.copy()
|
18 |
+
box = [int(x) for x in box]
|
19 |
+
cv2.rectangle(vis_img, (box[0], box[1]), (box[2], box[3]), color, 2)
|
20 |
+
draw_point=[int((box[0]+box[2])*1.0/2),int((box[1]+box[3])*1.0/2)]
|
21 |
+
|
22 |
+
return vis_img,color
|
23 |
+
|
24 |
+
|
25 |
+
def draw_line_on_img_vcoco(box,line, img, class_index,color):
|
26 |
+
|
27 |
+
vis_img = img.copy()
|
28 |
+
font=cv2.FONT_HERSHEY_SIMPLEX
|
29 |
+
x=int(box[0])+2
|
30 |
+
y=int(box[1])+2
|
31 |
+
f=int(box[1])+2
|
32 |
+
for i in range(len(class_index)):
|
33 |
+
|
34 |
+
font_scale=1
|
35 |
+
font_thickness=2
|
36 |
+
|
37 |
+
text_size, _ = cv2.getTextSize(vcoco_action_string[class_index[i]] , font, font_scale, font_thickness)
|
38 |
+
vis_img=cv2.rectangle(vis_img,(x,y),(x+text_size[0],y+text_size[1]+5),color[1],-1)
|
39 |
+
|
40 |
+
|
41 |
+
vis_img=cv2.putText(vis_img, vcoco_action_string[class_index[i]] ,(x,y + text_size[1] ),font,font_scale,[51,255,153],font_thickness)
|
42 |
+
y=y+text_size[1]+5
|
43 |
+
|
44 |
+
return vis_img,y
|
45 |
+
|
46 |
+
|
47 |
+
def draw_img_vcoco(img, output_i, top_k,threshold,color):
|
48 |
+
list_action = []
|
49 |
+
for action in output_i['hoi_prediction']:
|
50 |
+
subject_id = action['subject_id']
|
51 |
+
object_id = action['object_id']
|
52 |
+
category_id = action['category_id']
|
53 |
+
score = action['score']
|
54 |
+
single_out = [subject_id,object_id,category_id,score]
|
55 |
+
list_action.append(single_out)
|
56 |
+
list_action = sorted(list_action, key=lambda x:x[-1], reverse=True)
|
57 |
+
action_dict = []
|
58 |
+
action_cate = []
|
59 |
+
action_color=[]
|
60 |
+
subj_box=[]
|
61 |
+
sb={}
|
62 |
+
sbj=[]
|
63 |
+
for action in list_action[:top_k]:
|
64 |
+
|
65 |
+
subject_id,object_id,category_id,score = action
|
66 |
+
if score<threshold:
|
67 |
+
break
|
68 |
+
subject_obj = output_i['predictions'][subject_id]
|
69 |
+
subject_box = subject_obj['bbox']
|
70 |
+
object_obj = output_i['predictions'][object_id]
|
71 |
+
object_box = object_obj['bbox']
|
72 |
+
|
73 |
+
point_1 = [int((subject_box[0]+subject_box[2])*1.0/2),int((subject_box[1]+subject_box[3])*1.0/2)]
|
74 |
+
point_2 = [int((object_box[0]+object_box[2])*1.0/2),int((object_box[1]+object_box[3])*1.0/2)]
|
75 |
+
|
76 |
+
if [point_1,point_2] not in action_dict:
|
77 |
+
|
78 |
+
img,color_hum = draw_box_on_img(subject_box, img, color[subject_obj['category_id']]['color'])
|
79 |
+
|
80 |
+
img,color_obj = draw_box_on_img(object_box, img, color[object_obj['category_id']]['color'])
|
81 |
+
|
82 |
+
action_dict.append([point_1,point_2])
|
83 |
+
action_color.append([color_hum,color_obj])
|
84 |
+
subj_box.append([int(subject_box[0]),int(subject_box[1])])
|
85 |
+
|
86 |
+
action_cate.append([])
|
87 |
+
action_cate[action_dict.index([point_1,point_2])].append(category_id)
|
88 |
+
|
89 |
+
for i,(action_item,clr) in enumerate(zip(action_dict,action_color)):
|
90 |
+
|
91 |
+
img,offset = draw_line_on_img_vcoco(subj_box[i],action_item,img,action_cate[action_dict.index(action_item)],clr)
|
92 |
+
|
93 |
+
for p in range(i+1,len(subj_box)):
|
94 |
+
if subj_box[p]==subj_box[i]:
|
95 |
+
subj_box[p][1]=offset
|
96 |
+
return img
|
upload_checkpoint_hugginface.ipynb
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "a19d0964-4c83-4bc9-b59f-f04a57ca020f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"data": {
|
11 |
+
"application/vnd.jupyter.widget-view+json": {
|
12 |
+
"model_id": "5dba33c11d7a4f708b5d6a03869ccb30",
|
13 |
+
"version_major": 2,
|
14 |
+
"version_minor": 0
|
15 |
+
},
|
16 |
+
"text/plain": [
|
17 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
"metadata": {},
|
21 |
+
"output_type": "display_data"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"source": [
|
25 |
+
"from huggingface_hub import notebook_login\n",
|
26 |
+
" \n",
|
27 |
+
"notebook_login()"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 4,
|
33 |
+
"id": "4b0ba848-4799-476d-b1d6-16a7cde1f4ad",
|
34 |
+
"metadata": {},
|
35 |
+
"outputs": [
|
36 |
+
{
|
37 |
+
"name": "stderr",
|
38 |
+
"output_type": "stream",
|
39 |
+
"text": [
|
40 |
+
"Cloning https://huggingface.co/jhp/hoi into local empty directory.\n"
|
41 |
+
]
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"source": [
|
45 |
+
"from huggingface_hub import Repository\n",
|
46 |
+
" \n",
|
47 |
+
"repo = Repository('CPC_HOTR', clone_from='jhp/hoi')"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 9,
|
53 |
+
"id": "5d092b5a-99b3-4a4c-9ef4-90724eb665ef",
|
54 |
+
"metadata": {},
|
55 |
+
"outputs": [
|
56 |
+
{
|
57 |
+
"data": {
|
58 |
+
"application/vnd.jupyter.widget-view+json": {
|
59 |
+
"model_id": "b51f83592a414a1ea646717c434e2a23",
|
60 |
+
"version_major": 2,
|
61 |
+
"version_minor": 0
|
62 |
+
},
|
63 |
+
"text/plain": [
|
64 |
+
"checkpoint.pth: 0%| | 0.00/301M [00:00<?, ?B/s]"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
"metadata": {},
|
68 |
+
"output_type": "display_data"
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"data": {
|
72 |
+
"text/plain": [
|
73 |
+
"'https://huggingface.co/jhp/hoi/tree/main/./'"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
"execution_count": 9,
|
77 |
+
"metadata": {},
|
78 |
+
"output_type": "execute_result"
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"from huggingface_hub import HfApi\n",
|
83 |
+
"api = HfApi()\n",
|
84 |
+
"api.upload_folder(\n",
|
85 |
+
" folder_path=\"./checkpoints/\",\n",
|
86 |
+
" repo_id=\"jhp/hoi\",\n",
|
87 |
+
" path_in_repo=\"./\",\n",
|
88 |
+
" # allow_patterns=\"*.txt\", # Upload all local text files\n",
|
89 |
+
" # delete_patterns=\"*.txt\", # Delete all remote text files before\n",
|
90 |
+
")"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": null,
|
96 |
+
"id": "82eedf6b-21a2-4dad-bc7d-6f3631525ff2",
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": []
|
100 |
+
}
|
101 |
+
],
|
102 |
+
"metadata": {
|
103 |
+
"kernelspec": {
|
104 |
+
"display_name": "Python 3 (ipykernel)",
|
105 |
+
"language": "python",
|
106 |
+
"name": "python3"
|
107 |
+
},
|
108 |
+
"language_info": {
|
109 |
+
"codemirror_mode": {
|
110 |
+
"name": "ipython",
|
111 |
+
"version": 3
|
112 |
+
},
|
113 |
+
"file_extension": ".py",
|
114 |
+
"mimetype": "text/x-python",
|
115 |
+
"name": "python",
|
116 |
+
"nbconvert_exporter": "python",
|
117 |
+
"pygments_lexer": "ipython3",
|
118 |
+
"version": "3.9.17"
|
119 |
+
}
|
120 |
+
},
|
121 |
+
"nbformat": 4,
|
122 |
+
"nbformat_minor": 5
|
123 |
+
}
|
visualization.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
import multiprocessing
|
7 |
+
from pathlib import Path
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
13 |
+
import hotr.data.datasets as datasets
|
14 |
+
import hotr.util.misc as utils
|
15 |
+
from hotr.engine.arg_parser import get_args_parser
|
16 |
+
from hotr.data.datasets import build_dataset, get_coco_api_from_dataset
|
17 |
+
from hotr.data.datasets.vcoco import make_hoi_transforms
|
18 |
+
from PIL import Image
|
19 |
+
from hotr.util.logger import print_params, print_args
|
20 |
+
|
21 |
+
import copy
|
22 |
+
from hotr.data.datasets import builtin_meta
|
23 |
+
from PIL import Image
|
24 |
+
import requests
|
25 |
+
# import mmcv
|
26 |
+
from matplotlib import pyplot as plt
|
27 |
+
import imageio
|
28 |
+
|
29 |
+
from tools.vis_tool import *
|
30 |
+
from hotr.models.detr import build
|
31 |
+
|
32 |
+
def change_format(results,valid_ids):
|
33 |
+
|
34 |
+
boxes,labels,pair_score =\
|
35 |
+
list(map(lambda x: x.cpu().numpy(), [results['boxes'], results['labels'], results['pair_score']]))
|
36 |
+
output_i={}
|
37 |
+
output_i['predictions']=[]
|
38 |
+
output_i['hoi_prediction']=[]
|
39 |
+
|
40 |
+
h_idx=np.where(labels==1)[0]
|
41 |
+
for box,label in zip(boxes,labels):
|
42 |
+
|
43 |
+
output_i['predictions'].append({'bbox':box.tolist(),'category_id':label})
|
44 |
+
|
45 |
+
for i,verb in enumerate(pair_score):
|
46 |
+
if i in [1,4,10,23,26,5,18]:
|
47 |
+
continue
|
48 |
+
for j,hum in enumerate(h_idx):
|
49 |
+
for k in range(len(boxes)):
|
50 |
+
if verb[j][k]>0:
|
51 |
+
output_i['hoi_prediction'].append({'subject_id':hum,'object_id':k,'category_id':i+2,'score':verb[j][k]})
|
52 |
+
|
53 |
+
return output_i
|
54 |
+
def vis(args,id=294):
|
55 |
+
|
56 |
+
if args.frozen_weights is not None:
|
57 |
+
print("Freeze weights for detector")
|
58 |
+
|
59 |
+
device = torch.device(args.device)
|
60 |
+
|
61 |
+
# fix the seed for reproducibility
|
62 |
+
seed = args.seed + utils.get_rank()
|
63 |
+
torch.manual_seed(seed)
|
64 |
+
np.random.seed(seed)
|
65 |
+
random.seed(seed)
|
66 |
+
|
67 |
+
# Data Setup
|
68 |
+
dataset_train = build_dataset(image_set='train', args=args)
|
69 |
+
args.num_classes = dataset_train.num_category()
|
70 |
+
args.num_actions = dataset_train.num_action()
|
71 |
+
args.action_names = dataset_train.get_actions()
|
72 |
+
if args.share_enc: args.hoi_enc_layers = args.enc_layers
|
73 |
+
if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers
|
74 |
+
if args.dataset_file == 'vcoco':
|
75 |
+
# Save V-COCO dataset statistics
|
76 |
+
args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0]
|
77 |
+
args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1)
|
78 |
+
args.human_actions = dataset_train.get_human_action()
|
79 |
+
args.object_actions = dataset_train.get_object_action()
|
80 |
+
args.num_human_act = dataset_train.num_human_act()
|
81 |
+
elif args.dataset_file == 'hico-det':
|
82 |
+
args.valid_obj_ids = dataset_train.get_valid_obj_ids()
|
83 |
+
print_args(args)
|
84 |
+
|
85 |
+
args.HOIDet=True
|
86 |
+
args.eval=True
|
87 |
+
args.pretrained_dec=True
|
88 |
+
args.share_enc=True
|
89 |
+
args.share_dec_param = True
|
90 |
+
if args.dataset_file=='hico-det':
|
91 |
+
args.valid_ids=args.valid_obj_ids
|
92 |
+
|
93 |
+
# Model Setup
|
94 |
+
model, criterion, postprocessors = build(args)
|
95 |
+
model.to(device)
|
96 |
+
|
97 |
+
model_without_ddp = model
|
98 |
+
|
99 |
+
n_parameters = print_params(model)
|
100 |
+
|
101 |
+
param_dicts = [
|
102 |
+
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
|
103 |
+
{
|
104 |
+
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
|
105 |
+
"lr": args.lr_backbone,
|
106 |
+
},
|
107 |
+
]
|
108 |
+
|
109 |
+
output_dir = Path(args.output_dir)
|
110 |
+
|
111 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
112 |
+
#수정
|
113 |
+
module_name=list(checkpoint['model'].keys())
|
114 |
+
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
115 |
+
|
116 |
+
# if not args.video_vis:
|
117 |
+
# url='http://images.cocodataset.org/val2014/COCO_val2014_{}.jpg'.format(str(id).zfill(12))
|
118 |
+
# req = requests.get(url, stream=True, timeout=1, verify=False).raw
|
119 |
+
req = args.image_dir
|
120 |
+
img = Image.open(req).convert('RGB')
|
121 |
+
|
122 |
+
w,h=img.size
|
123 |
+
orig_size = torch.as_tensor([int(h), int(w)]).unsqueeze(0).to(device)
|
124 |
+
|
125 |
+
transform=make_hoi_transforms('val')
|
126 |
+
sample=img.copy()
|
127 |
+
sample,_=transform(sample,None)
|
128 |
+
sample = sample.unsqueeze(0).to(device)
|
129 |
+
with torch.no_grad():
|
130 |
+
model.eval()
|
131 |
+
out=model(sample)
|
132 |
+
results = postprocessors['hoi'](out, orig_size,dataset=args.dataset_file,args=args)
|
133 |
+
output_i=change_format(results[0],args.valid_ids)
|
134 |
+
|
135 |
+
out_dir = './vis'
|
136 |
+
image = np.asarray(img, dtype=np.uint8)[:,:,::-1]
|
137 |
+
# image = cv2.imdecode(image_nparray, cv2.IMREAD_COLOR)
|
138 |
+
|
139 |
+
vis_img=draw_img_vcoco(image,output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
|
140 |
+
plt.imshow(cv2.cvtColor(vis_img,cv2.COLOR_BGR2RGB))
|
141 |
+
cv2.imwrite('./vis_res/vis1.jpg',vis_img)
|
142 |
+
|
143 |
+
# else:
|
144 |
+
# frames=[]
|
145 |
+
# video_file=id
|
146 |
+
|
147 |
+
# video_reader = mmcv.VideoReader('./vid/'+video_file+'.mp4')
|
148 |
+
# fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
149 |
+
# video_writer = cv2.VideoWriter(
|
150 |
+
# './vid/'+video_file+'_vis.mp4', fourcc, video_reader.fps,
|
151 |
+
# (video_reader.width, video_reader.height))
|
152 |
+
|
153 |
+
# orig_size = torch.as_tensor([int(video_reader.height), int(video_reader.width)]).unsqueeze(0).to(device)
|
154 |
+
# transform=make_hoi_transforms('val')
|
155 |
+
|
156 |
+
# for frame in mmcv.track_iter_progress(video_reader):
|
157 |
+
|
158 |
+
# frame=mmcv.imread(frame)
|
159 |
+
# frame=frame.copy()
|
160 |
+
|
161 |
+
# frame=Image.fromarray(frame,'RGB')
|
162 |
+
|
163 |
+
# sample,_=transform(frame,None)
|
164 |
+
# sample=sample.unsqueeze(0).to(device)
|
165 |
+
|
166 |
+
# with torch.no_grad():
|
167 |
+
# model.eval()
|
168 |
+
# out=model(sample)
|
169 |
+
# results = postprocessors['hoi'](out, orig_size,dataset='vcoco',args=args)
|
170 |
+
# output_i=change_format(results[0],args.valid_ids)
|
171 |
+
|
172 |
+
# vis_img=draw_img_vcoco(np.array(frame),output_i,top_k=args.topk,threshold=args.threshold,color=builtin_meta.COCO_CATEGORIES)
|
173 |
+
# frames.append(vis_img)
|
174 |
+
# video_writer.write(vis_img)
|
175 |
+
|
176 |
+
# with imageio.get_writer("smiling.gif", mode="I") as writer:
|
177 |
+
# for idx, frame in enumerate(frames):
|
178 |
+
# # print("Adding frame to GIF file: ", idx + 1)
|
179 |
+
# writer.append_data(frame)
|
180 |
+
# if video_writer:
|
181 |
+
# video_writer.release()
|
182 |
+
# cv2.destroyAllWindows()
|
183 |
+
|
184 |
+
|
185 |
+
# def visualization(id, video_vis=False, dataset_file='vcoco', path_id = 0 ,data_path='v-coco', threshold=0.4, topk=10,aug_path = '[]'):
|
186 |
+
|
187 |
+
# parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
188 |
+
# checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
|
189 |
+
# with open('./v-coco/data/vcoco_test.ids') as file:
|
190 |
+
# test_idxs = [line.rstrip('\n') for line in file]
|
191 |
+
# if not video_vis:
|
192 |
+
# id = test_idxs[id]
|
193 |
+
# args = parser.parse_args(args=['--dataset_file',dataset_file,'--data_path',data_path,'--resume',checkpoint_dir,'--num_hoi_queries' ,'16','--temperature' ,'0.05', '--augpath_name',aug_path ,'--path_id','{}'.format(path_id)])
|
194 |
+
# args.video_vis=video_vis
|
195 |
+
# args.threshold=threshold
|
196 |
+
# args.topk=topk
|
197 |
+
|
198 |
+
# if args.output_dir:
|
199 |
+
# Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
200 |
+
# vis(args,id)
|
201 |
+
|
202 |
+
if __name__ == '__main__':
|
203 |
+
parser = argparse.ArgumentParser('DETR training and evaluation script', parents=[get_args_parser()])
|
204 |
+
parser.add_argument('--threshold',help='score threshold for visualization', default=0.4, type=float)
|
205 |
+
# parser.add_argument('--path_id',help='index of inference path', default=1, type=int)
|
206 |
+
parser.add_argument('--topk',help='topk prediction', default=5, type=int)
|
207 |
+
parser.add_argument('--video_vis', action='store_true')
|
208 |
+
parser.add_argument('--image_dir', default='', type=str)
|
209 |
+
args = parser.parse_args()
|
210 |
+
# checkpoint_dir= './checkpoints/vcoco/checkpoint.pth' if dataset_file=='vcoco' else './checkpoints/hico-det/hico_ft_q16.pth'
|
211 |
+
args.resume= './checkpoints/vcoco/checkpoint.pth'
|
212 |
+
with open('./v-coco/data/splits/vcoco_test.ids') as file:
|
213 |
+
test_idxs = [line.rstrip('\n') for line in file]
|
214 |
+
# if not video_vis:
|
215 |
+
id = test_idxs[309]
|
216 |
+
# args = parser.parse_args()
|
217 |
+
# args.dataset_file = 'vcoco'
|
218 |
+
# args.data_path = 'v-coco'
|
219 |
+
# args.resume = checkpoint_dir
|
220 |
+
# args.num_hoi_queries = 16
|
221 |
+
# args.temperature = 0.05
|
222 |
+
args.augpath_name = ['p2','p3','p4']
|
223 |
+
# args.path_id = 1
|
224 |
+
|
225 |
+
if args.output_dir:
|
226 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
227 |
+
vis(args,id)
|