raoulritter
commited on
Commit
·
331a0e7
1
Parent(s):
82fe504
Provide examples to Github and checkpoints to get help from edge impulse forums
Browse files- .gitattributes +4 -0
- Test_ONNX_Convert.ipynb +259 -0
- ckpt_e09.pth.tar +3 -0
- ckpt_e10.pth.tar +3 -0
- ckpt_e49.pth.tar +3 -0
- ckpt_pytorch_1_11_e00.pth.tar +3 -0
- model.onnx +3 -0
- model.pth +3 -0
- model_2.onnx +3 -0
- model_5x.onnx +3 -0
- model_checkpoint_5x_50ep.onnx +3 -0
- onnxrun.py +112 -0
- requirements.txt +5 -0
- run.py +133 -0
- test_convert.py +36 -0
.gitattributes
CHANGED
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
ckpt_e09.pth.tar filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ckpt_e10.pth.tar filter=lfs diff=lfs merge=lfs -text
|
37 |
+
ckpt_e49.pth.tar filter=lfs diff=lfs merge=lfs -text
|
38 |
+
ckpt_pytorch_1_11_e00.pth.tar filter=lfs diff=lfs merge=lfs -text
|
Test_ONNX_Convert.ipynb
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 9,
|
6 |
+
"id": "d0cd0fce",
|
7 |
+
"metadata": {
|
8 |
+
"ExecuteTime": {
|
9 |
+
"end_time": "2023-05-15T16:14:15.749744Z",
|
10 |
+
"start_time": "2023-05-15T16:14:15.540642Z"
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"outputs": [],
|
14 |
+
"source": [
|
15 |
+
"import onnx\n",
|
16 |
+
"\n",
|
17 |
+
"\n",
|
18 |
+
"onnx_model = onnx.load(\"ckpt/model.onnx\")\n",
|
19 |
+
"onnx.checker.check_model(onnx_model)"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 9,
|
25 |
+
"id": "27f06a8c",
|
26 |
+
"metadata": {
|
27 |
+
"ExecuteTime": {
|
28 |
+
"end_time": "2023-05-15T16:14:15.751689Z",
|
29 |
+
"start_time": "2023-05-15T16:14:15.748975Z"
|
30 |
+
}
|
31 |
+
},
|
32 |
+
"outputs": [],
|
33 |
+
"source": []
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": 10,
|
38 |
+
"id": "f9167299",
|
39 |
+
"metadata": {
|
40 |
+
"ExecuteTime": {
|
41 |
+
"end_time": "2023-05-15T16:14:15.777873Z",
|
42 |
+
"start_time": "2023-05-15T16:14:15.753825Z"
|
43 |
+
}
|
44 |
+
},
|
45 |
+
"outputs": [
|
46 |
+
{
|
47 |
+
"ename": "ModuleNotFoundError",
|
48 |
+
"evalue": "No module named 'onnx_tf'",
|
49 |
+
"output_type": "error",
|
50 |
+
"traceback": [
|
51 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
52 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
53 |
+
"Cell \u001b[0;32mIn[10], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautograd\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Variable\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01monnx\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01monnx_tf\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m prepare\n\u001b[1;32m 10\u001b[0m model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mckpt/model_5x.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 11\u001b[0m tf_rep \u001b[38;5;241m=\u001b[39m prepare(model)\n",
|
54 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'onnx_tf'"
|
55 |
+
]
|
56 |
+
}
|
57 |
+
],
|
58 |
+
"source": [
|
59 |
+
"import torch\n",
|
60 |
+
"import torch.nn as nn\n",
|
61 |
+
"import torch.nn.functional as F\n",
|
62 |
+
"import torch.optim as optim\n",
|
63 |
+
"from torchvision import datasets, transforms\n",
|
64 |
+
"from torch.autograd import Variable\n",
|
65 |
+
"import onnx\n",
|
66 |
+
"from onnx_tf.backend import prepare\n",
|
67 |
+
"\n",
|
68 |
+
"model = onnx.load('ckpt/model_5x.onnx')\n",
|
69 |
+
"tf_rep = prepare(model)"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": null,
|
75 |
+
"id": "2d1db936",
|
76 |
+
"metadata": {},
|
77 |
+
"outputs": [],
|
78 |
+
"source": []
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"id": "f4eb7d23",
|
84 |
+
"metadata": {
|
85 |
+
"scrolled": false
|
86 |
+
},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"import torch\n",
|
90 |
+
"import torch.onnx as torch.onnx\n",
|
91 |
+
"import onnx\n",
|
92 |
+
"\n",
|
93 |
+
"\n",
|
94 |
+
"import torch.nn as nn\n",
|
95 |
+
"\n",
|
96 |
+
"from models.model import STBVMM\n",
|
97 |
+
"\n",
|
98 |
+
"# # Initialize model with checkpointing enabled\n",
|
99 |
+
"# model = STBVMM(img_size=384, patch_size=1, in_chans=3,\n",
|
100 |
+
"# embed_dim=48, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],\n",
|
101 |
+
"# window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,\n",
|
102 |
+
"# drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n",
|
103 |
+
"# norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n",
|
104 |
+
"# use_checkpoint=True, img_range=1., resi_connection='1conv',\n",
|
105 |
+
"# manipulator_num_resblk = 1)\n",
|
106 |
+
"\n",
|
107 |
+
"model = STBVMM(img_size=384, patch_size=1, in_chans=3,\n",
|
108 |
+
" embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],\n",
|
109 |
+
" window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,\n",
|
110 |
+
" drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n",
|
111 |
+
" norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n",
|
112 |
+
" use_checkpoint=False, img_range=1., resi_connection='1conv',\n",
|
113 |
+
" manipulator_num_resblk=1)\n",
|
114 |
+
"\n",
|
115 |
+
"# Load pretrained weights from checkpoint\n",
|
116 |
+
"checkpoint = torch.load('ckpt/ckpt_e49.pth.tar')\n",
|
117 |
+
"# print(checkpoint.keys())\n",
|
118 |
+
"\n",
|
119 |
+
"# print(checkpoint['state_dict'])\n",
|
120 |
+
"\n",
|
121 |
+
"model.load_state_dict(checkpoint['state_dict'], strict= False)\n",
|
122 |
+
"\n",
|
123 |
+
"# Set the model to eval mode\n",
|
124 |
+
"model.eval()\n",
|
125 |
+
"\n",
|
126 |
+
"# Export model to ONNX\n",
|
127 |
+
"inputs = (torch.randn(1, 3, 384, 384), torch.randn(1, 3, 384, 384), 5)\n",
|
128 |
+
"input_names = [\"a\", \"b\", \"amp\"]\n",
|
129 |
+
"output_names = [\"output\"]\n",
|
130 |
+
"dynamic_axes = {\"a\": {0: \"batch_size\", 2: \"height\", 3: \"width\"},\n",
|
131 |
+
" \"b\": {0: \"batch_size\", 2: \"height\", 3: \"width\"},\n",
|
132 |
+
" \"output\": {0: \"batch_size\", 2: \"height\", 3: \"width\"}}\n",
|
133 |
+
"onnx.export(model, inputs, \"model_checkpoint_5x_50ep.onnx\", input_names=input_names, output_names=output_names,\n",
|
134 |
+
" dynamic_axes=dynamic_axes, opset_version=11)\n"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": 6,
|
140 |
+
"id": "8d70c0e9",
|
141 |
+
"metadata": {
|
142 |
+
"ExecuteTime": {
|
143 |
+
"end_time": "2023-05-16T04:50:57.341575Z",
|
144 |
+
"start_time": "2023-05-16T04:50:57.144003Z"
|
145 |
+
}
|
146 |
+
},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"import onnx\n",
|
150 |
+
"\n",
|
151 |
+
"onnx_model = onnx.load(\"ckpt/model_checkpoint_5x_50ep.onnx\")\n",
|
152 |
+
"onnx.checker.check_model(onnx_model)"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": null,
|
158 |
+
"id": "ec9bacfd",
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"import onnxruntime as ort\n",
|
163 |
+
"import numpy as np\n",
|
164 |
+
"import cv2\n",
|
165 |
+
"x, y = test_data[0][0], test_data[0][1]"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"execution_count": 7,
|
171 |
+
"id": "1bea7afb",
|
172 |
+
"metadata": {
|
173 |
+
"ExecuteTime": {
|
174 |
+
"end_time": "2023-05-16T05:23:43.230764Z",
|
175 |
+
"start_time": "2023-05-16T05:23:29.950857Z"
|
176 |
+
}
|
177 |
+
},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"name": "stdout",
|
181 |
+
"output_type": "stream",
|
182 |
+
"text": [
|
183 |
+
"Using device: cpu\r\n",
|
184 |
+
"demo_video/STB-VMM_Freezer_x20_mag\r\n",
|
185 |
+
"processing sample: 0\r\n",
|
186 |
+
"Traceback (most recent call last):\r\n",
|
187 |
+
" File \"/Users/raoulritter/STB-VMM/onnxrun.py\", line 112, in <module>\r\n",
|
188 |
+
" main(args)\r\n",
|
189 |
+
" File \"/Users/raoulritter/STB-VMM/onnxrun.py\", line 53, in main\r\n",
|
190 |
+
" ort_outs = ort_session.run(None, ort_inputs)\r\n",
|
191 |
+
" File \"/opt/anaconda3/envs/afstudeer/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py\", line 200, in run\r\n",
|
192 |
+
" return self._sess.run(output_names, input_feed, run_options)\r\n",
|
193 |
+
"RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'a'.\r\n"
|
194 |
+
]
|
195 |
+
}
|
196 |
+
],
|
197 |
+
"source": [
|
198 |
+
"!python onnxrun.py -j4 -b1 --load_ckpt ckpt/model_checkpoint_5x_50ep.onnx --save_dir demo_video/STB-VMM_Freezer_x20_mag -m 5 --video_path demo_video/STB-VMM_Freezer_x20_original/frame --num_data 6644 --mode static\n",
|
199 |
+
"\n"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"cell_type": "code",
|
204 |
+
"execution_count": 1,
|
205 |
+
"id": "c77312e7",
|
206 |
+
"metadata": {
|
207 |
+
"ExecuteTime": {
|
208 |
+
"end_time": "2023-05-16T05:14:51.342430Z",
|
209 |
+
"start_time": "2023-05-16T05:14:48.496955Z"
|
210 |
+
}
|
211 |
+
},
|
212 |
+
"outputs": [
|
213 |
+
{
|
214 |
+
"name": "stdout",
|
215 |
+
"output_type": "stream",
|
216 |
+
"text": [
|
217 |
+
"1.11.0\n",
|
218 |
+
"0.12.0\n"
|
219 |
+
]
|
220 |
+
}
|
221 |
+
],
|
222 |
+
"source": [
|
223 |
+
"import torch\n",
|
224 |
+
"import torchvision\n",
|
225 |
+
"print(torch.__version__)\n",
|
226 |
+
"print(torchvision.__version__)"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": null,
|
232 |
+
"id": "702d9d85",
|
233 |
+
"metadata": {},
|
234 |
+
"outputs": [],
|
235 |
+
"source": []
|
236 |
+
}
|
237 |
+
],
|
238 |
+
"metadata": {
|
239 |
+
"kernelspec": {
|
240 |
+
"display_name": "Python [conda env:afstudeer]",
|
241 |
+
"language": "python",
|
242 |
+
"name": "conda-env-afstudeer-py"
|
243 |
+
},
|
244 |
+
"language_info": {
|
245 |
+
"codemirror_mode": {
|
246 |
+
"name": "ipython",
|
247 |
+
"version": 3
|
248 |
+
},
|
249 |
+
"file_extension": ".py",
|
250 |
+
"mimetype": "text/x-python",
|
251 |
+
"name": "python",
|
252 |
+
"nbconvert_exporter": "python",
|
253 |
+
"pygments_lexer": "ipython3",
|
254 |
+
"version": "3.10.11"
|
255 |
+
}
|
256 |
+
},
|
257 |
+
"nbformat": 4,
|
258 |
+
"nbformat_minor": 5
|
259 |
+
}
|
ckpt_e09.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b503d280322ad5a257fd760878447c3e13e80800baa06b8288ee37bd79173ce
|
3 |
+
size 149374251
|
ckpt_e10.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b63b6fdfbd487d5482c0e4b821040df85e315407a2205e755b426cf9c94492ce
|
3 |
+
size 149368279
|
ckpt_e49.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f1df7bebba895be14728293138812826c1affeb4777f76be960e8eb100ed362
|
3 |
+
size 149368983
|
ckpt_pytorch_1_11_e00.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:449bb57e7b0c3a17580f5512cde77397abf6178ea30f28a726d700eac2343920
|
3 |
+
size 149368087
|
model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de4e1e8b51f1cf371159c53a9efdc93cdd879c8f4406941e20301d05d3718c67
|
3 |
+
size 137258146
|
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:32c2cea3c5ef96e308f5c8b8a0b6418d8e00cacf1c6c5b3e388e796f00ccb079
|
3 |
+
size 149306703
|
model_2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e72d64cec630d95ba0bc4e49ee7c7e2f1a4a71dcef83ab3b2b1e24ab75fd7a9c
|
3 |
+
size 136977868
|
model_5x.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de4e1e8b51f1cf371159c53a9efdc93cdd879c8f4406941e20301d05d3718c67
|
3 |
+
size 137258146
|
model_checkpoint_5x_50ep.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7681377d1db055c8b4f4a345052830beb20a027526235fcc6c4c2232edb876bb
|
3 |
+
size 136977913
|
onnxrun.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.utils.data as data
|
7 |
+
from PIL import Image
|
8 |
+
from utils.data_loader import ImageFromFolderTest
|
9 |
+
import onnxruntime as ort
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
# Device choice (auto)
|
14 |
+
if args.device == 'auto':
|
15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
else:
|
17 |
+
device = args.device
|
18 |
+
|
19 |
+
print(f'Using device: {device}')
|
20 |
+
|
21 |
+
# Create ONNX Inference Session
|
22 |
+
ort_session = ort.InferenceSession(args.load_ckpt)
|
23 |
+
|
24 |
+
# Check saving directory
|
25 |
+
save_dir = args.save_dir
|
26 |
+
if not os.path.exists(save_dir):
|
27 |
+
os.makedirs(save_dir)
|
28 |
+
print(save_dir)
|
29 |
+
|
30 |
+
# Data loader
|
31 |
+
dataset_mag = ImageFromFolderTest(
|
32 |
+
args.video_path, mag=args.mag, mode=args.mode, num_data=args.num_data, preprocessing=False)
|
33 |
+
data_loader = data.DataLoader(dataset_mag,
|
34 |
+
batch_size=args.batch_size,
|
35 |
+
shuffle=False,
|
36 |
+
num_workers=args.workers,
|
37 |
+
pin_memory=False)
|
38 |
+
|
39 |
+
# Magnification
|
40 |
+
for i, (xa, xb, mag_factor) in enumerate(data_loader):
|
41 |
+
if i % args.print_freq == 0:
|
42 |
+
print('processing sample: %d' % i)
|
43 |
+
|
44 |
+
xa = xa.to(device)
|
45 |
+
xb = xb.to(device)
|
46 |
+
|
47 |
+
# Infer using ONNX model
|
48 |
+
mag_factor = torch.tensor([[args.mag]]).to(device) # Create a constant tensor for the magnification factor
|
49 |
+
ort_inputs = {ort_session.get_inputs()[0].name: xa,
|
50 |
+
ort_session.get_inputs()[1].name: xb,
|
51 |
+
ort_session.get_inputs()[2].name: mag_factor}
|
52 |
+
#y_hat, _, _, _ = ort_session.run(ort_inputs)
|
53 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
54 |
+
y_hat = ort_outs[0]
|
55 |
+
|
56 |
+
# ort_inputs = {ort_session.get_inputs()[0].name: xa,
|
57 |
+
# ort_session.get_inputs()[1].name: xb}
|
58 |
+
# y_hat, _, _, _ = ort_session.run(None, ort_inputs)
|
59 |
+
|
60 |
+
if i == 0:
|
61 |
+
# Back to image scale (0-255)
|
62 |
+
tmp = xa.permute(0, 2, 3, 1).cpu().detach().numpy()
|
63 |
+
tmp = np.clip(tmp, -1.0, 1.0)
|
64 |
+
tmp = ((tmp + 1.0) * 127.5).astype(np.uint8)
|
65 |
+
|
66 |
+
# Save first frame
|
67 |
+
fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i))
|
68 |
+
im = Image.fromarray(np.concatenate(tmp, 0))
|
69 |
+
im.save(fn)
|
70 |
+
|
71 |
+
# back to image scale (0-255)
|
72 |
+
y_hat = y_hat.permute(0, 2, 3, 1).cpu().detach().numpy()
|
73 |
+
y_hat = np.clip(y_hat, -1.0, 1.0)
|
74 |
+
y_hat = ((y_hat + 1.0) * 127.5).astype(np.uint8)
|
75 |
+
|
76 |
+
# Save frames
|
77 |
+
fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i+1))
|
78 |
+
im = Image.fromarray(np.concatenate(y_hat, 0))
|
79 |
+
im.save(fn)
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
parser = argparse.ArgumentParser(
|
83 |
+
description='Swin Transformer Based Video Motion Magnification')
|
84 |
+
|
85 |
+
# Application parameters
|
86 |
+
parser.add_argument('-i', '--video_path', type=str, metavar='PATH', required=True,
|
87 |
+
help='path to video input frames')
|
88 |
+
parser.add_argument('-c', '--load_ckpt', type=str, metavar='PATH', required=True,
|
89 |
+
help='path to load ONNX model')
|
90 |
+
parser.add_argument('-o', '--save_dir', default='demo', type=str, metavar='PATH',
|
91 |
+
help='path to save generated frames (default: demo)')
|
92 |
+
parser.add_argument('-m', '--mag', metavar='N', default=20.0, type=float,
|
93 |
+
help='magnification factor (default: 20.0)')
|
94 |
+
parser.add_argument('--mode', default='static', type=str, choices=['static', 'dynamic'],
|
95 |
+
help='magnification mode (static, dynamic)')
|
96 |
+
parser.add_argument('-n', '--num_data', type=int, metavar='N', required=True,
|
97 |
+
help='number of frames')
|
98 |
+
|
99 |
+
# Execute parameters
|
100 |
+
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
|
101 |
+
help='number of data loading workers (default: 16)')
|
102 |
+
parser.add_argument('-b', '--batch_size', default=1, type=int,
|
103 |
+
metavar='N', help='batch size (default: 1)')
|
104 |
+
parser.add_argument('-p', '--print_freq', default=100, type=int,
|
105 |
+
metavar='N', help='print frequency (default: 100)')
|
106 |
+
|
107 |
+
# Device
|
108 |
+
parser.add_argument('--device', type=str, default='auto',
|
109 |
+
choices=['auto', 'cpu', 'cuda', 'mps', 'xla'],
|
110 |
+
help='select device [auto/cpu/cuda] (default: auto)')
|
111 |
+
args = parser.parse_args()
|
112 |
+
main(args)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0
|
2 |
+
Pillow==9.3
|
3 |
+
torchvision
|
4 |
+
torchaudio
|
5 |
+
numpy
|
run.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.utils.data as data
|
8 |
+
import torchvision.datasets as datasets
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
from utils.data_loader import ImageFromFolderTest
|
12 |
+
from models.model import STBVMM
|
13 |
+
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
# Device choice (auto)
|
17 |
+
if args.device == 'auto':
|
18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
else:
|
20 |
+
device = args.device
|
21 |
+
|
22 |
+
print(f'Using device: {device}')
|
23 |
+
|
24 |
+
# Create model
|
25 |
+
model = STBVMM(img_size=384, patch_size=1, in_chans=3,
|
26 |
+
embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
|
27 |
+
window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
|
28 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
29 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
30 |
+
use_checkpoint=False, img_range=1., resi_connection='1conv',
|
31 |
+
manipulator_num_resblk=1).to(device)
|
32 |
+
|
33 |
+
# Load checkpoint
|
34 |
+
if os.path.isfile(args.load_ckpt):
|
35 |
+
print("=> loading checkpoint '{}'".format(args.load_ckpt))
|
36 |
+
checkpoint = torch.load(args.load_ckpt)
|
37 |
+
args.start_epoch = checkpoint['epoch']
|
38 |
+
|
39 |
+
model.load_state_dict(checkpoint['state_dict'])
|
40 |
+
|
41 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
42 |
+
.format(args.load_ckpt, checkpoint['epoch']))
|
43 |
+
else:
|
44 |
+
print("=> no checkpoint found at '{}'".format(args.load_ckpt))
|
45 |
+
assert (False)
|
46 |
+
|
47 |
+
# Check saving directory
|
48 |
+
save_dir = args.save_dir
|
49 |
+
if not os.path.exists(save_dir):
|
50 |
+
os.makedirs(save_dir)
|
51 |
+
print(save_dir)
|
52 |
+
|
53 |
+
# Data loader
|
54 |
+
dataset_mag = ImageFromFolderTest(
|
55 |
+
args.video_path, mag=args.mag, mode=args.mode, num_data=args.num_data, preprocessing=False)
|
56 |
+
data_loader = data.DataLoader(dataset_mag,
|
57 |
+
batch_size=args.batch_size,
|
58 |
+
shuffle=False,
|
59 |
+
num_workers=args.workers,
|
60 |
+
pin_memory=False)
|
61 |
+
|
62 |
+
# Generate frames
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
# Magnification
|
66 |
+
for i, (xa, xb, mag_factor) in enumerate(data_loader):
|
67 |
+
if i % args.print_freq == 0:
|
68 |
+
print('processing sample: %d' % i)
|
69 |
+
|
70 |
+
mag_factor = mag_factor.unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
71 |
+
|
72 |
+
xa = xa.to(device)
|
73 |
+
xb = xb.to(device)
|
74 |
+
mag_factor = mag_factor.to(device)
|
75 |
+
|
76 |
+
y_hat, _, _, _ = model(xa, xb, mag_factor)
|
77 |
+
|
78 |
+
if i == 0:
|
79 |
+
# Back to image scale (0-255)
|
80 |
+
tmp = xa.permute(0, 2, 3, 1).cpu().detach().numpy()
|
81 |
+
tmp = np.clip(tmp, -1.0, 1.0)
|
82 |
+
tmp = ((tmp + 1.0) * 127.5).astype(np.uint8)
|
83 |
+
|
84 |
+
# Save first frame
|
85 |
+
fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i))
|
86 |
+
im = Image.fromarray(np.concatenate(tmp, 0))
|
87 |
+
im.save(fn)
|
88 |
+
|
89 |
+
# back to image scale (0-255)
|
90 |
+
y_hat = y_hat.permute(0, 2, 3, 1).cpu().detach().numpy()
|
91 |
+
y_hat = np.clip(y_hat, -1.0, 1.0)
|
92 |
+
y_hat = ((y_hat + 1.0) * 127.5).astype(np.uint8)
|
93 |
+
|
94 |
+
# Save frames
|
95 |
+
fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i+1))
|
96 |
+
im = Image.fromarray(np.concatenate(y_hat, 0))
|
97 |
+
im.save(fn)
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == '__main__':
|
101 |
+
parser = argparse.ArgumentParser(
|
102 |
+
description='Swin Transformer Based Video Motion Magnification')
|
103 |
+
|
104 |
+
# Application parameters
|
105 |
+
parser.add_argument('-i', '--video_path', type=str, metavar='PATH', required=True,
|
106 |
+
help='path to video input frames')
|
107 |
+
parser.add_argument('-c', '--load_ckpt', type=str, metavar='PATH', required=True,
|
108 |
+
help='path to load checkpoint')
|
109 |
+
parser.add_argument('-o', '--save_dir', default='demo', type=str, metavar='PATH',
|
110 |
+
help='path to save generated frames (default: demo)')
|
111 |
+
parser.add_argument('-m', '--mag', metavar='N', default=20.0, type=float,
|
112 |
+
help='magnification factor (default: 20.0)')
|
113 |
+
parser.add_argument('--mode', default='static', type=str, choices=['static', 'dynamic'],
|
114 |
+
help='magnification mode (static, dynamic)')
|
115 |
+
parser.add_argument('-n', '--num_data', type=int, metavar='N', required=True,
|
116 |
+
help='number of frames')
|
117 |
+
|
118 |
+
# Execute parameters
|
119 |
+
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
|
120 |
+
help='number of data loading workers (default: 16)')
|
121 |
+
parser.add_argument('-b', '--batch_size', default=1, type=int,
|
122 |
+
metavar='N', help='batch size (default: 1)')
|
123 |
+
parser.add_argument('-p', '--print_freq', default=100, type=int,
|
124 |
+
metavar='N', help='print frequency (default: 100)')
|
125 |
+
|
126 |
+
# Device
|
127 |
+
parser.add_argument('--device', type=str, default='auto',
|
128 |
+
choices=['auto', 'cpu', 'cuda', 'mps', 'xla'],
|
129 |
+
help='select device [auto/cpu/cuda] (default: auto)')
|
130 |
+
|
131 |
+
args = parser.parse_args()
|
132 |
+
|
133 |
+
main(args)
|
test_convert.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from models.model import STBVMM
|
4 |
+
|
5 |
+
|
6 |
+
model = STBVMM(img_size=384, patch_size=1, in_chans=3,
|
7 |
+
embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
|
8 |
+
window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
|
9 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
10 |
+
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
11 |
+
use_checkpoint=False, img_range=1., resi_connection='1conv',
|
12 |
+
manipulator_num_resblk=1).to("cpu")
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
checkpoint = torch.load('ckpt/ckpt_e10.pth.tar')
|
17 |
+
# print(checkpoint.keys())
|
18 |
+
|
19 |
+
print(checkpoint['state_dict'])
|
20 |
+
|
21 |
+
model.load_state_dict(checkpoint['state_dict'], strict= False)
|
22 |
+
# Get the keys in the checkpoint's state_dict
|
23 |
+
checkpoint_keys = set(checkpoint['state_dict'].keys())
|
24 |
+
|
25 |
+
# Get the keys in the current model's state_dict
|
26 |
+
model_keys = set(model.state_dict().keys())
|
27 |
+
|
28 |
+
# Find the difference between the keys
|
29 |
+
keys_only_in_checkpoint = checkpoint_keys - model_keys
|
30 |
+
keys_only_in_model = model_keys - checkpoint_keys
|
31 |
+
|
32 |
+
# Print the results
|
33 |
+
print("Keys only in the checkpoint's state_dict:")
|
34 |
+
print(keys_only_in_checkpoint)
|
35 |
+
|
36 |
+
|