Hila commited on
Commit
7754b29
·
1 Parent(s): 9f7f854

init commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +1 -0
  2. CLS2IDX.py +1000 -0
  3. README.md +124 -13
  4. RobustViT.ipynb +0 -0
  5. SegmentationTest/data/Imagenet.py +74 -0
  6. SegmentationTest/data/VOC.py +372 -0
  7. SegmentationTest/data/__init__.py +0 -0
  8. SegmentationTest/data/imagenet_utils.py +1002 -0
  9. SegmentationTest/data/transforms.py +442 -0
  10. SegmentationTest/imagenet_seg_eval.py +319 -0
  11. SegmentationTest/utils/__init__.py +0 -0
  12. SegmentationTest/utils/confusionmatrix.py +88 -0
  13. SegmentationTest/utils/iou.py +93 -0
  14. SegmentationTest/utils/metric.py +12 -0
  15. SegmentationTest/utils/metrices.py +208 -0
  16. SegmentationTest/utils/parallel.py +260 -0
  17. SegmentationTest/utils/render.py +266 -0
  18. SegmentationTest/utils/saver.py +34 -0
  19. SegmentationTest/utils/summaries.py +11 -0
  20. ViT/ViT.py +308 -0
  21. ViT_new.py → ViT/ViT_new.py +0 -0
  22. ViT/__init__.py +0 -0
  23. ViT/explainer.py +71 -0
  24. ViT/helpers.py +295 -0
  25. ViT/layer_helpers.py +21 -0
  26. ViT/weight_init.py +60 -0
  27. imagenet_ablation_gt.py +590 -0
  28. imagenet_classes.json +1002 -0
  29. imagenet_eval_robustness.py +337 -0
  30. imagenet_eval_robustness_per_class.py +343 -0
  31. imagenet_finetune.py +567 -0
  32. imagenet_finetune_gradmask.py +586 -0
  33. imagenet_finetune_rrr.py +570 -0
  34. imagenet_finetune_tokencut.py +577 -0
  35. label_str_to_imagenet_classes.py +133 -0
  36. objectnet_dataset.py +117 -0
  37. robustness_dataset.py +66 -0
  38. robustness_dataset_per_class.py +65 -0
  39. samples/augreg_base/1_in.png +0 -0
  40. samples/augreg_base/2_in.png +0 -0
  41. samples/augreg_base/3_in.png +0 -0
  42. samples/augreg_base/a.png +0 -0
  43. samples/augreg_base/a_2.png +0 -0
  44. samples/augreg_base/a_3.png +0 -0
  45. samples/catdog.png +0 -0
  46. samples/deit_base/1_in.png +0 -0
  47. samples/deit_base/2_in.png +0 -0
  48. samples/deit_base/3_in.png +0 -0
  49. samples/deit_base/a.png +0 -0
  50. samples/deit_base/a_2.png +0 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
CLS2IDX.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLS2IDX = {0: 'tench, Tinca tinca',
2
+ 1: 'goldfish, Carassius auratus',
3
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
4
+ 3: 'tiger shark, Galeocerdo cuvieri',
5
+ 4: 'hammerhead, hammerhead shark',
6
+ 5: 'electric ray, crampfish, numbfish, torpedo',
7
+ 6: 'stingray',
8
+ 7: 'cock',
9
+ 8: 'hen',
10
+ 9: 'ostrich, Struthio camelus',
11
+ 10: 'brambling, Fringilla montifringilla',
12
+ 11: 'goldfinch, Carduelis carduelis',
13
+ 12: 'house finch, linnet, Carpodacus mexicanus',
14
+ 13: 'junco, snowbird',
15
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
16
+ 15: 'robin, American robin, Turdus migratorius',
17
+ 16: 'bulbul',
18
+ 17: 'jay',
19
+ 18: 'magpie',
20
+ 19: 'chickadee',
21
+ 20: 'water ouzel, dipper',
22
+ 21: 'kite',
23
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
24
+ 23: 'vulture',
25
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
26
+ 25: 'European fire salamander, Salamandra salamandra',
27
+ 26: 'common newt, Triturus vulgaris',
28
+ 27: 'eft',
29
+ 28: 'spotted salamander, Ambystoma maculatum',
30
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
31
+ 30: 'bullfrog, Rana catesbeiana',
32
+ 31: 'tree frog, tree-frog',
33
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
34
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
35
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
36
+ 35: 'mud turtle',
37
+ 36: 'terrapin',
38
+ 37: 'box turtle, box tortoise',
39
+ 38: 'banded gecko',
40
+ 39: 'common iguana, iguana, Iguana iguana',
41
+ 40: 'American chameleon, anole, Anolis carolinensis',
42
+ 41: 'whiptail, whiptail lizard',
43
+ 42: 'agama',
44
+ 43: 'frilled lizard, Chlamydosaurus kingi',
45
+ 44: 'alligator lizard',
46
+ 45: 'Gila monster, Heloderma suspectum',
47
+ 46: 'green lizard, Lacerta viridis',
48
+ 47: 'African chameleon, Chamaeleo chamaeleon',
49
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
50
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
51
+ 50: 'American alligator, Alligator mississipiensis',
52
+ 51: 'triceratops',
53
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
54
+ 53: 'ringneck snake, ring-necked snake, ring snake',
55
+ 54: 'hognose snake, puff adder, sand viper',
56
+ 55: 'green snake, grass snake',
57
+ 56: 'king snake, kingsnake',
58
+ 57: 'garter snake, grass snake',
59
+ 58: 'water snake',
60
+ 59: 'vine snake',
61
+ 60: 'night snake, Hypsiglena torquata',
62
+ 61: 'boa constrictor, Constrictor constrictor',
63
+ 62: 'rock python, rock snake, Python sebae',
64
+ 63: 'Indian cobra, Naja naja',
65
+ 64: 'green mamba',
66
+ 65: 'sea snake',
67
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
68
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
69
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
70
+ 69: 'trilobite',
71
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
72
+ 71: 'scorpion',
73
+ 72: 'black and gold garden spider, Argiope aurantia',
74
+ 73: 'barn spider, Araneus cavaticus',
75
+ 74: 'garden spider, Aranea diademata',
76
+ 75: 'black widow, Latrodectus mactans',
77
+ 76: 'tarantula',
78
+ 77: 'wolf spider, hunting spider',
79
+ 78: 'tick',
80
+ 79: 'centipede',
81
+ 80: 'black grouse',
82
+ 81: 'ptarmigan',
83
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
84
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
85
+ 84: 'peacock',
86
+ 85: 'quail',
87
+ 86: 'partridge',
88
+ 87: 'African grey, African gray, Psittacus erithacus',
89
+ 88: 'macaw',
90
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
91
+ 90: 'lorikeet',
92
+ 91: 'coucal',
93
+ 92: 'bee eater',
94
+ 93: 'hornbill',
95
+ 94: 'hummingbird',
96
+ 95: 'jacamar',
97
+ 96: 'toucan',
98
+ 97: 'drake',
99
+ 98: 'red-breasted merganser, Mergus serrator',
100
+ 99: 'goose',
101
+ 100: 'black swan, Cygnus atratus',
102
+ 101: 'tusker',
103
+ 102: 'echidna, spiny anteater, anteater',
104
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
105
+ 104: 'wallaby, brush kangaroo',
106
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
107
+ 106: 'wombat',
108
+ 107: 'jellyfish',
109
+ 108: 'sea anemone, anemone',
110
+ 109: 'brain coral',
111
+ 110: 'flatworm, platyhelminth',
112
+ 111: 'nematode, nematode worm, roundworm',
113
+ 112: 'conch',
114
+ 113: 'snail',
115
+ 114: 'slug',
116
+ 115: 'sea slug, nudibranch',
117
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
118
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
119
+ 118: 'Dungeness crab, Cancer magister',
120
+ 119: 'rock crab, Cancer irroratus',
121
+ 120: 'fiddler crab',
122
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
123
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
124
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
125
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
126
+ 125: 'hermit crab',
127
+ 126: 'isopod',
128
+ 127: 'white stork, Ciconia ciconia',
129
+ 128: 'black stork, Ciconia nigra',
130
+ 129: 'spoonbill',
131
+ 130: 'flamingo',
132
+ 131: 'little blue heron, Egretta caerulea',
133
+ 132: 'American egret, great white heron, Egretta albus',
134
+ 133: 'bittern',
135
+ 134: 'crane',
136
+ 135: 'limpkin, Aramus pictus',
137
+ 136: 'European gallinule, Porphyrio porphyrio',
138
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
139
+ 138: 'bustard',
140
+ 139: 'ruddy turnstone, Arenaria interpres',
141
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
142
+ 141: 'redshank, Tringa totanus',
143
+ 142: 'dowitcher',
144
+ 143: 'oystercatcher, oyster catcher',
145
+ 144: 'pelican',
146
+ 145: 'king penguin, Aptenodytes patagonica',
147
+ 146: 'albatross, mollymawk',
148
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
149
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
150
+ 149: 'dugong, Dugong dugon',
151
+ 150: 'sea lion',
152
+ 151: 'Chihuahua',
153
+ 152: 'Japanese spaniel',
154
+ 153: 'Maltese dog, Maltese terrier, Maltese',
155
+ 154: 'Pekinese, Pekingese, Peke',
156
+ 155: 'Shih-Tzu',
157
+ 156: 'Blenheim spaniel',
158
+ 157: 'papillon',
159
+ 158: 'toy terrier',
160
+ 159: 'Rhodesian ridgeback',
161
+ 160: 'Afghan hound, Afghan',
162
+ 161: 'basset, basset hound',
163
+ 162: 'beagle',
164
+ 163: 'bloodhound, sleuthhound',
165
+ 164: 'bluetick',
166
+ 165: 'black-and-tan coonhound',
167
+ 166: 'Walker hound, Walker foxhound',
168
+ 167: 'English foxhound',
169
+ 168: 'redbone',
170
+ 169: 'borzoi, Russian wolfhound',
171
+ 170: 'Irish wolfhound',
172
+ 171: 'Italian greyhound',
173
+ 172: 'whippet',
174
+ 173: 'Ibizan hound, Ibizan Podenco',
175
+ 174: 'Norwegian elkhound, elkhound',
176
+ 175: 'otterhound, otter hound',
177
+ 176: 'Saluki, gazelle hound',
178
+ 177: 'Scottish deerhound, deerhound',
179
+ 178: 'Weimaraner',
180
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
181
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
182
+ 181: 'Bedlington terrier',
183
+ 182: 'Border terrier',
184
+ 183: 'Kerry blue terrier',
185
+ 184: 'Irish terrier',
186
+ 185: 'Norfolk terrier',
187
+ 186: 'Norwich terrier',
188
+ 187: 'Yorkshire terrier',
189
+ 188: 'wire-haired fox terrier',
190
+ 189: 'Lakeland terrier',
191
+ 190: 'Sealyham terrier, Sealyham',
192
+ 191: 'Airedale, Airedale terrier',
193
+ 192: 'cairn, cairn terrier',
194
+ 193: 'Australian terrier',
195
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
196
+ 195: 'Boston bull, Boston terrier',
197
+ 196: 'miniature schnauzer',
198
+ 197: 'giant schnauzer',
199
+ 198: 'standard schnauzer',
200
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
201
+ 200: 'Tibetan terrier, chrysanthemum dog',
202
+ 201: 'silky terrier, Sydney silky',
203
+ 202: 'soft-coated wheaten terrier',
204
+ 203: 'West Highland white terrier',
205
+ 204: 'Lhasa, Lhasa apso',
206
+ 205: 'flat-coated retriever',
207
+ 206: 'curly-coated retriever',
208
+ 207: 'golden retriever',
209
+ 208: 'Labrador retriever',
210
+ 209: 'Chesapeake Bay retriever',
211
+ 210: 'German short-haired pointer',
212
+ 211: 'vizsla, Hungarian pointer',
213
+ 212: 'English setter',
214
+ 213: 'Irish setter, red setter',
215
+ 214: 'Gordon setter',
216
+ 215: 'Brittany spaniel',
217
+ 216: 'clumber, clumber spaniel',
218
+ 217: 'English springer, English springer spaniel',
219
+ 218: 'Welsh springer spaniel',
220
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
221
+ 220: 'Sussex spaniel',
222
+ 221: 'Irish water spaniel',
223
+ 222: 'kuvasz',
224
+ 223: 'schipperke',
225
+ 224: 'groenendael',
226
+ 225: 'malinois',
227
+ 226: 'briard',
228
+ 227: 'kelpie',
229
+ 228: 'komondor',
230
+ 229: 'Old English sheepdog, bobtail',
231
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
232
+ 231: 'collie',
233
+ 232: 'Border collie',
234
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
235
+ 234: 'Rottweiler',
236
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
237
+ 236: 'Doberman, Doberman pinscher',
238
+ 237: 'miniature pinscher',
239
+ 238: 'Greater Swiss Mountain dog',
240
+ 239: 'Bernese mountain dog',
241
+ 240: 'Appenzeller',
242
+ 241: 'EntleBucher',
243
+ 242: 'boxer',
244
+ 243: 'bull mastiff',
245
+ 244: 'Tibetan mastiff',
246
+ 245: 'French bulldog',
247
+ 246: 'Great Dane',
248
+ 247: 'Saint Bernard, St Bernard',
249
+ 248: 'Eskimo dog, husky',
250
+ 249: 'malamute, malemute, Alaskan malamute',
251
+ 250: 'Siberian husky',
252
+ 251: 'dalmatian, coach dog, carriage dog',
253
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
254
+ 253: 'basenji',
255
+ 254: 'pug, pug-dog',
256
+ 255: 'Leonberg',
257
+ 256: 'Newfoundland, Newfoundland dog',
258
+ 257: 'Great Pyrenees',
259
+ 258: 'Samoyed, Samoyede',
260
+ 259: 'Pomeranian',
261
+ 260: 'chow, chow chow',
262
+ 261: 'keeshond',
263
+ 262: 'Brabancon griffon',
264
+ 263: 'Pembroke, Pembroke Welsh corgi',
265
+ 264: 'Cardigan, Cardigan Welsh corgi',
266
+ 265: 'toy poodle',
267
+ 266: 'miniature poodle',
268
+ 267: 'standard poodle',
269
+ 268: 'Mexican hairless',
270
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
271
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
272
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
273
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
274
+ 273: 'dingo, warrigal, warragal, Canis dingo',
275
+ 274: 'dhole, Cuon alpinus',
276
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
277
+ 276: 'hyena, hyaena',
278
+ 277: 'red fox, Vulpes vulpes',
279
+ 278: 'kit fox, Vulpes macrotis',
280
+ 279: 'Arctic fox, white fox, Alopex lagopus',
281
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
282
+ 281: 'tabby, tabby cat',
283
+ 282: 'tiger cat',
284
+ 283: 'Persian cat',
285
+ 284: 'Siamese cat, Siamese',
286
+ 285: 'Egyptian cat',
287
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
288
+ 287: 'lynx, catamount',
289
+ 288: 'leopard, Panthera pardus',
290
+ 289: 'snow leopard, ounce, Panthera uncia',
291
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
292
+ 291: 'lion, king of beasts, Panthera leo',
293
+ 292: 'tiger, Panthera tigris',
294
+ 293: 'cheetah, chetah, Acinonyx jubatus',
295
+ 294: 'brown bear, bruin, Ursus arctos',
296
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
297
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
298
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
299
+ 298: 'mongoose',
300
+ 299: 'meerkat, mierkat',
301
+ 300: 'tiger beetle',
302
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
303
+ 302: 'ground beetle, carabid beetle',
304
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
305
+ 304: 'leaf beetle, chrysomelid',
306
+ 305: 'dung beetle',
307
+ 306: 'rhinoceros beetle',
308
+ 307: 'weevil',
309
+ 308: 'fly',
310
+ 309: 'bee',
311
+ 310: 'ant, emmet, pismire',
312
+ 311: 'grasshopper, hopper',
313
+ 312: 'cricket',
314
+ 313: 'walking stick, walkingstick, stick insect',
315
+ 314: 'cockroach, roach',
316
+ 315: 'mantis, mantid',
317
+ 316: 'cicada, cicala',
318
+ 317: 'leafhopper',
319
+ 318: 'lacewing, lacewing fly',
320
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
321
+ 320: 'damselfly',
322
+ 321: 'admiral',
323
+ 322: 'ringlet, ringlet butterfly',
324
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
325
+ 324: 'cabbage butterfly',
326
+ 325: 'sulphur butterfly, sulfur butterfly',
327
+ 326: 'lycaenid, lycaenid butterfly',
328
+ 327: 'starfish, sea star',
329
+ 328: 'sea urchin',
330
+ 329: 'sea cucumber, holothurian',
331
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
332
+ 331: 'hare',
333
+ 332: 'Angora, Angora rabbit',
334
+ 333: 'hamster',
335
+ 334: 'porcupine, hedgehog',
336
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
337
+ 336: 'marmot',
338
+ 337: 'beaver',
339
+ 338: 'guinea pig, Cavia cobaya',
340
+ 339: 'sorrel',
341
+ 340: 'zebra',
342
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
343
+ 342: 'wild boar, boar, Sus scrofa',
344
+ 343: 'warthog',
345
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
346
+ 345: 'ox',
347
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
348
+ 347: 'bison',
349
+ 348: 'ram, tup',
350
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
351
+ 350: 'ibex, Capra ibex',
352
+ 351: 'hartebeest',
353
+ 352: 'impala, Aepyceros melampus',
354
+ 353: 'gazelle',
355
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
356
+ 355: 'llama',
357
+ 356: 'weasel',
358
+ 357: 'mink',
359
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
360
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
361
+ 360: 'otter',
362
+ 361: 'skunk, polecat, wood pussy',
363
+ 362: 'badger',
364
+ 363: 'armadillo',
365
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
366
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
367
+ 366: 'gorilla, Gorilla gorilla',
368
+ 367: 'chimpanzee, chimp, Pan troglodytes',
369
+ 368: 'gibbon, Hylobates lar',
370
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
371
+ 370: 'guenon, guenon monkey',
372
+ 371: 'patas, hussar monkey, Erythrocebus patas',
373
+ 372: 'baboon',
374
+ 373: 'macaque',
375
+ 374: 'langur',
376
+ 375: 'colobus, colobus monkey',
377
+ 376: 'proboscis monkey, Nasalis larvatus',
378
+ 377: 'marmoset',
379
+ 378: 'capuchin, ringtail, Cebus capucinus',
380
+ 379: 'howler monkey, howler',
381
+ 380: 'titi, titi monkey',
382
+ 381: 'spider monkey, Ateles geoffroyi',
383
+ 382: 'squirrel monkey, Saimiri sciureus',
384
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
385
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
386
+ 385: 'Indian elephant, Elephas maximus',
387
+ 386: 'African elephant, Loxodonta africana',
388
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
389
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
390
+ 389: 'barracouta, snoek',
391
+ 390: 'eel',
392
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
393
+ 392: 'rock beauty, Holocanthus tricolor',
394
+ 393: 'anemone fish',
395
+ 394: 'sturgeon',
396
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
397
+ 396: 'lionfish',
398
+ 397: 'puffer, pufferfish, blowfish, globefish',
399
+ 398: 'abacus',
400
+ 399: 'abaya',
401
+ 400: "academic gown, academic robe, judge's robe",
402
+ 401: 'accordion, piano accordion, squeeze box',
403
+ 402: 'acoustic guitar',
404
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
405
+ 404: 'airliner',
406
+ 405: 'airship, dirigible',
407
+ 406: 'altar',
408
+ 407: 'ambulance',
409
+ 408: 'amphibian, amphibious vehicle',
410
+ 409: 'analog clock',
411
+ 410: 'apiary, bee house',
412
+ 411: 'apron',
413
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
414
+ 413: 'assault rifle, assault gun',
415
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
416
+ 415: 'bakery, bakeshop, bakehouse',
417
+ 416: 'balance beam, beam',
418
+ 417: 'balloon',
419
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
420
+ 419: 'Band Aid',
421
+ 420: 'banjo',
422
+ 421: 'bannister, banister, balustrade, balusters, handrail',
423
+ 422: 'barbell',
424
+ 423: 'barber chair',
425
+ 424: 'barbershop',
426
+ 425: 'barn',
427
+ 426: 'barometer',
428
+ 427: 'barrel, cask',
429
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
430
+ 429: 'baseball',
431
+ 430: 'basketball',
432
+ 431: 'bassinet',
433
+ 432: 'bassoon',
434
+ 433: 'bathing cap, swimming cap',
435
+ 434: 'bath towel',
436
+ 435: 'bathtub, bathing tub, bath, tub',
437
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
438
+ 437: 'beacon, lighthouse, beacon light, pharos',
439
+ 438: 'beaker',
440
+ 439: 'bearskin, busby, shako',
441
+ 440: 'beer bottle',
442
+ 441: 'beer glass',
443
+ 442: 'bell cote, bell cot',
444
+ 443: 'bib',
445
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
446
+ 445: 'bikini, two-piece',
447
+ 446: 'binder, ring-binder',
448
+ 447: 'binoculars, field glasses, opera glasses',
449
+ 448: 'birdhouse',
450
+ 449: 'boathouse',
451
+ 450: 'bobsled, bobsleigh, bob',
452
+ 451: 'bolo tie, bolo, bola tie, bola',
453
+ 452: 'bonnet, poke bonnet',
454
+ 453: 'bookcase',
455
+ 454: 'bookshop, bookstore, bookstall',
456
+ 455: 'bottlecap',
457
+ 456: 'bow',
458
+ 457: 'bow tie, bow-tie, bowtie',
459
+ 458: 'brass, memorial tablet, plaque',
460
+ 459: 'brassiere, bra, bandeau',
461
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
462
+ 461: 'breastplate, aegis, egis',
463
+ 462: 'broom',
464
+ 463: 'bucket, pail',
465
+ 464: 'buckle',
466
+ 465: 'bulletproof vest',
467
+ 466: 'bullet train, bullet',
468
+ 467: 'butcher shop, meat market',
469
+ 468: 'cab, hack, taxi, taxicab',
470
+ 469: 'caldron, cauldron',
471
+ 470: 'candle, taper, wax light',
472
+ 471: 'cannon',
473
+ 472: 'canoe',
474
+ 473: 'can opener, tin opener',
475
+ 474: 'cardigan',
476
+ 475: 'car mirror',
477
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
478
+ 477: "carpenter's kit, tool kit",
479
+ 478: 'carton',
480
+ 479: 'car wheel',
481
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
482
+ 481: 'cassette',
483
+ 482: 'cassette player',
484
+ 483: 'castle',
485
+ 484: 'catamaran',
486
+ 485: 'CD player',
487
+ 486: 'cello, violoncello',
488
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
489
+ 488: 'chain',
490
+ 489: 'chainlink fence',
491
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
492
+ 491: 'chain saw, chainsaw',
493
+ 492: 'chest',
494
+ 493: 'chiffonier, commode',
495
+ 494: 'chime, bell, gong',
496
+ 495: 'china cabinet, china closet',
497
+ 496: 'Christmas stocking',
498
+ 497: 'church, church building',
499
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
500
+ 499: 'cleaver, meat cleaver, chopper',
501
+ 500: 'cliff dwelling',
502
+ 501: 'cloak',
503
+ 502: 'clog, geta, patten, sabot',
504
+ 503: 'cocktail shaker',
505
+ 504: 'coffee mug',
506
+ 505: 'coffeepot',
507
+ 506: 'coil, spiral, volute, whorl, helix',
508
+ 507: 'combination lock',
509
+ 508: 'computer keyboard, keypad',
510
+ 509: 'confectionery, confectionary, candy store',
511
+ 510: 'container ship, containership, container vessel',
512
+ 511: 'convertible',
513
+ 512: 'corkscrew, bottle screw',
514
+ 513: 'cornet, horn, trumpet, trump',
515
+ 514: 'cowboy boot',
516
+ 515: 'cowboy hat, ten-gallon hat',
517
+ 516: 'cradle',
518
+ 517: 'crane',
519
+ 518: 'crash helmet',
520
+ 519: 'crate',
521
+ 520: 'crib, cot',
522
+ 521: 'Crock Pot',
523
+ 522: 'croquet ball',
524
+ 523: 'crutch',
525
+ 524: 'cuirass',
526
+ 525: 'dam, dike, dyke',
527
+ 526: 'desk',
528
+ 527: 'desktop computer',
529
+ 528: 'dial telephone, dial phone',
530
+ 529: 'diaper, nappy, napkin',
531
+ 530: 'digital clock',
532
+ 531: 'digital watch',
533
+ 532: 'dining table, board',
534
+ 533: 'dishrag, dishcloth',
535
+ 534: 'dishwasher, dish washer, dishwashing machine',
536
+ 535: 'disk brake, disc brake',
537
+ 536: 'dock, dockage, docking facility',
538
+ 537: 'dogsled, dog sled, dog sleigh',
539
+ 538: 'dome',
540
+ 539: 'doormat, welcome mat',
541
+ 540: 'drilling platform, offshore rig',
542
+ 541: 'drum, membranophone, tympan',
543
+ 542: 'drumstick',
544
+ 543: 'dumbbell',
545
+ 544: 'Dutch oven',
546
+ 545: 'electric fan, blower',
547
+ 546: 'electric guitar',
548
+ 547: 'electric locomotive',
549
+ 548: 'entertainment center',
550
+ 549: 'envelope',
551
+ 550: 'espresso maker',
552
+ 551: 'face powder',
553
+ 552: 'feather boa, boa',
554
+ 553: 'file, file cabinet, filing cabinet',
555
+ 554: 'fireboat',
556
+ 555: 'fire engine, fire truck',
557
+ 556: 'fire screen, fireguard',
558
+ 557: 'flagpole, flagstaff',
559
+ 558: 'flute, transverse flute',
560
+ 559: 'folding chair',
561
+ 560: 'football helmet',
562
+ 561: 'forklift',
563
+ 562: 'fountain',
564
+ 563: 'fountain pen',
565
+ 564: 'four-poster',
566
+ 565: 'freight car',
567
+ 566: 'French horn, horn',
568
+ 567: 'frying pan, frypan, skillet',
569
+ 568: 'fur coat',
570
+ 569: 'garbage truck, dustcart',
571
+ 570: 'gasmask, respirator, gas helmet',
572
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
573
+ 572: 'goblet',
574
+ 573: 'go-kart',
575
+ 574: 'golf ball',
576
+ 575: 'golfcart, golf cart',
577
+ 576: 'gondola',
578
+ 577: 'gong, tam-tam',
579
+ 578: 'gown',
580
+ 579: 'grand piano, grand',
581
+ 580: 'greenhouse, nursery, glasshouse',
582
+ 581: 'grille, radiator grille',
583
+ 582: 'grocery store, grocery, food market, market',
584
+ 583: 'guillotine',
585
+ 584: 'hair slide',
586
+ 585: 'hair spray',
587
+ 586: 'half track',
588
+ 587: 'hammer',
589
+ 588: 'hamper',
590
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
591
+ 590: 'hand-held computer, hand-held microcomputer',
592
+ 591: 'handkerchief, hankie, hanky, hankey',
593
+ 592: 'hard disc, hard disk, fixed disk',
594
+ 593: 'harmonica, mouth organ, harp, mouth harp',
595
+ 594: 'harp',
596
+ 595: 'harvester, reaper',
597
+ 596: 'hatchet',
598
+ 597: 'holster',
599
+ 598: 'home theater, home theatre',
600
+ 599: 'honeycomb',
601
+ 600: 'hook, claw',
602
+ 601: 'hoopskirt, crinoline',
603
+ 602: 'horizontal bar, high bar',
604
+ 603: 'horse cart, horse-cart',
605
+ 604: 'hourglass',
606
+ 605: 'iPod',
607
+ 606: 'iron, smoothing iron',
608
+ 607: "jack-o'-lantern",
609
+ 608: 'jean, blue jean, denim',
610
+ 609: 'jeep, landrover',
611
+ 610: 'jersey, T-shirt, tee shirt',
612
+ 611: 'jigsaw puzzle',
613
+ 612: 'jinrikisha, ricksha, rickshaw',
614
+ 613: 'joystick',
615
+ 614: 'kimono',
616
+ 615: 'knee pad',
617
+ 616: 'knot',
618
+ 617: 'lab coat, laboratory coat',
619
+ 618: 'ladle',
620
+ 619: 'lampshade, lamp shade',
621
+ 620: 'laptop, laptop computer',
622
+ 621: 'lawn mower, mower',
623
+ 622: 'lens cap, lens cover',
624
+ 623: 'letter opener, paper knife, paperknife',
625
+ 624: 'library',
626
+ 625: 'lifeboat',
627
+ 626: 'lighter, light, igniter, ignitor',
628
+ 627: 'limousine, limo',
629
+ 628: 'liner, ocean liner',
630
+ 629: 'lipstick, lip rouge',
631
+ 630: 'Loafer',
632
+ 631: 'lotion',
633
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
634
+ 633: "loupe, jeweler's loupe",
635
+ 634: 'lumbermill, sawmill',
636
+ 635: 'magnetic compass',
637
+ 636: 'mailbag, postbag',
638
+ 637: 'mailbox, letter box',
639
+ 638: 'maillot',
640
+ 639: 'maillot, tank suit',
641
+ 640: 'manhole cover',
642
+ 641: 'maraca',
643
+ 642: 'marimba, xylophone',
644
+ 643: 'mask',
645
+ 644: 'matchstick',
646
+ 645: 'maypole',
647
+ 646: 'maze, labyrinth',
648
+ 647: 'measuring cup',
649
+ 648: 'medicine chest, medicine cabinet',
650
+ 649: 'megalith, megalithic structure',
651
+ 650: 'microphone, mike',
652
+ 651: 'microwave, microwave oven',
653
+ 652: 'military uniform',
654
+ 653: 'milk can',
655
+ 654: 'minibus',
656
+ 655: 'miniskirt, mini',
657
+ 656: 'minivan',
658
+ 657: 'missile',
659
+ 658: 'mitten',
660
+ 659: 'mixing bowl',
661
+ 660: 'mobile home, manufactured home',
662
+ 661: 'Model T',
663
+ 662: 'modem',
664
+ 663: 'monastery',
665
+ 664: 'monitor',
666
+ 665: 'moped',
667
+ 666: 'mortar',
668
+ 667: 'mortarboard',
669
+ 668: 'mosque',
670
+ 669: 'mosquito net',
671
+ 670: 'motor scooter, scooter',
672
+ 671: 'mountain bike, all-terrain bike, off-roader',
673
+ 672: 'mountain tent',
674
+ 673: 'mouse, computer mouse',
675
+ 674: 'mousetrap',
676
+ 675: 'moving van',
677
+ 676: 'muzzle',
678
+ 677: 'nail',
679
+ 678: 'neck brace',
680
+ 679: 'necklace',
681
+ 680: 'nipple',
682
+ 681: 'notebook, notebook computer',
683
+ 682: 'obelisk',
684
+ 683: 'oboe, hautboy, hautbois',
685
+ 684: 'ocarina, sweet potato',
686
+ 685: 'odometer, hodometer, mileometer, milometer',
687
+ 686: 'oil filter',
688
+ 687: 'organ, pipe organ',
689
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
690
+ 689: 'overskirt',
691
+ 690: 'oxcart',
692
+ 691: 'oxygen mask',
693
+ 692: 'packet',
694
+ 693: 'paddle, boat paddle',
695
+ 694: 'paddlewheel, paddle wheel',
696
+ 695: 'padlock',
697
+ 696: 'paintbrush',
698
+ 697: "pajama, pyjama, pj's, jammies",
699
+ 698: 'palace',
700
+ 699: 'panpipe, pandean pipe, syrinx',
701
+ 700: 'paper towel',
702
+ 701: 'parachute, chute',
703
+ 702: 'parallel bars, bars',
704
+ 703: 'park bench',
705
+ 704: 'parking meter',
706
+ 705: 'passenger car, coach, carriage',
707
+ 706: 'patio, terrace',
708
+ 707: 'pay-phone, pay-station',
709
+ 708: 'pedestal, plinth, footstall',
710
+ 709: 'pencil box, pencil case',
711
+ 710: 'pencil sharpener',
712
+ 711: 'perfume, essence',
713
+ 712: 'Petri dish',
714
+ 713: 'photocopier',
715
+ 714: 'pick, plectrum, plectron',
716
+ 715: 'pickelhaube',
717
+ 716: 'picket fence, paling',
718
+ 717: 'pickup, pickup truck',
719
+ 718: 'pier',
720
+ 719: 'piggy bank, penny bank',
721
+ 720: 'pill bottle',
722
+ 721: 'pillow',
723
+ 722: 'ping-pong ball',
724
+ 723: 'pinwheel',
725
+ 724: 'pirate, pirate ship',
726
+ 725: 'pitcher, ewer',
727
+ 726: "plane, carpenter's plane, woodworking plane",
728
+ 727: 'planetarium',
729
+ 728: 'plastic bag',
730
+ 729: 'plate rack',
731
+ 730: 'plow, plough',
732
+ 731: "plunger, plumber's helper",
733
+ 732: 'Polaroid camera, Polaroid Land camera',
734
+ 733: 'pole',
735
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
736
+ 735: 'poncho',
737
+ 736: 'pool table, billiard table, snooker table',
738
+ 737: 'pop bottle, soda bottle',
739
+ 738: 'pot, flowerpot',
740
+ 739: "potter's wheel",
741
+ 740: 'power drill',
742
+ 741: 'prayer rug, prayer mat',
743
+ 742: 'printer',
744
+ 743: 'prison, prison house',
745
+ 744: 'projectile, missile',
746
+ 745: 'projector',
747
+ 746: 'puck, hockey puck',
748
+ 747: 'punching bag, punch bag, punching ball, punchball',
749
+ 748: 'purse',
750
+ 749: 'quill, quill pen',
751
+ 750: 'quilt, comforter, comfort, puff',
752
+ 751: 'racer, race car, racing car',
753
+ 752: 'racket, racquet',
754
+ 753: 'radiator',
755
+ 754: 'radio, wireless',
756
+ 755: 'radio telescope, radio reflector',
757
+ 756: 'rain barrel',
758
+ 757: 'recreational vehicle, RV, R.V.',
759
+ 758: 'reel',
760
+ 759: 'reflex camera',
761
+ 760: 'refrigerator, icebox',
762
+ 761: 'remote control, remote',
763
+ 762: 'restaurant, eating house, eating place, eatery',
764
+ 763: 'revolver, six-gun, six-shooter',
765
+ 764: 'rifle',
766
+ 765: 'rocking chair, rocker',
767
+ 766: 'rotisserie',
768
+ 767: 'rubber eraser, rubber, pencil eraser',
769
+ 768: 'rugby ball',
770
+ 769: 'rule, ruler',
771
+ 770: 'running shoe',
772
+ 771: 'safe',
773
+ 772: 'safety pin',
774
+ 773: 'saltshaker, salt shaker',
775
+ 774: 'sandal',
776
+ 775: 'sarong',
777
+ 776: 'sax, saxophone',
778
+ 777: 'scabbard',
779
+ 778: 'scale, weighing machine',
780
+ 779: 'school bus',
781
+ 780: 'schooner',
782
+ 781: 'scoreboard',
783
+ 782: 'screen, CRT screen',
784
+ 783: 'screw',
785
+ 784: 'screwdriver',
786
+ 785: 'seat belt, seatbelt',
787
+ 786: 'sewing machine',
788
+ 787: 'shield, buckler',
789
+ 788: 'shoe shop, shoe-shop, shoe store',
790
+ 789: 'shoji',
791
+ 790: 'shopping basket',
792
+ 791: 'shopping cart',
793
+ 792: 'shovel',
794
+ 793: 'shower cap',
795
+ 794: 'shower curtain',
796
+ 795: 'ski',
797
+ 796: 'ski mask',
798
+ 797: 'sleeping bag',
799
+ 798: 'slide rule, slipstick',
800
+ 799: 'sliding door',
801
+ 800: 'slot, one-armed bandit',
802
+ 801: 'snorkel',
803
+ 802: 'snowmobile',
804
+ 803: 'snowplow, snowplough',
805
+ 804: 'soap dispenser',
806
+ 805: 'soccer ball',
807
+ 806: 'sock',
808
+ 807: 'solar dish, solar collector, solar furnace',
809
+ 808: 'sombrero',
810
+ 809: 'soup bowl',
811
+ 810: 'space bar',
812
+ 811: 'space heater',
813
+ 812: 'space shuttle',
814
+ 813: 'spatula',
815
+ 814: 'speedboat',
816
+ 815: "spider web, spider's web",
817
+ 816: 'spindle',
818
+ 817: 'sports car, sport car',
819
+ 818: 'spotlight, spot',
820
+ 819: 'stage',
821
+ 820: 'steam locomotive',
822
+ 821: 'steel arch bridge',
823
+ 822: 'steel drum',
824
+ 823: 'stethoscope',
825
+ 824: 'stole',
826
+ 825: 'stone wall',
827
+ 826: 'stopwatch, stop watch',
828
+ 827: 'stove',
829
+ 828: 'strainer',
830
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
831
+ 830: 'stretcher',
832
+ 831: 'studio couch, day bed',
833
+ 832: 'stupa, tope',
834
+ 833: 'submarine, pigboat, sub, U-boat',
835
+ 834: 'suit, suit of clothes',
836
+ 835: 'sundial',
837
+ 836: 'sunglass',
838
+ 837: 'sunglasses, dark glasses, shades',
839
+ 838: 'sunscreen, sunblock, sun blocker',
840
+ 839: 'suspension bridge',
841
+ 840: 'swab, swob, mop',
842
+ 841: 'sweatshirt',
843
+ 842: 'swimming trunks, bathing trunks',
844
+ 843: 'swing',
845
+ 844: 'switch, electric switch, electrical switch',
846
+ 845: 'syringe',
847
+ 846: 'table lamp',
848
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
849
+ 848: 'tape player',
850
+ 849: 'teapot',
851
+ 850: 'teddy, teddy bear',
852
+ 851: 'television, television system',
853
+ 852: 'tennis ball',
854
+ 853: 'thatch, thatched roof',
855
+ 854: 'theater curtain, theatre curtain',
856
+ 855: 'thimble',
857
+ 856: 'thresher, thrasher, threshing machine',
858
+ 857: 'throne',
859
+ 858: 'tile roof',
860
+ 859: 'toaster',
861
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
862
+ 861: 'toilet seat',
863
+ 862: 'torch',
864
+ 863: 'totem pole',
865
+ 864: 'tow truck, tow car, wrecker',
866
+ 865: 'toyshop',
867
+ 866: 'tractor',
868
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
869
+ 868: 'tray',
870
+ 869: 'trench coat',
871
+ 870: 'tricycle, trike, velocipede',
872
+ 871: 'trimaran',
873
+ 872: 'tripod',
874
+ 873: 'triumphal arch',
875
+ 874: 'trolleybus, trolley coach, trackless trolley',
876
+ 875: 'trombone',
877
+ 876: 'tub, vat',
878
+ 877: 'turnstile',
879
+ 878: 'typewriter keyboard',
880
+ 879: 'umbrella',
881
+ 880: 'unicycle, monocycle',
882
+ 881: 'upright, upright piano',
883
+ 882: 'vacuum, vacuum cleaner',
884
+ 883: 'vase',
885
+ 884: 'vault',
886
+ 885: 'velvet',
887
+ 886: 'vending machine',
888
+ 887: 'vestment',
889
+ 888: 'viaduct',
890
+ 889: 'violin, fiddle',
891
+ 890: 'volleyball',
892
+ 891: 'waffle iron',
893
+ 892: 'wall clock',
894
+ 893: 'wallet, billfold, notecase, pocketbook',
895
+ 894: 'wardrobe, closet, press',
896
+ 895: 'warplane, military plane',
897
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
898
+ 897: 'washer, automatic washer, washing machine',
899
+ 898: 'water bottle',
900
+ 899: 'water jug',
901
+ 900: 'water tower',
902
+ 901: 'whiskey jug',
903
+ 902: 'whistle',
904
+ 903: 'wig',
905
+ 904: 'window screen',
906
+ 905: 'window shade',
907
+ 906: 'Windsor tie',
908
+ 907: 'wine bottle',
909
+ 908: 'wing',
910
+ 909: 'wok',
911
+ 910: 'wooden spoon',
912
+ 911: 'wool, woolen, woollen',
913
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
914
+ 913: 'wreck',
915
+ 914: 'yawl',
916
+ 915: 'yurt',
917
+ 916: 'web site, website, internet site, site',
918
+ 917: 'comic book',
919
+ 918: 'crossword puzzle, crossword',
920
+ 919: 'street sign',
921
+ 920: 'traffic light, traffic signal, stoplight',
922
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
923
+ 922: 'menu',
924
+ 923: 'plate',
925
+ 924: 'guacamole',
926
+ 925: 'consomme',
927
+ 926: 'hot pot, hotpot',
928
+ 927: 'trifle',
929
+ 928: 'ice cream, icecream',
930
+ 929: 'ice lolly, lolly, lollipop, popsicle',
931
+ 930: 'French loaf',
932
+ 931: 'bagel, beigel',
933
+ 932: 'pretzel',
934
+ 933: 'cheeseburger',
935
+ 934: 'hotdog, hot dog, red hot',
936
+ 935: 'mashed potato',
937
+ 936: 'head cabbage',
938
+ 937: 'broccoli',
939
+ 938: 'cauliflower',
940
+ 939: 'zucchini, courgette',
941
+ 940: 'spaghetti squash',
942
+ 941: 'acorn squash',
943
+ 942: 'butternut squash',
944
+ 943: 'cucumber, cuke',
945
+ 944: 'artichoke, globe artichoke',
946
+ 945: 'bell pepper',
947
+ 946: 'cardoon',
948
+ 947: 'mushroom',
949
+ 948: 'Granny Smith',
950
+ 949: 'strawberry',
951
+ 950: 'orange',
952
+ 951: 'lemon',
953
+ 952: 'fig',
954
+ 953: 'pineapple, ananas',
955
+ 954: 'banana',
956
+ 955: 'jackfruit, jak, jack',
957
+ 956: 'custard apple',
958
+ 957: 'pomegranate',
959
+ 958: 'hay',
960
+ 959: 'carbonara',
961
+ 960: 'chocolate sauce, chocolate syrup',
962
+ 961: 'dough',
963
+ 962: 'meat loaf, meatloaf',
964
+ 963: 'pizza, pizza pie',
965
+ 964: 'potpie',
966
+ 965: 'burrito',
967
+ 966: 'red wine',
968
+ 967: 'espresso',
969
+ 968: 'cup',
970
+ 969: 'eggnog',
971
+ 970: 'alp',
972
+ 971: 'bubble',
973
+ 972: 'cliff, drop, drop-off',
974
+ 973: 'coral reef',
975
+ 974: 'geyser',
976
+ 975: 'lakeside, lakeshore',
977
+ 976: 'promontory, headland, head, foreland',
978
+ 977: 'sandbar, sand bar',
979
+ 978: 'seashore, coast, seacoast, sea-coast',
980
+ 979: 'valley, vale',
981
+ 980: 'volcano',
982
+ 981: 'ballplayer, baseball player',
983
+ 982: 'groom, bridegroom',
984
+ 983: 'scuba diver',
985
+ 984: 'rapeseed',
986
+ 985: 'daisy',
987
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
988
+ 987: 'corn',
989
+ 988: 'acorn',
990
+ 989: 'hip, rose hip, rosehip',
991
+ 990: 'buckeye, horse chestnut, conker',
992
+ 991: 'coral fungus',
993
+ 992: 'agaric',
994
+ 993: 'gyromitra',
995
+ 994: 'stinkhorn, carrion fungus',
996
+ 995: 'earthstar',
997
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
998
+ 997: 'bolete',
999
+ 998: 'ear, spike, capitulum',
1000
+ 999: 'toilet tissue, toilet paper, bathroom tissue'}
README.md CHANGED
@@ -1,13 +1,124 @@
1
- ---
2
- title: RobustViT
3
- emoji:
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.0.11
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RobustViT
2
+
3
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hila-chefer/RobustViT/blob/master/RobustViT.ipynb)
4
+
5
+ Official PyTorch implementation of **Optimizing Relevance Maps of Vision Transformers Improves Robustness**. This code allows to
6
+ finetune the explainability maps of Vision Transformers to enhance robustness.
7
+
8
+ The method employs loss functions directly to the explainability maps to ensure that the model is focused mostly on the foreground of the image:
9
+ <p align="center">
10
+ <img width="500" height="400" src="teaser.png">
11
+ </p>
12
+ Using a short finetuning process with only 3 labeled examples from 500 classes, our method imrpoves robustness of ViT models across different model sizes and training techniques, even when data augmentations/ regularization are applied.
13
+
14
+ ## Producing Segmenataion Data
15
+ ### Using ImageNet-S
16
+ To use the ImageNet-S labeled data, [download the `ImageNetS919` dataset](https://github.com/UnsupervisedSemanticSegmentation/ImageNet-S)
17
+
18
+ ### Using TokenCut for unsupervised segmentation
19
+ 1. Clone the TokenCut project
20
+ ```
21
+ git clone https://github.com/YangtaoWANG95/TokenCut.git
22
+ ```
23
+ 2. Install the dependencies
24
+ Python 3.7, PyTorch 1.7.1 and CUDA 11.2. Please refer to the official installation. If CUDA 10.2 has been properly installed:
25
+ ```
26
+ pip install torch==1.7.1 torchvision==0.8.2
27
+ ```
28
+ Followed by
29
+ ```
30
+ pip install -r TokenCut/requirements.txt
31
+
32
+ 3. Use the following command to extract the segmentation maps:
33
+ ```
34
+ python tokencut_generate_segmentation.py --img_path <PATH_TO_IMAGE> --out_dir <PATH_TO_OUTPUT_DIRECTORY>
35
+ ```
36
+
37
+
38
+ ## Finetuning ViT models
39
+
40
+ To finetune a pretrained ViT model use the `imagenet_finetune.py` script. Notice to uncomment the import line containing the pretrained model you
41
+ wish to finetune.
42
+
43
+ Usage example:
44
+
45
+ ```bash
46
+ python imagenet_finetune.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0 --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC> --lambda_background <BACK> --lambda_foreground <FORE>
47
+ ```
48
+
49
+ Notes:
50
+
51
+ * For all models we use :
52
+ * `lambda_seg=0.8`
53
+ * `lambda_acc=0.2`
54
+ * `lambda_background=2`
55
+ * `lambda_foreground=0.3`
56
+ * For **DeiT** models, a temprature is required as follows:
57
+ * `temprature=0.65` for DeiT-B
58
+ * `temprature=0.55` for DeiT-S
59
+ * The learning rates per model are:
60
+ * ViT-B: 3e-6
61
+ * ViT-L: 9e-7
62
+ * AR-S: 2e-6
63
+ * AR-B: 6e-7
64
+ * AR-L: 9e-7
65
+ * DeiT-S: 1e-6
66
+ * DeiT-B: 8e-7
67
+
68
+ ## Baseline methods
69
+ Notice to uncomment the import line containing the pretrained model you wish to finetune in the code.
70
+
71
+ ### GradMask
72
+ Run the following command:
73
+ ```bash
74
+ python imagenet_finetune_gradmask.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0 --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC>
75
+ ```
76
+ All hyperparameters for the different models can be found in section D of the supplementary material.
77
+
78
+ ### Right for the Right Reasons
79
+ Run the following command:
80
+ ```bash
81
+ python imagenet_finetune_rrr.py --seg_data <PATH_TO_SEGMENTATION_DATA> --data <PATH_TO_IMAGENET> --gpu 0 --lr <LR> --lambda_seg <SEG> --lambda_acc <ACC>
82
+ ```
83
+ All hyperparameters for the different models can be found in section D of the supplementary material.
84
+
85
+ ## Evaluation
86
+
87
+ ### Robustness Evaluation
88
+
89
+ 1. Download the evaluation datasets:
90
+ * [INet-A](https://github.com/hendrycks/natural-adv-examples)
91
+ * [INet-R](https://github.com/hendrycks/imagenet-r)
92
+ * [INet-v2](https://github.com/modestyachts/ImageNetV2)
93
+ * [ObjectNet](https://objectnet.dev/)
94
+ * [SI-Score](https://github.com/google-research/si-score)
95
+
96
+ 2. Run the following script to evaluate:
97
+
98
+ ```bash
99
+ python imagenet_eval_robustness.py --data <PATH_TO_ROBUSTNESS_DATASET> --batch-size <BATCH_SIZE> --evaluate --checkpoint <PATH_TO_FINETUNED_CHECKPOINT>
100
+ ```
101
+ * Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
102
+ * To evaluate the original model simply omit the `checkpoint` parameter.
103
+ * For the INet-v2 dataset add `--isV2`.
104
+ * For the ObjectNet dataset add `--isObjectNet`.
105
+ * For the SI datasets add `--isSI`.
106
+
107
+ ### Segmentation Evaluation
108
+ Our segmentation tests are based on the test in the official implementation of [Transformer Interpretability Beyond Attention Visualization](https://github.com/hila-chefer/Transformer-Explainability).
109
+ 1. [Download the ImageNet segmentation test set](https://github.com/hila-chefer/Transformer-Explainability#section-a-segmentation-results).
110
+ 2. Run the following script to evaluate:
111
+
112
+ ```bash
113
+ PYTHONPATH=./:$PYTHONPATH python SegmentationTest/imagenet_seg_eval.py --imagenet-seg-path <PATH_TO_gtsegs_ijcv.mat>
114
+ ```
115
+ * Notice to uncomment the import line containing the pretrained model you wish to evaluate in the code.
116
+
117
+ ### Credits
118
+ * The TokenCut code is built on top of [LOST](https://github.com/valeoai/LOST), [DINO](https://github.com/facebookresearch/dino), [Segswap](https://github.com/XiSHEN0220/SegSwap), and [Bilateral_Sovlver](https://github.com/poolio/bilateral_solver).
119
+ * Our ViT code is based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) repository.
120
+ * Our ImageNet finetuning code is based on [code from the official PyTorch repo](https://github.com/pytorch/examples/blob/main/imagenet/main.py).
121
+ * The code to convert ObjectNet classes to ImageNet classes was taken from [the torchprune repo](https://github.com/lucaslie/torchprune/blob/b753745b773c3ed259bf819d193ce8573d89efbb/src/torchprune/torchprune/util/datasets/objectnet.py).
122
+ * The code to convert SI-Score classes to ImageNet classes was taken from [the official implementation](https://github.com/google-research/si-score).
123
+
124
+ We would like to sincerely thank the authors for their great works.
RobustViT.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
SegmentationTest/data/Imagenet.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.utils.data as data
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ import h5py
8
+
9
+ __all__ = ['ImagenetResults']
10
+
11
+
12
+ class Imagenet_Segmentation(data.Dataset):
13
+ CLASSES = 2
14
+
15
+ def __init__(self,
16
+ path,
17
+ transform=None,
18
+ target_transform=None):
19
+ self.path = path
20
+ self.transform = transform
21
+ self.target_transform = target_transform
22
+ self.h5py = None
23
+ tmp = h5py.File(path, 'r')
24
+ self.data_length = len(tmp['/value/img'])
25
+ tmp.close()
26
+ del tmp
27
+
28
+ def __getitem__(self, index):
29
+
30
+ if self.h5py is None:
31
+ self.h5py = h5py.File(self.path, 'r')
32
+
33
+ img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
34
+ target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
35
+
36
+ img = Image.fromarray(img).convert('RGB')
37
+ target = Image.fromarray(target)
38
+
39
+ if self.transform is not None:
40
+ img = self.transform(img)
41
+
42
+ if self.target_transform is not None:
43
+ target = np.array(self.target_transform(target)).astype('int32')
44
+ target = torch.from_numpy(target).long()
45
+
46
+ return img, target
47
+
48
+ def __len__(self):
49
+ return self.data_length
50
+
51
+
52
+ class ImagenetResults(data.Dataset):
53
+ def __init__(self, path):
54
+ super(ImagenetResults, self).__init__()
55
+
56
+ self.path = os.path.join(path, 'results.hdf5')
57
+ self.data = None
58
+
59
+ print('Reading dataset length...')
60
+ with h5py.File(self.path, 'r') as f:
61
+ self.data_length = len(f['/image'])
62
+
63
+ def __len__(self):
64
+ return self.data_length
65
+
66
+ def __getitem__(self, item):
67
+ if self.data is None:
68
+ self.data = h5py.File(self.path, 'r')
69
+
70
+ image = torch.tensor(self.data['image'][item])
71
+ vis = torch.tensor(self.data['vis'][item])
72
+ target = torch.tensor(self.data['target'][item]).long()
73
+
74
+ return image, vis, target
SegmentationTest/data/VOC.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tarfile
3
+ import torch
4
+ import torch.utils.data as data
5
+ import numpy as np
6
+ import h5py
7
+
8
+ from PIL import Image
9
+ from scipy import io
10
+ from torchvision.datasets.utils import download_url
11
+
12
+ DATASET_YEAR_DICT = {
13
+ '2012': {
14
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
15
+ 'filename': 'VOCtrainval_11-May-2012.tar',
16
+ 'md5': '6cd6e144f989b92b3379bac3b3de84fd',
17
+ 'base_dir': 'VOCdevkit/VOC2012'
18
+ },
19
+ '2011': {
20
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
21
+ 'filename': 'VOCtrainval_25-May-2011.tar',
22
+ 'md5': '6c3384ef61512963050cb5d687e5bf1e',
23
+ 'base_dir': 'TrainVal/VOCdevkit/VOC2011'
24
+ },
25
+ '2010': {
26
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
27
+ 'filename': 'VOCtrainval_03-May-2010.tar',
28
+ 'md5': 'da459979d0c395079b5c75ee67908abb',
29
+ 'base_dir': 'VOCdevkit/VOC2010'
30
+ },
31
+ '2009': {
32
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
33
+ 'filename': 'VOCtrainval_11-May-2009.tar',
34
+ 'md5': '59065e4b188729180974ef6572f6a212',
35
+ 'base_dir': 'VOCdevkit/VOC2009'
36
+ },
37
+ '2008': {
38
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
39
+ 'filename': 'VOCtrainval_11-May-2012.tar',
40
+ 'md5': '2629fa636546599198acfcfbfcf1904a',
41
+ 'base_dir': 'VOCdevkit/VOC2008'
42
+ },
43
+ '2007': {
44
+ 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
45
+ 'filename': 'VOCtrainval_06-Nov-2007.tar',
46
+ 'md5': 'c52e279531787c972589f7e41ab4ae64',
47
+ 'base_dir': 'VOCdevkit/VOC2007'
48
+ }
49
+ }
50
+
51
+
52
+ class VOCSegmentation(data.Dataset):
53
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
54
+
55
+ Args:
56
+ root (string): Root directory of the VOC Dataset.
57
+ year (string, optional): The dataset year, supports years 2007 to 2012.
58
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
59
+ download (bool, optional): If true, downloads the dataset from the internet and
60
+ puts it in root directory. If dataset is already downloaded, it is not
61
+ downloaded again.
62
+ transform (callable, optional): A function/transform that takes in an PIL image
63
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
64
+ target_transform (callable, optional): A function/transform that takes in the
65
+ target and transforms it.
66
+ """
67
+
68
+ CLASSES = 20
69
+ CLASSES_NAMES = [
70
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
71
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
72
+ 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
73
+ 'tvmonitor', 'ambigious'
74
+ ]
75
+
76
+ def __init__(self,
77
+ root,
78
+ year='2012',
79
+ image_set='train',
80
+ download=False,
81
+ transform=None,
82
+ target_transform=None):
83
+ self.root = os.path.expanduser(root)
84
+ self.year = year
85
+ self.url = DATASET_YEAR_DICT[year]['url']
86
+ self.filename = DATASET_YEAR_DICT[year]['filename']
87
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
88
+ self.transform = transform
89
+ self.target_transform = target_transform
90
+ self.image_set = image_set
91
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
92
+ voc_root = os.path.join(self.root, base_dir)
93
+ image_dir = os.path.join(voc_root, 'JPEGImages')
94
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
95
+
96
+ if download:
97
+ download_extract(self.url, self.root, self.filename, self.md5)
98
+
99
+ if not os.path.isdir(voc_root):
100
+ raise RuntimeError('Dataset not found or corrupted.' +
101
+ ' You can use download=True to download it')
102
+
103
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
104
+
105
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
106
+
107
+ if not os.path.exists(split_f):
108
+ raise ValueError(
109
+ 'Wrong image_set entered! Please use image_set="train" '
110
+ 'or image_set="trainval" or image_set="val"')
111
+
112
+ with open(os.path.join(split_f), "r") as f:
113
+ file_names = [x.strip() for x in f.readlines()]
114
+
115
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
116
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
117
+ assert (len(self.images) == len(self.masks))
118
+
119
+ def __getitem__(self, index):
120
+ """
121
+ Args:
122
+ index (int): Index
123
+
124
+ Returns:
125
+ tuple: (image, target) where target is the image segmentation.
126
+ """
127
+ img = Image.open(self.images[index]).convert('RGB')
128
+ target = Image.open(self.masks[index])
129
+
130
+ if self.transform is not None:
131
+ img = self.transform(img)
132
+
133
+ if self.target_transform is not None:
134
+ target = np.array(self.target_transform(target)).astype('int32')
135
+ target[target == 255] = -1
136
+ target = torch.from_numpy(target).long()
137
+
138
+ return img, target
139
+
140
+ @staticmethod
141
+ def _mask_transform(mask):
142
+ target = np.array(mask).astype('int32')
143
+ target[target == 255] = -1
144
+ return torch.from_numpy(target).long()
145
+
146
+ def __len__(self):
147
+ return len(self.images)
148
+
149
+ @property
150
+ def pred_offset(self):
151
+ return 0
152
+
153
+
154
+ class VOCClassification(data.Dataset):
155
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
156
+
157
+ Args:
158
+ root (string): Root directory of the VOC Dataset.
159
+ year (string, optional): The dataset year, supports years 2007 to 2012.
160
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
161
+ download (bool, optional): If true, downloads the dataset from the internet and
162
+ puts it in root directory. If dataset is already downloaded, it is not
163
+ downloaded again.
164
+ transform (callable, optional): A function/transform that takes in an PIL image
165
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
166
+ """
167
+ CLASSES = 20
168
+
169
+ def __init__(self,
170
+ root,
171
+ year='2012',
172
+ image_set='train',
173
+ download=False,
174
+ transform=None):
175
+ self.root = os.path.expanduser(root)
176
+ self.year = year
177
+ self.url = DATASET_YEAR_DICT[year]['url']
178
+ self.filename = DATASET_YEAR_DICT[year]['filename']
179
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
180
+ self.transform = transform
181
+ self.image_set = image_set
182
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
183
+ voc_root = os.path.join(self.root, base_dir)
184
+ image_dir = os.path.join(voc_root, 'JPEGImages')
185
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
186
+
187
+ if download:
188
+ download_extract(self.url, self.root, self.filename, self.md5)
189
+
190
+ if not os.path.isdir(voc_root):
191
+ raise RuntimeError('Dataset not found or corrupted.' +
192
+ ' You can use download=True to download it')
193
+
194
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
195
+
196
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
197
+
198
+ if not os.path.exists(split_f):
199
+ raise ValueError(
200
+ 'Wrong image_set entered! Please use image_set="train" '
201
+ 'or image_set="trainval" or image_set="val"')
202
+
203
+ with open(os.path.join(split_f), "r") as f:
204
+ file_names = [x.strip() for x in f.readlines()]
205
+
206
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
207
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
208
+ assert (len(self.images) == len(self.masks))
209
+
210
+ def __getitem__(self, index):
211
+ """
212
+ Args:
213
+ index (int): Index
214
+
215
+ Returns:
216
+ tuple: (image, target) where target is the image segmentation.
217
+ """
218
+ img = Image.open(self.images[index]).convert('RGB')
219
+ target = Image.open(self.masks[index])
220
+
221
+ # if self.transform is not None:
222
+ # img = self.transform(img)
223
+ if self.transform is not None:
224
+ img, target = self.transform(img, target)
225
+
226
+ visible_classes = np.unique(target)
227
+ labels = torch.zeros(self.CLASSES)
228
+ for id in visible_classes:
229
+ if id not in (0, 255):
230
+ labels[id - 1].fill_(1)
231
+
232
+ return img, labels
233
+
234
+ def __len__(self):
235
+ return len(self.images)
236
+
237
+
238
+ class VOCSBDClassification(data.Dataset):
239
+ """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
240
+
241
+ Args:
242
+ root (string): Root directory of the VOC Dataset.
243
+ year (string, optional): The dataset year, supports years 2007 to 2012.
244
+ image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
245
+ download (bool, optional): If true, downloads the dataset from the internet and
246
+ puts it in root directory. If dataset is already downloaded, it is not
247
+ downloaded again.
248
+ transform (callable, optional): A function/transform that takes in an PIL image
249
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
250
+ """
251
+ CLASSES = 20
252
+
253
+ def __init__(self,
254
+ root,
255
+ sbd_root,
256
+ year='2012',
257
+ image_set='train',
258
+ download=False,
259
+ transform=None):
260
+ self.root = os.path.expanduser(root)
261
+ self.sbd_root = os.path.expanduser(sbd_root)
262
+ self.year = year
263
+ self.url = DATASET_YEAR_DICT[year]['url']
264
+ self.filename = DATASET_YEAR_DICT[year]['filename']
265
+ self.md5 = DATASET_YEAR_DICT[year]['md5']
266
+ self.transform = transform
267
+ self.image_set = image_set
268
+ base_dir = DATASET_YEAR_DICT[year]['base_dir']
269
+ voc_root = os.path.join(self.root, base_dir)
270
+ image_dir = os.path.join(voc_root, 'JPEGImages')
271
+ mask_dir = os.path.join(voc_root, 'SegmentationClass')
272
+ sbd_image_dir = os.path.join(sbd_root, 'img')
273
+ sbd_mask_dir = os.path.join(sbd_root, 'cls')
274
+
275
+ if download:
276
+ download_extract(self.url, self.root, self.filename, self.md5)
277
+
278
+ if not os.path.isdir(voc_root):
279
+ raise RuntimeError('Dataset not found or corrupted.' +
280
+ ' You can use download=True to download it')
281
+
282
+ splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
283
+
284
+ split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
285
+ sbd_split = os.path.join(sbd_root, 'train.txt')
286
+
287
+ if not os.path.exists(split_f):
288
+ raise ValueError(
289
+ 'Wrong image_set entered! Please use image_set="train" '
290
+ 'or image_set="trainval" or image_set="val"')
291
+
292
+ with open(os.path.join(split_f), "r") as f:
293
+ voc_file_names = [x.strip() for x in f.readlines()]
294
+
295
+ with open(os.path.join(sbd_split), "r") as f:
296
+ sbd_file_names = [x.strip() for x in f.readlines()]
297
+
298
+ self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names]
299
+ self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names]
300
+ self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names]
301
+ self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names]
302
+ assert (len(self.images) == len(self.masks))
303
+
304
+ def __getitem__(self, index):
305
+ """
306
+ Args:
307
+ index (int): Index
308
+
309
+ Returns:
310
+ tuple: (image, target) where target is the image segmentation.
311
+ """
312
+ img = Image.open(self.images[index]).convert('RGB')
313
+ mask_path = self.masks[index]
314
+ if mask_path[-3:] == 'mat':
315
+ target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation
316
+ target = Image.fromarray(target, mode='P')
317
+ else:
318
+ target = Image.open(self.masks[index])
319
+
320
+ if self.transform is not None:
321
+ img, target = self.transform(img, target)
322
+
323
+ visible_classes = np.unique(target)
324
+ labels = torch.zeros(self.CLASSES)
325
+ for id in visible_classes:
326
+ if id not in (0, 255):
327
+ labels[id - 1].fill_(1)
328
+
329
+ return img, labels
330
+
331
+ def __len__(self):
332
+ return len(self.images)
333
+
334
+
335
+ def download_extract(url, root, filename, md5):
336
+ download_url(url, root, filename, md5)
337
+ with tarfile.open(os.path.join(root, filename), "r") as tar:
338
+ tar.extractall(path=root)
339
+
340
+
341
+ class VOCResults(data.Dataset):
342
+ CLASSES = 20
343
+ CLASSES_NAMES = [
344
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
345
+ 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
346
+ 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
347
+ 'tvmonitor', 'ambigious'
348
+ ]
349
+
350
+ def __init__(self, path):
351
+ super(VOCResults, self).__init__()
352
+
353
+ self.path = os.path.join(path, 'results.hdf5')
354
+ self.data = None
355
+
356
+ print('Reading dataset length...')
357
+ with h5py.File(self.path , 'r') as f:
358
+ self.data_length = len(f['/image'])
359
+
360
+ def __len__(self):
361
+ return self.data_length
362
+
363
+ def __getitem__(self, item):
364
+ if self.data is None:
365
+ self.data = h5py.File(self.path, 'r')
366
+
367
+ image = torch.tensor(self.data['image'][item])
368
+ vis = torch.tensor(self.data['vis'][item])
369
+ target = torch.tensor(self.data['target'][item])
370
+ class_pred = torch.tensor(self.data['class_pred'][item])
371
+
372
+ return image, vis, target, class_pred
SegmentationTest/data/__init__.py ADDED
File without changes
SegmentationTest/data/imagenet_utils.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CLS2IDX = {
2
+ 0: 'tench, Tinca tinca',
3
+ 1: 'goldfish, Carassius auratus',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
5
+ 3: 'tiger shark, Galeocerdo cuvieri',
6
+ 4: 'hammerhead, hammerhead shark',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo',
8
+ 6: 'stingray',
9
+ 7: 'cock',
10
+ 8: 'hen',
11
+ 9: 'ostrich, Struthio camelus',
12
+ 10: 'brambling, Fringilla montifringilla',
13
+ 11: 'goldfinch, Carduelis carduelis',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus',
15
+ 13: 'junco, snowbird',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
17
+ 15: 'robin, American robin, Turdus migratorius',
18
+ 16: 'bulbul',
19
+ 17: 'jay',
20
+ 18: 'magpie',
21
+ 19: 'chickadee',
22
+ 20: 'water ouzel, dipper',
23
+ 21: 'kite',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
25
+ 23: 'vulture',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa',
27
+ 25: 'European fire salamander, Salamandra salamandra',
28
+ 26: 'common newt, Triturus vulgaris',
29
+ 27: 'eft',
30
+ 28: 'spotted salamander, Ambystoma maculatum',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum',
32
+ 30: 'bullfrog, Rana catesbeiana',
33
+ 31: 'tree frog, tree-frog',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
37
+ 35: 'mud turtle',
38
+ 36: 'terrapin',
39
+ 37: 'box turtle, box tortoise',
40
+ 38: 'banded gecko',
41
+ 39: 'common iguana, iguana, Iguana iguana',
42
+ 40: 'American chameleon, anole, Anolis carolinensis',
43
+ 41: 'whiptail, whiptail lizard',
44
+ 42: 'agama',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi',
46
+ 44: 'alligator lizard',
47
+ 45: 'Gila monster, Heloderma suspectum',
48
+ 46: 'green lizard, Lacerta viridis',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
52
+ 50: 'American alligator, Alligator mississipiensis',
53
+ 51: 'triceratops',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake',
56
+ 54: 'hognose snake, puff adder, sand viper',
57
+ 55: 'green snake, grass snake',
58
+ 56: 'king snake, kingsnake',
59
+ 57: 'garter snake, grass snake',
60
+ 58: 'water snake',
61
+ 59: 'vine snake',
62
+ 60: 'night snake, Hypsiglena torquata',
63
+ 61: 'boa constrictor, Constrictor constrictor',
64
+ 62: 'rock python, rock snake, Python sebae',
65
+ 63: 'Indian cobra, Naja naja',
66
+ 64: 'green mamba',
67
+ 65: 'sea snake',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
71
+ 69: 'trilobite',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio',
73
+ 71: 'scorpion',
74
+ 72: 'black and gold garden spider, Argiope aurantia',
75
+ 73: 'barn spider, Araneus cavaticus',
76
+ 74: 'garden spider, Aranea diademata',
77
+ 75: 'black widow, Latrodectus mactans',
78
+ 76: 'tarantula',
79
+ 77: 'wolf spider, hunting spider',
80
+ 78: 'tick',
81
+ 79: 'centipede',
82
+ 80: 'black grouse',
83
+ 81: 'ptarmigan',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl',
86
+ 84: 'peacock',
87
+ 85: 'quail',
88
+ 86: 'partridge',
89
+ 87: 'African grey, African gray, Psittacus erithacus',
90
+ 88: 'macaw',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
92
+ 90: 'lorikeet',
93
+ 91: 'coucal',
94
+ 92: 'bee eater',
95
+ 93: 'hornbill',
96
+ 94: 'hummingbird',
97
+ 95: 'jacamar',
98
+ 96: 'toucan',
99
+ 97: 'drake',
100
+ 98: 'red-breasted merganser, Mergus serrator',
101
+ 99: 'goose',
102
+ 100: 'black swan, Cygnus atratus',
103
+ 101: 'tusker',
104
+ 102: 'echidna, spiny anteater, anteater',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
106
+ 104: 'wallaby, brush kangaroo',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
108
+ 106: 'wombat',
109
+ 107: 'jellyfish',
110
+ 108: 'sea anemone, anemone',
111
+ 109: 'brain coral',
112
+ 110: 'flatworm, platyhelminth',
113
+ 111: 'nematode, nematode worm, roundworm',
114
+ 112: 'conch',
115
+ 113: 'snail',
116
+ 114: 'slug',
117
+ 115: 'sea slug, nudibranch',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus',
120
+ 118: 'Dungeness crab, Cancer magister',
121
+ 119: 'rock crab, Cancer irroratus',
122
+ 120: 'fiddler crab',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy',
127
+ 125: 'hermit crab',
128
+ 126: 'isopod',
129
+ 127: 'white stork, Ciconia ciconia',
130
+ 128: 'black stork, Ciconia nigra',
131
+ 129: 'spoonbill',
132
+ 130: 'flamingo',
133
+ 131: 'little blue heron, Egretta caerulea',
134
+ 132: 'American egret, great white heron, Egretta albus',
135
+ 133: 'bittern',
136
+ 134: 'crane',
137
+ 135: 'limpkin, Aramus pictus',
138
+ 136: 'European gallinule, Porphyrio porphyrio',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
140
+ 138: 'bustard',
141
+ 139: 'ruddy turnstone, Arenaria interpres',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina',
143
+ 141: 'redshank, Tringa totanus',
144
+ 142: 'dowitcher',
145
+ 143: 'oystercatcher, oyster catcher',
146
+ 144: 'pelican',
147
+ 145: 'king penguin, Aptenodytes patagonica',
148
+ 146: 'albatross, mollymawk',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
151
+ 149: 'dugong, Dugong dugon',
152
+ 150: 'sea lion',
153
+ 151: 'Chihuahua',
154
+ 152: 'Japanese spaniel',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese',
156
+ 154: 'Pekinese, Pekingese, Peke',
157
+ 155: 'Shih-Tzu',
158
+ 156: 'Blenheim spaniel',
159
+ 157: 'papillon',
160
+ 158: 'toy terrier',
161
+ 159: 'Rhodesian ridgeback',
162
+ 160: 'Afghan hound, Afghan',
163
+ 161: 'basset, basset hound',
164
+ 162: 'beagle',
165
+ 163: 'bloodhound, sleuthhound',
166
+ 164: 'bluetick',
167
+ 165: 'black-and-tan coonhound',
168
+ 166: 'Walker hound, Walker foxhound',
169
+ 167: 'English foxhound',
170
+ 168: 'redbone',
171
+ 169: 'borzoi, Russian wolfhound',
172
+ 170: 'Irish wolfhound',
173
+ 171: 'Italian greyhound',
174
+ 172: 'whippet',
175
+ 173: 'Ibizan hound, Ibizan Podenco',
176
+ 174: 'Norwegian elkhound, elkhound',
177
+ 175: 'otterhound, otter hound',
178
+ 176: 'Saluki, gazelle hound',
179
+ 177: 'Scottish deerhound, deerhound',
180
+ 178: 'Weimaraner',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
183
+ 181: 'Bedlington terrier',
184
+ 182: 'Border terrier',
185
+ 183: 'Kerry blue terrier',
186
+ 184: 'Irish terrier',
187
+ 185: 'Norfolk terrier',
188
+ 186: 'Norwich terrier',
189
+ 187: 'Yorkshire terrier',
190
+ 188: 'wire-haired fox terrier',
191
+ 189: 'Lakeland terrier',
192
+ 190: 'Sealyham terrier, Sealyham',
193
+ 191: 'Airedale, Airedale terrier',
194
+ 192: 'cairn, cairn terrier',
195
+ 193: 'Australian terrier',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier',
197
+ 195: 'Boston bull, Boston terrier',
198
+ 196: 'miniature schnauzer',
199
+ 197: 'giant schnauzer',
200
+ 198: 'standard schnauzer',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie',
202
+ 200: 'Tibetan terrier, chrysanthemum dog',
203
+ 201: 'silky terrier, Sydney silky',
204
+ 202: 'soft-coated wheaten terrier',
205
+ 203: 'West Highland white terrier',
206
+ 204: 'Lhasa, Lhasa apso',
207
+ 205: 'flat-coated retriever',
208
+ 206: 'curly-coated retriever',
209
+ 207: 'golden retriever',
210
+ 208: 'Labrador retriever',
211
+ 209: 'Chesapeake Bay retriever',
212
+ 210: 'German short-haired pointer',
213
+ 211: 'vizsla, Hungarian pointer',
214
+ 212: 'English setter',
215
+ 213: 'Irish setter, red setter',
216
+ 214: 'Gordon setter',
217
+ 215: 'Brittany spaniel',
218
+ 216: 'clumber, clumber spaniel',
219
+ 217: 'English springer, English springer spaniel',
220
+ 218: 'Welsh springer spaniel',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker',
222
+ 220: 'Sussex spaniel',
223
+ 221: 'Irish water spaniel',
224
+ 222: 'kuvasz',
225
+ 223: 'schipperke',
226
+ 224: 'groenendael',
227
+ 225: 'malinois',
228
+ 226: 'briard',
229
+ 227: 'kelpie',
230
+ 228: 'komondor',
231
+ 229: 'Old English sheepdog, bobtail',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
233
+ 231: 'collie',
234
+ 232: 'Border collie',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres',
236
+ 234: 'Rottweiler',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian',
238
+ 236: 'Doberman, Doberman pinscher',
239
+ 237: 'miniature pinscher',
240
+ 238: 'Greater Swiss Mountain dog',
241
+ 239: 'Bernese mountain dog',
242
+ 240: 'Appenzeller',
243
+ 241: 'EntleBucher',
244
+ 242: 'boxer',
245
+ 243: 'bull mastiff',
246
+ 244: 'Tibetan mastiff',
247
+ 245: 'French bulldog',
248
+ 246: 'Great Dane',
249
+ 247: 'Saint Bernard, St Bernard',
250
+ 248: 'Eskimo dog, husky',
251
+ 249: 'malamute, malemute, Alaskan malamute',
252
+ 250: 'Siberian husky',
253
+ 251: 'dalmatian, coach dog, carriage dog',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog',
255
+ 253: 'basenji',
256
+ 254: 'pug, pug-dog',
257
+ 255: 'Leonberg',
258
+ 256: 'Newfoundland, Newfoundland dog',
259
+ 257: 'Great Pyrenees',
260
+ 258: 'Samoyed, Samoyede',
261
+ 259: 'Pomeranian',
262
+ 260: 'chow, chow chow',
263
+ 261: 'keeshond',
264
+ 262: 'Brabancon griffon',
265
+ 263: 'Pembroke, Pembroke Welsh corgi',
266
+ 264: 'Cardigan, Cardigan Welsh corgi',
267
+ 265: 'toy poodle',
268
+ 266: 'miniature poodle',
269
+ 267: 'standard poodle',
270
+ 268: 'Mexican hairless',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo',
276
+ 274: 'dhole, Cuon alpinus',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
278
+ 276: 'hyena, hyaena',
279
+ 277: 'red fox, Vulpes vulpes',
280
+ 278: 'kit fox, Vulpes macrotis',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus',
283
+ 281: 'tabby, tabby cat',
284
+ 282: 'tiger cat',
285
+ 283: 'Persian cat',
286
+ 284: 'Siamese cat, Siamese',
287
+ 285: 'Egyptian cat',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
289
+ 287: 'lynx, catamount',
290
+ 288: 'leopard, Panthera pardus',
291
+ 289: 'snow leopard, ounce, Panthera uncia',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca',
293
+ 291: 'lion, king of beasts, Panthera leo',
294
+ 292: 'tiger, Panthera tigris',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus',
296
+ 294: 'brown bear, bruin, Ursus arctos',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus',
300
+ 298: 'mongoose',
301
+ 299: 'meerkat, mierkat',
302
+ 300: 'tiger beetle',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
304
+ 302: 'ground beetle, carabid beetle',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle',
306
+ 304: 'leaf beetle, chrysomelid',
307
+ 305: 'dung beetle',
308
+ 306: 'rhinoceros beetle',
309
+ 307: 'weevil',
310
+ 308: 'fly',
311
+ 309: 'bee',
312
+ 310: 'ant, emmet, pismire',
313
+ 311: 'grasshopper, hopper',
314
+ 312: 'cricket',
315
+ 313: 'walking stick, walkingstick, stick insect',
316
+ 314: 'cockroach, roach',
317
+ 315: 'mantis, mantid',
318
+ 316: 'cicada, cicala',
319
+ 317: 'leafhopper',
320
+ 318: 'lacewing, lacewing fly',
321
+ 319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
322
+ 320: 'damselfly',
323
+ 321: 'admiral',
324
+ 322: 'ringlet, ringlet butterfly',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
326
+ 324: 'cabbage butterfly',
327
+ 325: 'sulphur butterfly, sulfur butterfly',
328
+ 326: 'lycaenid, lycaenid butterfly',
329
+ 327: 'starfish, sea star',
330
+ 328: 'sea urchin',
331
+ 329: 'sea cucumber, holothurian',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit',
333
+ 331: 'hare',
334
+ 332: 'Angora, Angora rabbit',
335
+ 333: 'hamster',
336
+ 334: 'porcupine, hedgehog',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
338
+ 336: 'marmot',
339
+ 337: 'beaver',
340
+ 338: 'guinea pig, Cavia cobaya',
341
+ 339: 'sorrel',
342
+ 340: 'zebra',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa',
344
+ 342: 'wild boar, boar, Sus scrofa',
345
+ 343: 'warthog',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
347
+ 345: 'ox',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
349
+ 347: 'bison',
350
+ 348: 'ram, tup',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
352
+ 350: 'ibex, Capra ibex',
353
+ 351: 'hartebeest',
354
+ 352: 'impala, Aepyceros melampus',
355
+ 353: 'gazelle',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius',
357
+ 355: 'llama',
358
+ 356: 'weasel',
359
+ 357: 'mink',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes',
362
+ 360: 'otter',
363
+ 361: 'skunk, polecat, wood pussy',
364
+ 362: 'badger',
365
+ 363: 'armadillo',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus',
368
+ 366: 'gorilla, Gorilla gorilla',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes',
370
+ 368: 'gibbon, Hylobates lar',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
372
+ 370: 'guenon, guenon monkey',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas',
374
+ 372: 'baboon',
375
+ 373: 'macaque',
376
+ 374: 'langur',
377
+ 375: 'colobus, colobus monkey',
378
+ 376: 'proboscis monkey, Nasalis larvatus',
379
+ 377: 'marmoset',
380
+ 378: 'capuchin, ringtail, Cebus capucinus',
381
+ 379: 'howler monkey, howler',
382
+ 380: 'titi, titi monkey',
383
+ 381: 'spider monkey, Ateles geoffroyi',
384
+ 382: 'squirrel monkey, Saimiri sciureus',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus',
387
+ 385: 'Indian elephant, Elephas maximus',
388
+ 386: 'African elephant, Loxodonta africana',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
391
+ 389: 'barracouta, snoek',
392
+ 390: 'eel',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
394
+ 392: 'rock beauty, Holocanthus tricolor',
395
+ 393: 'anemone fish',
396
+ 394: 'sturgeon',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
398
+ 396: 'lionfish',
399
+ 397: 'puffer, pufferfish, blowfish, globefish',
400
+ 398: 'abacus',
401
+ 399: 'abaya',
402
+ 400: "academic gown, academic robe, judge's robe",
403
+ 401: 'accordion, piano accordion, squeeze box',
404
+ 402: 'acoustic guitar',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
406
+ 404: 'airliner',
407
+ 405: 'airship, dirigible',
408
+ 406: 'altar',
409
+ 407: 'ambulance',
410
+ 408: 'amphibian, amphibious vehicle',
411
+ 409: 'analog clock',
412
+ 410: 'apiary, bee house',
413
+ 411: 'apron',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
415
+ 413: 'assault rifle, assault gun',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
417
+ 415: 'bakery, bakeshop, bakehouse',
418
+ 416: 'balance beam, beam',
419
+ 417: 'balloon',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro',
421
+ 419: 'Band Aid',
422
+ 420: 'banjo',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail',
424
+ 422: 'barbell',
425
+ 423: 'barber chair',
426
+ 424: 'barbershop',
427
+ 425: 'barn',
428
+ 426: 'barometer',
429
+ 427: 'barrel, cask',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow',
431
+ 429: 'baseball',
432
+ 430: 'basketball',
433
+ 431: 'bassinet',
434
+ 432: 'bassoon',
435
+ 433: 'bathing cap, swimming cap',
436
+ 434: 'bath towel',
437
+ 435: 'bathtub, bathing tub, bath, tub',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
439
+ 437: 'beacon, lighthouse, beacon light, pharos',
440
+ 438: 'beaker',
441
+ 439: 'bearskin, busby, shako',
442
+ 440: 'beer bottle',
443
+ 441: 'beer glass',
444
+ 442: 'bell cote, bell cot',
445
+ 443: 'bib',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem',
447
+ 445: 'bikini, two-piece',
448
+ 446: 'binder, ring-binder',
449
+ 447: 'binoculars, field glasses, opera glasses',
450
+ 448: 'birdhouse',
451
+ 449: 'boathouse',
452
+ 450: 'bobsled, bobsleigh, bob',
453
+ 451: 'bolo tie, bolo, bola tie, bola',
454
+ 452: 'bonnet, poke bonnet',
455
+ 453: 'bookcase',
456
+ 454: 'bookshop, bookstore, bookstall',
457
+ 455: 'bottlecap',
458
+ 456: 'bow',
459
+ 457: 'bow tie, bow-tie, bowtie',
460
+ 458: 'brass, memorial tablet, plaque',
461
+ 459: 'brassiere, bra, bandeau',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
463
+ 461: 'breastplate, aegis, egis',
464
+ 462: 'broom',
465
+ 463: 'bucket, pail',
466
+ 464: 'buckle',
467
+ 465: 'bulletproof vest',
468
+ 466: 'bullet train, bullet',
469
+ 467: 'butcher shop, meat market',
470
+ 468: 'cab, hack, taxi, taxicab',
471
+ 469: 'caldron, cauldron',
472
+ 470: 'candle, taper, wax light',
473
+ 471: 'cannon',
474
+ 472: 'canoe',
475
+ 473: 'can opener, tin opener',
476
+ 474: 'cardigan',
477
+ 475: 'car mirror',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
479
+ 477: "carpenter's kit, tool kit",
480
+ 478: 'carton',
481
+ 479: 'car wheel',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
483
+ 481: 'cassette',
484
+ 482: 'cassette player',
485
+ 483: 'castle',
486
+ 484: 'catamaran',
487
+ 485: 'CD player',
488
+ 486: 'cello, violoncello',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
490
+ 488: 'chain',
491
+ 489: 'chainlink fence',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
493
+ 491: 'chain saw, chainsaw',
494
+ 492: 'chest',
495
+ 493: 'chiffonier, commode',
496
+ 494: 'chime, bell, gong',
497
+ 495: 'china cabinet, china closet',
498
+ 496: 'Christmas stocking',
499
+ 497: 'church, church building',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace',
501
+ 499: 'cleaver, meat cleaver, chopper',
502
+ 500: 'cliff dwelling',
503
+ 501: 'cloak',
504
+ 502: 'clog, geta, patten, sabot',
505
+ 503: 'cocktail shaker',
506
+ 504: 'coffee mug',
507
+ 505: 'coffeepot',
508
+ 506: 'coil, spiral, volute, whorl, helix',
509
+ 507: 'combination lock',
510
+ 508: 'computer keyboard, keypad',
511
+ 509: 'confectionery, confectionary, candy store',
512
+ 510: 'container ship, containership, container vessel',
513
+ 511: 'convertible',
514
+ 512: 'corkscrew, bottle screw',
515
+ 513: 'cornet, horn, trumpet, trump',
516
+ 514: 'cowboy boot',
517
+ 515: 'cowboy hat, ten-gallon hat',
518
+ 516: 'cradle',
519
+ 517: 'crane',
520
+ 518: 'crash helmet',
521
+ 519: 'crate',
522
+ 520: 'crib, cot',
523
+ 521: 'Crock Pot',
524
+ 522: 'croquet ball',
525
+ 523: 'crutch',
526
+ 524: 'cuirass',
527
+ 525: 'dam, dike, dyke',
528
+ 526: 'desk',
529
+ 527: 'desktop computer',
530
+ 528: 'dial telephone, dial phone',
531
+ 529: 'diaper, nappy, napkin',
532
+ 530: 'digital clock',
533
+ 531: 'digital watch',
534
+ 532: 'dining table, board',
535
+ 533: 'dishrag, dishcloth',
536
+ 534: 'dishwasher, dish washer, dishwashing machine',
537
+ 535: 'disk brake, disc brake',
538
+ 536: 'dock, dockage, docking facility',
539
+ 537: 'dogsled, dog sled, dog sleigh',
540
+ 538: 'dome',
541
+ 539: 'doormat, welcome mat',
542
+ 540: 'drilling platform, offshore rig',
543
+ 541: 'drum, membranophone, tympan',
544
+ 542: 'drumstick',
545
+ 543: 'dumbbell',
546
+ 544: 'Dutch oven',
547
+ 545: 'electric fan, blower',
548
+ 546: 'electric guitar',
549
+ 547: 'electric locomotive',
550
+ 548: 'entertainment center',
551
+ 549: 'envelope',
552
+ 550: 'espresso maker',
553
+ 551: 'face powder',
554
+ 552: 'feather boa, boa',
555
+ 553: 'file, file cabinet, filing cabinet',
556
+ 554: 'fireboat',
557
+ 555: 'fire engine, fire truck',
558
+ 556: 'fire screen, fireguard',
559
+ 557: 'flagpole, flagstaff',
560
+ 558: 'flute, transverse flute',
561
+ 559: 'folding chair',
562
+ 560: 'football helmet',
563
+ 561: 'forklift',
564
+ 562: 'fountain',
565
+ 563: 'fountain pen',
566
+ 564: 'four-poster',
567
+ 565: 'freight car',
568
+ 566: 'French horn, horn',
569
+ 567: 'frying pan, frypan, skillet',
570
+ 568: 'fur coat',
571
+ 569: 'garbage truck, dustcart',
572
+ 570: 'gasmask, respirator, gas helmet',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser',
574
+ 572: 'goblet',
575
+ 573: 'go-kart',
576
+ 574: 'golf ball',
577
+ 575: 'golfcart, golf cart',
578
+ 576: 'gondola',
579
+ 577: 'gong, tam-tam',
580
+ 578: 'gown',
581
+ 579: 'grand piano, grand',
582
+ 580: 'greenhouse, nursery, glasshouse',
583
+ 581: 'grille, radiator grille',
584
+ 582: 'grocery store, grocery, food market, market',
585
+ 583: 'guillotine',
586
+ 584: 'hair slide',
587
+ 585: 'hair spray',
588
+ 586: 'half track',
589
+ 587: 'hammer',
590
+ 588: 'hamper',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
592
+ 590: 'hand-held computer, hand-held microcomputer',
593
+ 591: 'handkerchief, hankie, hanky, hankey',
594
+ 592: 'hard disc, hard disk, fixed disk',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp',
596
+ 594: 'harp',
597
+ 595: 'harvester, reaper',
598
+ 596: 'hatchet',
599
+ 597: 'holster',
600
+ 598: 'home theater, home theatre',
601
+ 599: 'honeycomb',
602
+ 600: 'hook, claw',
603
+ 601: 'hoopskirt, crinoline',
604
+ 602: 'horizontal bar, high bar',
605
+ 603: 'horse cart, horse-cart',
606
+ 604: 'hourglass',
607
+ 605: 'iPod',
608
+ 606: 'iron, smoothing iron',
609
+ 607: "jack-o'-lantern",
610
+ 608: 'jean, blue jean, denim',
611
+ 609: 'jeep, landrover',
612
+ 610: 'jersey, T-shirt, tee shirt',
613
+ 611: 'jigsaw puzzle',
614
+ 612: 'jinrikisha, ricksha, rickshaw',
615
+ 613: 'joystick',
616
+ 614: 'kimono',
617
+ 615: 'knee pad',
618
+ 616: 'knot',
619
+ 617: 'lab coat, laboratory coat',
620
+ 618: 'ladle',
621
+ 619: 'lampshade, lamp shade',
622
+ 620: 'laptop, laptop computer',
623
+ 621: 'lawn mower, mower',
624
+ 622: 'lens cap, lens cover',
625
+ 623: 'letter opener, paper knife, paperknife',
626
+ 624: 'library',
627
+ 625: 'lifeboat',
628
+ 626: 'lighter, light, igniter, ignitor',
629
+ 627: 'limousine, limo',
630
+ 628: 'liner, ocean liner',
631
+ 629: 'lipstick, lip rouge',
632
+ 630: 'Loafer',
633
+ 631: 'lotion',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
635
+ 633: "loupe, jeweler's loupe",
636
+ 634: 'lumbermill, sawmill',
637
+ 635: 'magnetic compass',
638
+ 636: 'mailbag, postbag',
639
+ 637: 'mailbox, letter box',
640
+ 638: 'maillot',
641
+ 639: 'maillot, tank suit',
642
+ 640: 'manhole cover',
643
+ 641: 'maraca',
644
+ 642: 'marimba, xylophone',
645
+ 643: 'mask',
646
+ 644: 'matchstick',
647
+ 645: 'maypole',
648
+ 646: 'maze, labyrinth',
649
+ 647: 'measuring cup',
650
+ 648: 'medicine chest, medicine cabinet',
651
+ 649: 'megalith, megalithic structure',
652
+ 650: 'microphone, mike',
653
+ 651: 'microwave, microwave oven',
654
+ 652: 'military uniform',
655
+ 653: 'milk can',
656
+ 654: 'minibus',
657
+ 655: 'miniskirt, mini',
658
+ 656: 'minivan',
659
+ 657: 'missile',
660
+ 658: 'mitten',
661
+ 659: 'mixing bowl',
662
+ 660: 'mobile home, manufactured home',
663
+ 661: 'Model T',
664
+ 662: 'modem',
665
+ 663: 'monastery',
666
+ 664: 'monitor',
667
+ 665: 'moped',
668
+ 666: 'mortar',
669
+ 667: 'mortarboard',
670
+ 668: 'mosque',
671
+ 669: 'mosquito net',
672
+ 670: 'motor scooter, scooter',
673
+ 671: 'mountain bike, all-terrain bike, off-roader',
674
+ 672: 'mountain tent',
675
+ 673: 'mouse, computer mouse',
676
+ 674: 'mousetrap',
677
+ 675: 'moving van',
678
+ 676: 'muzzle',
679
+ 677: 'nail',
680
+ 678: 'neck brace',
681
+ 679: 'necklace',
682
+ 680: 'nipple',
683
+ 681: 'notebook, notebook computer',
684
+ 682: 'obelisk',
685
+ 683: 'oboe, hautboy, hautbois',
686
+ 684: 'ocarina, sweet potato',
687
+ 685: 'odometer, hodometer, mileometer, milometer',
688
+ 686: 'oil filter',
689
+ 687: 'organ, pipe organ',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
691
+ 689: 'overskirt',
692
+ 690: 'oxcart',
693
+ 691: 'oxygen mask',
694
+ 692: 'packet',
695
+ 693: 'paddle, boat paddle',
696
+ 694: 'paddlewheel, paddle wheel',
697
+ 695: 'padlock',
698
+ 696: 'paintbrush',
699
+ 697: "pajama, pyjama, pj's, jammies",
700
+ 698: 'palace',
701
+ 699: 'panpipe, pandean pipe, syrinx',
702
+ 700: 'paper towel',
703
+ 701: 'parachute, chute',
704
+ 702: 'parallel bars, bars',
705
+ 703: 'park bench',
706
+ 704: 'parking meter',
707
+ 705: 'passenger car, coach, carriage',
708
+ 706: 'patio, terrace',
709
+ 707: 'pay-phone, pay-station',
710
+ 708: 'pedestal, plinth, footstall',
711
+ 709: 'pencil box, pencil case',
712
+ 710: 'pencil sharpener',
713
+ 711: 'perfume, essence',
714
+ 712: 'Petri dish',
715
+ 713: 'photocopier',
716
+ 714: 'pick, plectrum, plectron',
717
+ 715: 'pickelhaube',
718
+ 716: 'picket fence, paling',
719
+ 717: 'pickup, pickup truck',
720
+ 718: 'pier',
721
+ 719: 'piggy bank, penny bank',
722
+ 720: 'pill bottle',
723
+ 721: 'pillow',
724
+ 722: 'ping-pong ball',
725
+ 723: 'pinwheel',
726
+ 724: 'pirate, pirate ship',
727
+ 725: 'pitcher, ewer',
728
+ 726: "plane, carpenter's plane, woodworking plane",
729
+ 727: 'planetarium',
730
+ 728: 'plastic bag',
731
+ 729: 'plate rack',
732
+ 730: 'plow, plough',
733
+ 731: "plunger, plumber's helper",
734
+ 732: 'Polaroid camera, Polaroid Land camera',
735
+ 733: 'pole',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
737
+ 735: 'poncho',
738
+ 736: 'pool table, billiard table, snooker table',
739
+ 737: 'pop bottle, soda bottle',
740
+ 738: 'pot, flowerpot',
741
+ 739: "potter's wheel",
742
+ 740: 'power drill',
743
+ 741: 'prayer rug, prayer mat',
744
+ 742: 'printer',
745
+ 743: 'prison, prison house',
746
+ 744: 'projectile, missile',
747
+ 745: 'projector',
748
+ 746: 'puck, hockey puck',
749
+ 747: 'punching bag, punch bag, punching ball, punchball',
750
+ 748: 'purse',
751
+ 749: 'quill, quill pen',
752
+ 750: 'quilt, comforter, comfort, puff',
753
+ 751: 'racer, race car, racing car',
754
+ 752: 'racket, racquet',
755
+ 753: 'radiator',
756
+ 754: 'radio, wireless',
757
+ 755: 'radio telescope, radio reflector',
758
+ 756: 'rain barrel',
759
+ 757: 'recreational vehicle, RV, R.V.',
760
+ 758: 'reel',
761
+ 759: 'reflex camera',
762
+ 760: 'refrigerator, icebox',
763
+ 761: 'remote control, remote',
764
+ 762: 'restaurant, eating house, eating place, eatery',
765
+ 763: 'revolver, six-gun, six-shooter',
766
+ 764: 'rifle',
767
+ 765: 'rocking chair, rocker',
768
+ 766: 'rotisserie',
769
+ 767: 'rubber eraser, rubber, pencil eraser',
770
+ 768: 'rugby ball',
771
+ 769: 'rule, ruler',
772
+ 770: 'running shoe',
773
+ 771: 'safe',
774
+ 772: 'safety pin',
775
+ 773: 'saltshaker, salt shaker',
776
+ 774: 'sandal',
777
+ 775: 'sarong',
778
+ 776: 'sax, saxophone',
779
+ 777: 'scabbard',
780
+ 778: 'scale, weighing machine',
781
+ 779: 'school bus',
782
+ 780: 'schooner',
783
+ 781: 'scoreboard',
784
+ 782: 'screen, CRT screen',
785
+ 783: 'screw',
786
+ 784: 'screwdriver',
787
+ 785: 'seat belt, seatbelt',
788
+ 786: 'sewing machine',
789
+ 787: 'shield, buckler',
790
+ 788: 'shoe shop, shoe-shop, shoe store',
791
+ 789: 'shoji',
792
+ 790: 'shopping basket',
793
+ 791: 'shopping cart',
794
+ 792: 'shovel',
795
+ 793: 'shower cap',
796
+ 794: 'shower curtain',
797
+ 795: 'ski',
798
+ 796: 'ski mask',
799
+ 797: 'sleeping bag',
800
+ 798: 'slide rule, slipstick',
801
+ 799: 'sliding door',
802
+ 800: 'slot, one-armed bandit',
803
+ 801: 'snorkel',
804
+ 802: 'snowmobile',
805
+ 803: 'snowplow, snowplough',
806
+ 804: 'soap dispenser',
807
+ 805: 'soccer ball',
808
+ 806: 'sock',
809
+ 807: 'solar dish, solar collector, solar furnace',
810
+ 808: 'sombrero',
811
+ 809: 'soup bowl',
812
+ 810: 'space bar',
813
+ 811: 'space heater',
814
+ 812: 'space shuttle',
815
+ 813: 'spatula',
816
+ 814: 'speedboat',
817
+ 815: "spider web, spider's web",
818
+ 816: 'spindle',
819
+ 817: 'sports car, sport car',
820
+ 818: 'spotlight, spot',
821
+ 819: 'stage',
822
+ 820: 'steam locomotive',
823
+ 821: 'steel arch bridge',
824
+ 822: 'steel drum',
825
+ 823: 'stethoscope',
826
+ 824: 'stole',
827
+ 825: 'stone wall',
828
+ 826: 'stopwatch, stop watch',
829
+ 827: 'stove',
830
+ 828: 'strainer',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car',
832
+ 830: 'stretcher',
833
+ 831: 'studio couch, day bed',
834
+ 832: 'stupa, tope',
835
+ 833: 'submarine, pigboat, sub, U-boat',
836
+ 834: 'suit, suit of clothes',
837
+ 835: 'sundial',
838
+ 836: 'sunglass',
839
+ 837: 'sunglasses, dark glasses, shades',
840
+ 838: 'sunscreen, sunblock, sun blocker',
841
+ 839: 'suspension bridge',
842
+ 840: 'swab, swob, mop',
843
+ 841: 'sweatshirt',
844
+ 842: 'swimming trunks, bathing trunks',
845
+ 843: 'swing',
846
+ 844: 'switch, electric switch, electrical switch',
847
+ 845: 'syringe',
848
+ 846: 'table lamp',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
850
+ 848: 'tape player',
851
+ 849: 'teapot',
852
+ 850: 'teddy, teddy bear',
853
+ 851: 'television, television system',
854
+ 852: 'tennis ball',
855
+ 853: 'thatch, thatched roof',
856
+ 854: 'theater curtain, theatre curtain',
857
+ 855: 'thimble',
858
+ 856: 'thresher, thrasher, threshing machine',
859
+ 857: 'throne',
860
+ 858: 'tile roof',
861
+ 859: 'toaster',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist',
863
+ 861: 'toilet seat',
864
+ 862: 'torch',
865
+ 863: 'totem pole',
866
+ 864: 'tow truck, tow car, wrecker',
867
+ 865: 'toyshop',
868
+ 866: 'tractor',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
870
+ 868: 'tray',
871
+ 869: 'trench coat',
872
+ 870: 'tricycle, trike, velocipede',
873
+ 871: 'trimaran',
874
+ 872: 'tripod',
875
+ 873: 'triumphal arch',
876
+ 874: 'trolleybus, trolley coach, trackless trolley',
877
+ 875: 'trombone',
878
+ 876: 'tub, vat',
879
+ 877: 'turnstile',
880
+ 878: 'typewriter keyboard',
881
+ 879: 'umbrella',
882
+ 880: 'unicycle, monocycle',
883
+ 881: 'upright, upright piano',
884
+ 882: 'vacuum, vacuum cleaner',
885
+ 883: 'vase',
886
+ 884: 'vault',
887
+ 885: 'velvet',
888
+ 886: 'vending machine',
889
+ 887: 'vestment',
890
+ 888: 'viaduct',
891
+ 889: 'violin, fiddle',
892
+ 890: 'volleyball',
893
+ 891: 'waffle iron',
894
+ 892: 'wall clock',
895
+ 893: 'wallet, billfold, notecase, pocketbook',
896
+ 894: 'wardrobe, closet, press',
897
+ 895: 'warplane, military plane',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
899
+ 897: 'washer, automatic washer, washing machine',
900
+ 898: 'water bottle',
901
+ 899: 'water jug',
902
+ 900: 'water tower',
903
+ 901: 'whiskey jug',
904
+ 902: 'whistle',
905
+ 903: 'wig',
906
+ 904: 'window screen',
907
+ 905: 'window shade',
908
+ 906: 'Windsor tie',
909
+ 907: 'wine bottle',
910
+ 908: 'wing',
911
+ 909: 'wok',
912
+ 910: 'wooden spoon',
913
+ 911: 'wool, woolen, woollen',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
915
+ 913: 'wreck',
916
+ 914: 'yawl',
917
+ 915: 'yurt',
918
+ 916: 'web site, website, internet site, site',
919
+ 917: 'comic book',
920
+ 918: 'crossword puzzle, crossword',
921
+ 919: 'street sign',
922
+ 920: 'traffic light, traffic signal, stoplight',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper',
924
+ 922: 'menu',
925
+ 923: 'plate',
926
+ 924: 'guacamole',
927
+ 925: 'consomme',
928
+ 926: 'hot pot, hotpot',
929
+ 927: 'trifle',
930
+ 928: 'ice cream, icecream',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle',
932
+ 930: 'French loaf',
933
+ 931: 'bagel, beigel',
934
+ 932: 'pretzel',
935
+ 933: 'cheeseburger',
936
+ 934: 'hotdog, hot dog, red hot',
937
+ 935: 'mashed potato',
938
+ 936: 'head cabbage',
939
+ 937: 'broccoli',
940
+ 938: 'cauliflower',
941
+ 939: 'zucchini, courgette',
942
+ 940: 'spaghetti squash',
943
+ 941: 'acorn squash',
944
+ 942: 'butternut squash',
945
+ 943: 'cucumber, cuke',
946
+ 944: 'artichoke, globe artichoke',
947
+ 945: 'bell pepper',
948
+ 946: 'cardoon',
949
+ 947: 'mushroom',
950
+ 948: 'Granny Smith',
951
+ 949: 'strawberry',
952
+ 950: 'orange',
953
+ 951: 'lemon',
954
+ 952: 'fig',
955
+ 953: 'pineapple, ananas',
956
+ 954: 'banana',
957
+ 955: 'jackfruit, jak, jack',
958
+ 956: 'custard apple',
959
+ 957: 'pomegranate',
960
+ 958: 'hay',
961
+ 959: 'carbonara',
962
+ 960: 'chocolate sauce, chocolate syrup',
963
+ 961: 'dough',
964
+ 962: 'meat loaf, meatloaf',
965
+ 963: 'pizza, pizza pie',
966
+ 964: 'potpie',
967
+ 965: 'burrito',
968
+ 966: 'red wine',
969
+ 967: 'espresso',
970
+ 968: 'cup',
971
+ 969: 'eggnog',
972
+ 970: 'alp',
973
+ 971: 'bubble',
974
+ 972: 'cliff, drop, drop-off',
975
+ 973: 'coral reef',
976
+ 974: 'geyser',
977
+ 975: 'lakeside, lakeshore',
978
+ 976: 'promontory, headland, head, foreland',
979
+ 977: 'sandbar, sand bar',
980
+ 978: 'seashore, coast, seacoast, sea-coast',
981
+ 979: 'valley, vale',
982
+ 980: 'volcano',
983
+ 981: 'ballplayer, baseball player',
984
+ 982: 'groom, bridegroom',
985
+ 983: 'scuba diver',
986
+ 984: 'rapeseed',
987
+ 985: 'daisy',
988
+ 986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
989
+ 987: 'corn',
990
+ 988: 'acorn',
991
+ 989: 'hip, rose hip, rosehip',
992
+ 990: 'buckeye, horse chestnut, conker',
993
+ 991: 'coral fungus',
994
+ 992: 'agaric',
995
+ 993: 'gyromitra',
996
+ 994: 'stinkhorn, carrion fungus',
997
+ 995: 'earthstar',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
999
+ 997: 'bolete',
1000
+ 998: 'ear, spike, capitulum',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue'
1002
+ }
SegmentationTest/data/transforms.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import sys
3
+ import random
4
+ from PIL import Image
5
+
6
+ try:
7
+ import accimage
8
+ except ImportError:
9
+ accimage = None
10
+ import numbers
11
+ import collections
12
+
13
+ from torchvision.transforms import functional as F
14
+
15
+ if sys.version_info < (3, 3):
16
+ Sequence = collections.Sequence
17
+ Iterable = collections.Iterable
18
+ else:
19
+ Sequence = collections.abc.Sequence
20
+ Iterable = collections.abc.Iterable
21
+
22
+ _pil_interpolation_to_str = {
23
+ Image.NEAREST: 'PIL.Image.NEAREST',
24
+ Image.BILINEAR: 'PIL.Image.BILINEAR',
25
+ Image.BICUBIC: 'PIL.Image.BICUBIC',
26
+ Image.LANCZOS: 'PIL.Image.LANCZOS',
27
+ Image.HAMMING: 'PIL.Image.HAMMING',
28
+ Image.BOX: 'PIL.Image.BOX',
29
+ }
30
+
31
+
32
+ class Compose(object):
33
+ """Composes several transforms together.
34
+
35
+ Args:
36
+ transforms (list of ``Transform`` objects): list of transforms to compose.
37
+
38
+ Example:
39
+ >>> transforms.Compose([
40
+ >>> transforms.CenterCrop(10),
41
+ >>> transforms.ToTensor(),
42
+ >>> ])
43
+ """
44
+
45
+ def __init__(self, transforms):
46
+ self.transforms = transforms
47
+
48
+ def __call__(self, img, tgt):
49
+ for t in self.transforms:
50
+ img, tgt = t(img, tgt)
51
+ return img, tgt
52
+
53
+ def __repr__(self):
54
+ format_string = self.__class__.__name__ + '('
55
+ for t in self.transforms:
56
+ format_string += '\n'
57
+ format_string += ' {0}'.format(t)
58
+ format_string += '\n)'
59
+ return format_string
60
+
61
+
62
+ class Resize(object):
63
+ """Resize the input PIL Image to the given size.
64
+
65
+ Args:
66
+ size (sequence or int): Desired output size. If size is a sequence like
67
+ (h, w), output size will be matched to this. If size is an int,
68
+ smaller edge of the image will be matched to this number.
69
+ i.e, if height > width, then image will be rescaled to
70
+ (size * height / width, size)
71
+ interpolation (int, optional): Desired interpolation. Default is
72
+ ``PIL.Image.BILINEAR``
73
+ """
74
+
75
+ def __init__(self, size, interpolation=Image.BILINEAR):
76
+ assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
77
+ self.size = size
78
+ self.interpolation = interpolation
79
+
80
+ def __call__(self, img, tgt):
81
+ """
82
+ Args:
83
+ img (PIL Image): Image to be scaled.
84
+
85
+ Returns:
86
+ PIL Image: Rescaled image.
87
+ """
88
+ return F.resize(img, self.size, self.interpolation), F.resize(tgt, self.size, Image.NEAREST)
89
+
90
+ def __repr__(self):
91
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
92
+ return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
93
+
94
+
95
+ class CenterCrop(object):
96
+ """Crops the given PIL Image at the center.
97
+
98
+ Args:
99
+ size (sequence or int): Desired output size of the crop. If size is an
100
+ int instead of sequence like (h, w), a square crop (size, size) is
101
+ made.
102
+ """
103
+
104
+ def __init__(self, size):
105
+ if isinstance(size, numbers.Number):
106
+ self.size = (int(size), int(size))
107
+ else:
108
+ self.size = size
109
+
110
+ def __call__(self, img, tgt):
111
+ """
112
+ Args:
113
+ img (PIL Image): Image to be cropped.
114
+
115
+ Returns:
116
+ PIL Image: Cropped image.
117
+ """
118
+ return F.center_crop(img, self.size), F.center_crop(tgt, self.size)
119
+
120
+ def __repr__(self):
121
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
122
+
123
+
124
+ class RandomCrop(object):
125
+ """Crop the given PIL Image at a random location.
126
+
127
+ Args:
128
+ size (sequence or int): Desired output size of the crop. If size is an
129
+ int instead of sequence like (h, w), a square crop (size, size) is
130
+ made.
131
+ padding (int or sequence, optional): Optional padding on each border
132
+ of the image. Default is None, i.e no padding. If a sequence of length
133
+ 4 is provided, it is used to pad left, top, right, bottom borders
134
+ respectively. If a sequence of length 2 is provided, it is used to
135
+ pad left/right, top/bottom borders, respectively.
136
+ pad_if_needed (boolean): It will pad the image if smaller than the
137
+ desired size to avoid raising an exception.
138
+ fill: Pixel fill value for constant fill. Default is 0. If a tuple of
139
+ length 3, it is used to fill R, G, B channels respectively.
140
+ This value is only used when the padding_mode is constant
141
+ padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
142
+
143
+ - constant: pads with a constant value, this value is specified with fill
144
+
145
+ - edge: pads with the last value on the edge of the image
146
+
147
+ - reflect: pads with reflection of image (without repeating the last value on the edge)
148
+
149
+ padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
150
+ will result in [3, 2, 1, 2, 3, 4, 3, 2]
151
+
152
+ - symmetric: pads with reflection of image (repeating the last value on the edge)
153
+
154
+ padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
155
+ will result in [2, 1, 1, 2, 3, 4, 4, 3]
156
+
157
+ """
158
+
159
+ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
160
+ if isinstance(size, numbers.Number):
161
+ self.size = (int(size), int(size))
162
+ else:
163
+ self.size = size
164
+ self.padding = padding
165
+ self.pad_if_needed = pad_if_needed
166
+ self.fill = fill
167
+ self.padding_mode = padding_mode
168
+
169
+ @staticmethod
170
+ def get_params(img, output_size):
171
+ """Get parameters for ``crop`` for a random crop.
172
+
173
+ Args:
174
+ img (PIL Image): Image to be cropped.
175
+ output_size (tuple): Expected output size of the crop.
176
+
177
+ Returns:
178
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
179
+ """
180
+ w, h = img.size
181
+ th, tw = output_size
182
+ if w == tw and h == th:
183
+ return 0, 0, h, w
184
+
185
+ i = random.randint(0, h - th)
186
+ j = random.randint(0, w - tw)
187
+ return i, j, th, tw
188
+
189
+ def __call__(self, img, tgt):
190
+ """
191
+ Args:
192
+ img (PIL Image): Image to be cropped.
193
+
194
+ Returns:
195
+ PIL Image: Cropped image.
196
+ """
197
+ if self.padding is not None:
198
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
199
+ tgt = F.pad(tgt, self.padding, self.fill, self.padding_mode)
200
+
201
+ # pad the width if needed
202
+ if self.pad_if_needed and img.size[0] < self.size[1]:
203
+ img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
204
+ tgt = F.pad(tgt, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
205
+ # pad the height if needed
206
+ if self.pad_if_needed and img.size[1] < self.size[0]:
207
+ img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
208
+ tgt = F.pad(tgt, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
209
+
210
+ i, j, h, w = self.get_params(img, self.size)
211
+
212
+ return F.crop(img, i, j, h, w), F.crop(tgt, i, j, h, w)
213
+
214
+ def __repr__(self):
215
+ return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
216
+
217
+
218
+ class RandomHorizontalFlip(object):
219
+ """Horizontally flip the given PIL Image randomly with a given probability.
220
+
221
+ Args:
222
+ p (float): probability of the image being flipped. Default value is 0.5
223
+ """
224
+
225
+ def __init__(self, p=0.5):
226
+ self.p = p
227
+
228
+ def __call__(self, img, tgt):
229
+ """
230
+ Args:
231
+ img (PIL Image): Image to be flipped.
232
+
233
+ Returns:
234
+ PIL Image: Randomly flipped image.
235
+ """
236
+ if random.random() < self.p:
237
+ return F.hflip(img), F.hflip(tgt)
238
+
239
+ return img, tgt
240
+
241
+ def __repr__(self):
242
+ return self.__class__.__name__ + '(p={})'.format(self.p)
243
+
244
+
245
+ class RandomVerticalFlip(object):
246
+ """Vertically flip the given PIL Image randomly with a given probability.
247
+
248
+ Args:
249
+ p (float): probability of the image being flipped. Default value is 0.5
250
+ """
251
+
252
+ def __init__(self, p=0.5):
253
+ self.p = p
254
+
255
+ def __call__(self, img, tgt):
256
+ """
257
+ Args:
258
+ img (PIL Image): Image to be flipped.
259
+
260
+ Returns:
261
+ PIL Image: Randomly flipped image.
262
+ """
263
+ if random.random() < self.p:
264
+ return F.vflip(img), F.vflip(tgt)
265
+ return img, tgt
266
+
267
+ def __repr__(self):
268
+ return self.__class__.__name__ + '(p={})'.format(self.p)
269
+
270
+
271
+ class Lambda(object):
272
+ """Apply a user-defined lambda as a transform.
273
+
274
+ Args:
275
+ lambd (function): Lambda/function to be used for transform.
276
+ """
277
+
278
+ def __init__(self, lambd):
279
+ assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
280
+ self.lambd = lambd
281
+
282
+ def __call__(self, img, tgt):
283
+ return self.lambd(img, tgt)
284
+
285
+ def __repr__(self):
286
+ return self.__class__.__name__ + '()'
287
+
288
+
289
+ class ColorJitter(object):
290
+ """Randomly change the brightness, contrast and saturation of an image.
291
+
292
+ Args:
293
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
294
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
295
+ or the given [min, max]. Should be non negative numbers.
296
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
297
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
298
+ or the given [min, max]. Should be non negative numbers.
299
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
300
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
301
+ or the given [min, max]. Should be non negative numbers.
302
+ hue (float or tuple of float (min, max)): How much to jitter hue.
303
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
304
+ Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
305
+ """
306
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
307
+ self.brightness = self._check_input(brightness, 'brightness')
308
+ self.contrast = self._check_input(contrast, 'contrast')
309
+ self.saturation = self._check_input(saturation, 'saturation')
310
+ self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
311
+ clip_first_on_zero=False)
312
+
313
+ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
314
+ if isinstance(value, numbers.Number):
315
+ if value < 0:
316
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
317
+ value = [center - value, center + value]
318
+ if clip_first_on_zero:
319
+ value[0] = max(value[0], 0)
320
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
321
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
322
+ raise ValueError("{} values should be between {}".format(name, bound))
323
+ else:
324
+ raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
325
+
326
+ # if value is 0 or (1., 1.) for brightness/contrast/saturation
327
+ # or (0., 0.) for hue, do nothing
328
+ if value[0] == value[1] == center:
329
+ value = None
330
+ return value
331
+
332
+ @staticmethod
333
+ def get_params(brightness, contrast, saturation, hue):
334
+ """Get a randomized transform to be applied on image.
335
+
336
+ Arguments are same as that of __init__.
337
+
338
+ Returns:
339
+ Transform which randomly adjusts brightness, contrast and
340
+ saturation in a random order.
341
+ """
342
+ transforms = []
343
+
344
+ if brightness is not None:
345
+ brightness_factor = random.uniform(brightness[0], brightness[1])
346
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_brightness(img, brightness_factor), tgt)))
347
+
348
+ if contrast is not None:
349
+ contrast_factor = random.uniform(contrast[0], contrast[1])
350
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_contrast(img, contrast_factor), tgt)))
351
+
352
+ if saturation is not None:
353
+ saturation_factor = random.uniform(saturation[0], saturation[1])
354
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_saturation(img, saturation_factor), tgt)))
355
+
356
+ if hue is not None:
357
+ hue_factor = random.uniform(hue[0], hue[1])
358
+ transforms.append(Lambda(lambda img, tgt: (F.adjust_hue(img, hue_factor), tgt)))
359
+
360
+ random.shuffle(transforms)
361
+ transform = Compose(transforms)
362
+
363
+ return transform
364
+
365
+ def __call__(self, img, tgt):
366
+ """
367
+ Args:
368
+ img (PIL Image): Input image.
369
+
370
+ Returns:
371
+ PIL Image: Color jittered image.
372
+ """
373
+ transform = self.get_params(self.brightness, self.contrast,
374
+ self.saturation, self.hue)
375
+ return transform(img, tgt)
376
+
377
+ def __repr__(self):
378
+ format_string = self.__class__.__name__ + '('
379
+ format_string += 'brightness={0}'.format(self.brightness)
380
+ format_string += ', contrast={0}'.format(self.contrast)
381
+ format_string += ', saturation={0}'.format(self.saturation)
382
+ format_string += ', hue={0})'.format(self.hue)
383
+ return format_string
384
+
385
+
386
+ class Normalize(object):
387
+ """Normalize a tensor image with mean and standard deviation.
388
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
389
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
390
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
391
+
392
+ .. note::
393
+ This transform acts out of place, i.e., it does not mutates the input tensor.
394
+
395
+ Args:
396
+ mean (sequence): Sequence of means for each channel.
397
+ std (sequence): Sequence of standard deviations for each channel.
398
+ """
399
+
400
+ def __init__(self, mean, std, inplace=False):
401
+ self.mean = mean
402
+ self.std = std
403
+ self.inplace = inplace
404
+
405
+ def __call__(self, img, tgt):
406
+ """
407
+ Args:
408
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
409
+
410
+ Returns:
411
+ Tensor: Normalized Tensor image.
412
+ """
413
+ # return F.normalize(img, self.mean, self.std, self.inplace), tgt
414
+ return F.normalize(img, self.mean, self.std), tgt
415
+
416
+ def __repr__(self):
417
+ return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
418
+
419
+
420
+ class ToTensor(object):
421
+ """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
422
+
423
+ Converts a PIL Image or numpy.ndarray (H x W x C) in the range
424
+ [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
425
+ if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
426
+ or if the numpy.ndarray has dtype = np.uint8
427
+
428
+ In the other cases, tensors are returned without scaling.
429
+ """
430
+
431
+ def __call__(self, img, tgt):
432
+ """
433
+ Args:
434
+ pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
435
+
436
+ Returns:
437
+ Tensor: Converted image.
438
+ """
439
+ return F.to_tensor(img), tgt
440
+
441
+ def __repr__(self):
442
+ return self.__class__.__name__ + '()'
SegmentationTest/imagenet_seg_eval.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torch.utils.data import DataLoader
5
+ from numpy import *
6
+ import argparse
7
+ from PIL import Image
8
+ import imageio
9
+ import os
10
+ from tqdm import tqdm
11
+ from SegmentationTest.utils.metrices import *
12
+
13
+ from SegmentationTest.utils import render
14
+ from SegmentationTest.utils.saver import Saver
15
+ from SegmentationTest.utils.iou import IoU
16
+
17
+ from SegmentationTest.data.Imagenet import Imagenet_Segmentation
18
+
19
+ # Uncomment the expected model below
20
+
21
+ # ViT
22
+ from ViT.ViT import vit_base_patch16_224 as vit
23
+ # from ViT.ViT import vit_large_patch16_224 as vit
24
+
25
+ # ViT-AugReg
26
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
27
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
28
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
29
+
30
+ # DeiT
31
+ # from ViT.ViT import deit_base_patch16_224 as vit
32
+ # from ViT.ViT import deit_small_patch16_224 as vit
33
+
34
+
35
+ from ViT.explainer import generate_relevance, get_image_with_relevance
36
+
37
+ from sklearn.metrics import precision_recall_curve
38
+ import matplotlib.pyplot as plt
39
+
40
+ import torch.nn.functional as F
41
+
42
+ import warnings
43
+ warnings.filterwarnings("ignore")
44
+
45
+ plt.switch_backend('agg')
46
+
47
+ # hyperparameters
48
+ num_workers = 0
49
+ batch_size = 1
50
+
51
+ cls = ['airplane',
52
+ 'bicycle',
53
+ 'bird',
54
+ 'boat',
55
+ 'bottle',
56
+ 'bus',
57
+ 'car',
58
+ 'cat',
59
+ 'chair',
60
+ 'cow',
61
+ 'dining table',
62
+ 'dog',
63
+ 'horse',
64
+ 'motobike',
65
+ 'person',
66
+ 'potted plant',
67
+ 'sheep',
68
+ 'sofa',
69
+ 'train',
70
+ 'tv'
71
+ ]
72
+
73
+ # Args
74
+ parser = argparse.ArgumentParser(description='Training multi-class classifier')
75
+ parser.add_argument('--arc', type=str, default='vgg', metavar='N',
76
+ help='Model architecture')
77
+ parser.add_argument('--train_dataset', type=str, default='imagenet', metavar='N',
78
+ help='Testing Dataset')
79
+ parser.add_argument('--method', type=str,
80
+ default='grad_rollout',
81
+ choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'lrp_last_layer',
82
+ 'attn_last_layer', 'attn_gradcam'],
83
+ help='')
84
+ parser.add_argument('--thr', type=float, default=0.,
85
+ help='threshold')
86
+ parser.add_argument('--K', type=int, default=1,
87
+ help='new - top K results')
88
+ parser.add_argument('--save-img', action='store_true',
89
+ default=False,
90
+ help='')
91
+ parser.add_argument('--no-ia', action='store_true',
92
+ default=False,
93
+ help='')
94
+ parser.add_argument('--no-fx', action='store_true',
95
+ default=False,
96
+ help='')
97
+ parser.add_argument('--no-fgx', action='store_true',
98
+ default=False,
99
+ help='')
100
+ parser.add_argument('--no-m', action='store_true',
101
+ default=False,
102
+ help='')
103
+ parser.add_argument('--no-reg', action='store_true',
104
+ default=False,
105
+ help='')
106
+ parser.add_argument('--is-ablation', type=bool,
107
+ default=False,
108
+ help='')
109
+ parser.add_argument('--imagenet-seg-path', type=str, required=True)
110
+ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
111
+ help='path to latest checkpoint (default: none)')
112
+ args = parser.parse_args()
113
+
114
+ args.checkname = args.method + '_' + args.arc
115
+
116
+ alpha = 2
117
+
118
+ cuda = torch.cuda.is_available()
119
+ device = torch.device("cuda" if cuda else "cpu")
120
+
121
+ # Define Saver
122
+ saver = Saver(args)
123
+ saver.results_dir = os.path.join(saver.experiment_dir, 'results')
124
+ if not os.path.exists(saver.results_dir):
125
+ os.makedirs(saver.results_dir)
126
+ if not os.path.exists(os.path.join(saver.results_dir, 'input')):
127
+ os.makedirs(os.path.join(saver.results_dir, 'input'))
128
+ if not os.path.exists(os.path.join(saver.results_dir, 'explain')):
129
+ os.makedirs(os.path.join(saver.results_dir, 'explain'))
130
+
131
+ args.exp_img_path = os.path.join(saver.results_dir, 'explain/img')
132
+ if not os.path.exists(args.exp_img_path):
133
+ os.makedirs(args.exp_img_path)
134
+ args.exp_np_path = os.path.join(saver.results_dir, 'explain/np')
135
+ if not os.path.exists(args.exp_np_path):
136
+ os.makedirs(args.exp_np_path)
137
+
138
+ # Data
139
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
140
+ test_img_trans = transforms.Compose([
141
+ transforms.Resize((224, 224)),
142
+ transforms.ToTensor(),
143
+ normalize,
144
+ ])
145
+ test_lbl_trans = transforms.Compose([
146
+ transforms.Resize((224, 224), Image.NEAREST),
147
+ ])
148
+
149
+ ds = Imagenet_Segmentation(args.imagenet_seg_path,
150
+ transform=test_img_trans, target_transform=test_lbl_trans)
151
+ dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False)
152
+
153
+ # Model
154
+ if args.checkpoint:
155
+ print(f"loading model from checkpoint {args.checkpoint}")
156
+ model = vit().cuda()
157
+ checkpoint = torch.load(args.checkpoint)
158
+ model.load_state_dict(checkpoint['state_dict'])
159
+ else:
160
+ model = vit(pretrained=True).cuda()
161
+
162
+ metric = IoU(2, ignore_index=-1)
163
+
164
+ iterator = tqdm(dl)
165
+
166
+ model.eval()
167
+
168
+
169
+ def compute_pred(output):
170
+ pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
171
+ # pred[0, 0] = 282
172
+ # print('Pred cls : ' + str(pred))
173
+ T = pred.squeeze().cpu().numpy()
174
+ T = np.expand_dims(T, 0)
175
+ T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
176
+ T = torch.from_numpy(T).type(torch.FloatTensor)
177
+ Tt = T.cuda()
178
+
179
+ return Tt
180
+
181
+
182
+ def eval_batch(image, labels, evaluator, index):
183
+ evaluator.zero_grad()
184
+ # Save input image
185
+ if args.save_img:
186
+ img = image[0].permute(1, 2, 0).data.cpu().numpy()
187
+ img = 255 * (img - img.min()) / (img.max() - img.min())
188
+ img = img.astype('uint8')
189
+ Image.fromarray(img, 'RGB').save(os.path.join(saver.results_dir, 'input/{}_input.png'.format(index)))
190
+ Image.fromarray((labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype('uint8'), 'RGB').save(
191
+ os.path.join(saver.results_dir, 'input/{}_mask.png'.format(index)))
192
+
193
+ image.requires_grad = True
194
+
195
+ image = image.requires_grad_()
196
+ predictions = evaluator(image)
197
+ Res = generate_relevance(model, image.cuda())
198
+
199
+ # threshold between FG and BG is the mean
200
+ Res = (Res - Res.min()) / (Res.max() - Res.min())
201
+
202
+ ret = Res.mean()
203
+
204
+ Res_1 = Res.gt(ret).type(Res.type())
205
+ Res_0 = Res.le(ret).type(Res.type())
206
+
207
+ Res_1_AP = Res
208
+ Res_0_AP = 1 - Res
209
+
210
+ Res_1[Res_1 != Res_1] = 0
211
+ Res_0[Res_0 != Res_0] = 0
212
+ Res_1_AP[Res_1_AP != Res_1_AP] = 0
213
+ Res_0_AP[Res_0_AP != Res_0_AP] = 0
214
+
215
+ # TEST
216
+ pred = Res.clamp(min=args.thr) / Res.max()
217
+ pred = pred.view(-1).data.cpu().numpy()
218
+ target = labels.view(-1).data.cpu().numpy()
219
+ # print("target", target.shape)
220
+
221
+ output = torch.cat((Res_0, Res_1), 1)
222
+ output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)
223
+
224
+ if args.save_img:
225
+ # Save predicted mask
226
+ mask = F.interpolate(Res_1, [64, 64], mode='bilinear')
227
+ mask = mask[0].squeeze().data.cpu().numpy()
228
+ # mask = Res_1[0].squeeze().data.cpu().numpy()
229
+ mask = 255 * mask
230
+ mask = mask.astype('uint8')
231
+ imageio.imsave(os.path.join(args.exp_img_path, 'mask_' + str(index) + '.jpg'), mask)
232
+
233
+ relevance = F.interpolate(Res, [64, 64], mode='bilinear')
234
+ relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
235
+ # relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
236
+ hm = np.sum(relevance, axis=-1)
237
+ maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap='seismic') * 255).astype(np.uint8)
238
+ imageio.imsave(os.path.join(args.exp_img_path, 'heatmap_' + str(index) + '.jpg'), maps)
239
+
240
+ # Evaluate Segmentation
241
+ batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
242
+ batch_ap, batch_f1 = 0, 0
243
+
244
+ # Segmentation resutls
245
+ correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
246
+ inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
247
+ batch_correct += correct
248
+ batch_label += labeled
249
+ batch_inter += inter
250
+ batch_union += union
251
+ # print("output", output.shape)
252
+ # print("ap labels", labels.shape)
253
+ # ap = np.nan_to_num(get_ap_scores(output, labels))
254
+ ap = np.nan_to_num(get_ap_scores(output_AP, labels))
255
+ # f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
256
+ batch_ap += ap
257
+ # batch_f1 += f1
258
+
259
+ # return batch_correct, batch_label, batch_inter, batch_union, batch_ap, batch_f1, pred, target
260
+ return batch_correct, batch_label, batch_inter, batch_union, batch_ap, pred, target
261
+
262
+
263
+ total_inter, total_union, total_correct, total_label = np.int64(0), np.int64(0), np.int64(0), np.int64(0)
264
+ total_ap, total_f1 = [], []
265
+
266
+ predictions, targets = [], []
267
+ for batch_idx, (image, labels) in enumerate(iterator):
268
+
269
+ if args.method == "blur":
270
+ images = (image[0].cuda(), image[1].cuda())
271
+ else:
272
+ images = image.cuda()
273
+ labels = labels.cuda()
274
+ # print("image", image.shape)
275
+ # print("lables", labels.shape)
276
+
277
+ # correct, labeled, inter, union, ap, f1, pred, target = eval_batch(images, labels, model, batch_idx)
278
+ correct, labeled, inter, union, ap, pred, target = eval_batch(images, labels, model, batch_idx)
279
+
280
+ predictions.append(pred)
281
+ targets.append(target)
282
+
283
+ total_correct += correct.astype('int64')
284
+ total_label += labeled.astype('int64')
285
+ total_inter += inter.astype('int64')
286
+ total_union += union.astype('int64')
287
+ total_ap += [ap]
288
+ # total_f1 += [f1]
289
+ pixAcc = np.float64(1.0) * total_correct / (np.spacing(1, dtype=np.float64) + total_label)
290
+ IoU = np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
291
+ mIoU = IoU.mean()
292
+ mAp = np.mean(total_ap)
293
+ # mF1 = np.mean(total_f1)
294
+ # iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f' % (pixAcc, mIoU, mAp, mF1))
295
+ iterator.set_description('pixAcc: %.4f, mIoU: %.4f, mAP: %.4f' % (pixAcc, mIoU, mAp))
296
+
297
+ predictions = np.concatenate(predictions)
298
+ targets = np.concatenate(targets)
299
+ pr, rc, thr = precision_recall_curve(targets, predictions)
300
+ np.save(os.path.join(saver.experiment_dir, 'precision.npy'), pr)
301
+ np.save(os.path.join(saver.experiment_dir, 'recall.npy'), rc)
302
+
303
+ plt.figure()
304
+ plt.plot(rc, pr)
305
+ plt.savefig(os.path.join(saver.experiment_dir, 'PR_curve_{}.png'.format(args.method)))
306
+
307
+ txtfile = os.path.join(saver.experiment_dir, 'result_mIoU_%.4f.txt' % mIoU)
308
+ # txtfile = 'result_mIoU_%.4f.txt' % mIoU
309
+ fh = open(txtfile, 'w')
310
+ print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
311
+ print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
312
+ print("Mean AP over %d classes: %.4f\n" % (2, mAp))
313
+ # print("Mean F1 over %d classes: %.4f\n" % (2, mF1))
314
+
315
+ fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
316
+ fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
317
+ fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
318
+ # fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
319
+ fh.close()
SegmentationTest/utils/__init__.py ADDED
File without changes
SegmentationTest/utils/confusionmatrix.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from . import metric
4
+
5
+
6
+ class ConfusionMatrix(metric.Metric):
7
+ """Constructs a confusion matrix for a multi-class classification problems.
8
+ Does not support multi-label, multi-class problems.
9
+ Keyword arguments:
10
+ - num_classes (int): number of classes in the classification problem.
11
+ - normalized (boolean, optional): Determines whether or not the confusion
12
+ matrix is normalized or not. Default: False.
13
+ Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
14
+ """
15
+
16
+ def __init__(self, num_classes, normalized=False):
17
+ super().__init__()
18
+
19
+ self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
20
+ self.normalized = normalized
21
+ self.num_classes = num_classes
22
+ self.reset()
23
+
24
+ def reset(self):
25
+ self.conf.fill(0)
26
+
27
+ def add(self, predicted, target):
28
+ """Computes the confusion matrix
29
+ The shape of the confusion matrix is K x K, where K is the number
30
+ of classes.
31
+ Keyword arguments:
32
+ - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
33
+ predicted scores obtained from the model for N examples and K classes,
34
+ or an N-tensor/array of integer values between 0 and K-1.
35
+ - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
36
+ ground-truth classes for N examples and K classes, or an N-tensor/array
37
+ of integer values between 0 and K-1.
38
+ """
39
+ # If target and/or predicted are tensors, convert them to numpy arrays
40
+ if torch.is_tensor(predicted):
41
+ predicted = predicted.cpu().numpy()
42
+ if torch.is_tensor(target):
43
+ target = target.cpu().numpy()
44
+
45
+ assert predicted.shape[0] == target.shape[0], \
46
+ 'number of targets and predicted outputs do not match'
47
+
48
+ if np.ndim(predicted) != 1:
49
+ assert predicted.shape[1] == self.num_classes, \
50
+ 'number of predictions does not match size of confusion matrix'
51
+ predicted = np.argmax(predicted, 1)
52
+ else:
53
+ assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
54
+ 'predicted values are not between 0 and k-1'
55
+
56
+ if np.ndim(target) != 1:
57
+ assert target.shape[1] == self.num_classes, \
58
+ 'Onehot target does not match size of confusion matrix'
59
+ assert (target >= 0).all() and (target <= 1).all(), \
60
+ 'in one-hot encoding, target values should be 0 or 1'
61
+ assert (target.sum(1) == 1).all(), \
62
+ 'multi-label setting is not supported'
63
+ target = np.argmax(target, 1)
64
+ else:
65
+ assert (target.max() < self.num_classes) and (target.min() >= 0), \
66
+ 'target values are not between 0 and k-1'
67
+
68
+ # hack for bincounting 2 arrays together
69
+ x = predicted + self.num_classes * target
70
+ bincount_2d = np.bincount(
71
+ x.astype(np.int32), minlength=self.num_classes**2)
72
+ assert bincount_2d.size == self.num_classes**2
73
+ conf = bincount_2d.reshape((self.num_classes, self.num_classes))
74
+
75
+ self.conf += conf
76
+
77
+ def value(self):
78
+ """
79
+ Returns:
80
+ Confustion matrix of K rows and K columns, where rows corresponds
81
+ to ground-truth targets and columns corresponds to predicted
82
+ targets.
83
+ """
84
+ if self.normalized:
85
+ conf = self.conf.astype(np.float32)
86
+ return conf / conf.sum(1).clip(min=1e-12)[:, None]
87
+ else:
88
+ return self.conf
SegmentationTest/utils/iou.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from . import metric
4
+ from .confusionmatrix import ConfusionMatrix
5
+
6
+
7
+ class IoU(metric.Metric):
8
+ """Computes the intersection over union (IoU) per class and corresponding
9
+ mean (mIoU).
10
+
11
+ Intersection over union (IoU) is a common evaluation metric for semantic
12
+ segmentation. The predictions are first accumulated in a confusion matrix
13
+ and the IoU is computed from it as follows:
14
+
15
+ IoU = true_positive / (true_positive + false_positive + false_negative).
16
+
17
+ Keyword arguments:
18
+ - num_classes (int): number of classes in the classification problem
19
+ - normalized (boolean, optional): Determines whether or not the confusion
20
+ matrix is normalized or not. Default: False.
21
+ - ignore_index (int or iterable, optional): Index of the classes to ignore
22
+ when computing the IoU. Can be an int, or any iterable of ints.
23
+ """
24
+
25
+ def __init__(self, num_classes, normalized=False, ignore_index=None):
26
+ super().__init__()
27
+ self.conf_metric = ConfusionMatrix(num_classes, normalized)
28
+
29
+ if ignore_index is None:
30
+ self.ignore_index = None
31
+ elif isinstance(ignore_index, int):
32
+ self.ignore_index = (ignore_index,)
33
+ else:
34
+ try:
35
+ self.ignore_index = tuple(ignore_index)
36
+ except TypeError:
37
+ raise ValueError("'ignore_index' must be an int or iterable")
38
+
39
+ def reset(self):
40
+ self.conf_metric.reset()
41
+
42
+ def add(self, predicted, target):
43
+ """Adds the predicted and target pair to the IoU metric.
44
+
45
+ Keyword arguments:
46
+ - predicted (Tensor): Can be a (N, K, H, W) tensor of
47
+ predicted scores obtained from the model for N examples and K classes,
48
+ or (N, H, W) tensor of integer values between 0 and K-1.
49
+ - target (Tensor): Can be a (N, K, H, W) tensor of
50
+ target scores for N examples and K classes, or (N, H, W) tensor of
51
+ integer values between 0 and K-1.
52
+
53
+ """
54
+ # Dimensions check
55
+ assert predicted.size(0) == target.size(0), \
56
+ 'number of targets and predicted outputs do not match'
57
+ assert predicted.dim() == 3 or predicted.dim() == 4, \
58
+ "predictions must be of dimension (N, H, W) or (N, K, H, W)"
59
+ assert target.dim() == 3 or target.dim() == 4, \
60
+ "targets must be of dimension (N, H, W) or (N, K, H, W)"
61
+
62
+ # If the tensor is in categorical format convert it to integer format
63
+ if predicted.dim() == 4:
64
+ _, predicted = predicted.max(1)
65
+ if target.dim() == 4:
66
+ _, target = target.max(1)
67
+
68
+ self.conf_metric.add(predicted.view(-1), target.view(-1))
69
+
70
+ def value(self):
71
+ """Computes the IoU and mean IoU.
72
+
73
+ The mean computation ignores NaN elements of the IoU array.
74
+
75
+ Returns:
76
+ Tuple: (IoU, mIoU). The first output is the per class IoU,
77
+ for K classes it's numpy.ndarray with K elements. The second output,
78
+ is the mean IoU.
79
+ """
80
+ conf_matrix = self.conf_metric.value()
81
+ if self.ignore_index is not None:
82
+ for index in self.ignore_index:
83
+ conf_matrix[:, self.ignore_index] = 0
84
+ conf_matrix[self.ignore_index, :] = 0
85
+ true_positive = np.diag(conf_matrix)
86
+ false_positive = np.sum(conf_matrix, 0) - true_positive
87
+ false_negative = np.sum(conf_matrix, 1) - true_positive
88
+
89
+ # Just in case we get a division by 0, ignore/hide the error
90
+ with np.errstate(divide='ignore', invalid='ignore'):
91
+ iou = true_positive / (true_positive + false_positive + false_negative)
92
+
93
+ return iou, np.nanmean(iou)
SegmentationTest/utils/metric.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Metric(object):
2
+ """Base class for all metrics.
3
+ From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
4
+ """
5
+ def reset(self):
6
+ pass
7
+
8
+ def add(self):
9
+ pass
10
+
11
+ def value(self):
12
+ pass
SegmentationTest/utils/metrices.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.metrics import f1_score, average_precision_score
4
+ from sklearn.metrics import precision_recall_curve, roc_curve
5
+
6
+ SMOOTH = 1e-6
7
+ __all__ = ['get_f1_scores', 'get_ap_scores', 'batch_pix_accuracy', 'batch_intersection_union', 'get_iou', 'get_pr',
8
+ 'get_roc', 'get_ap_multiclass']
9
+
10
+
11
+ def get_iou(outputs: torch.Tensor, labels: torch.Tensor):
12
+ # You can comment out this line if you are passing tensors of equal shape
13
+ # But if you are passing output from UNet or something it will most probably
14
+ # be with the BATCH x 1 x H x W shape
15
+ outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
16
+ labels = labels.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W
17
+
18
+ intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0
19
+ union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0
20
+
21
+ iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0
22
+
23
+ return iou.cpu().numpy()
24
+
25
+
26
+ def get_f1_scores(predict, target, ignore_index=-1):
27
+ # Tensor process
28
+ batch_size = predict.shape[0]
29
+ predict = predict.data.cpu().numpy().reshape(-1)
30
+ target = target.data.cpu().numpy().reshape(-1)
31
+ pb = predict[target != ignore_index].reshape(batch_size, -1)
32
+ tb = target[target != ignore_index].reshape(batch_size, -1)
33
+
34
+ total = []
35
+ for p, t in zip(pb, tb):
36
+ total.append(np.nan_to_num(f1_score(t, p)))
37
+
38
+ return total
39
+
40
+
41
+ def get_roc(predict, target, ignore_index=-1):
42
+ target_expand = target.unsqueeze(1).expand_as(predict)
43
+ target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
44
+ # Tensor process
45
+ x = torch.zeros_like(target_expand)
46
+ t = target.unsqueeze(1).clamp(min=0)
47
+ target_1hot = x.scatter_(1, t, 1)
48
+ batch_size = predict.shape[0]
49
+ predict = predict.data.cpu().numpy().reshape(-1)
50
+ target = target_1hot.data.cpu().numpy().reshape(-1)
51
+ pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
52
+ tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
53
+
54
+ total = []
55
+ for p, t in zip(pb, tb):
56
+ total.append(roc_curve(t, p))
57
+
58
+ return total
59
+
60
+
61
+ def get_pr(predict, target, ignore_index=-1):
62
+ target_expand = target.unsqueeze(1).expand_as(predict)
63
+ target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
64
+ # Tensor process
65
+ x = torch.zeros_like(target_expand)
66
+ t = target.unsqueeze(1).clamp(min=0)
67
+ target_1hot = x.scatter_(1, t, 1)
68
+ batch_size = predict.shape[0]
69
+ predict = predict.data.cpu().numpy().reshape(-1)
70
+ target = target_1hot.data.cpu().numpy().reshape(-1)
71
+ pb = predict[target_expand_numpy != ignore_index].reshape(batch_size, -1)
72
+ tb = target[target_expand_numpy != ignore_index].reshape(batch_size, -1)
73
+
74
+ total = []
75
+ for p, t in zip(pb, tb):
76
+ total.append(precision_recall_curve(t, p))
77
+
78
+ return total
79
+
80
+
81
+ def get_ap_scores(predict, target, ignore_index=-1):
82
+ total = []
83
+ for pred, tgt in zip(predict, target):
84
+ target_expand = tgt.unsqueeze(0).expand_as(pred)
85
+ target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
86
+
87
+ # Tensor process
88
+ x = torch.zeros_like(target_expand)
89
+ t = tgt.unsqueeze(0).clamp(min=0).long()
90
+ target_1hot = x.scatter_(0, t, 1)
91
+ predict_flat = pred.data.cpu().numpy().reshape(-1)
92
+ target_flat = target_1hot.data.cpu().numpy().reshape(-1)
93
+
94
+ p = predict_flat[target_expand_numpy != ignore_index]
95
+ t = target_flat[target_expand_numpy != ignore_index]
96
+
97
+ total.append(np.nan_to_num(average_precision_score(t, p)))
98
+
99
+ return total
100
+
101
+
102
+ def get_ap_multiclass(predict, target):
103
+ total = []
104
+ for pred, tgt in zip(predict, target):
105
+ predict_flat = pred.data.cpu().numpy().reshape(-1)
106
+ target_flat = tgt.data.cpu().numpy().reshape(-1)
107
+
108
+ total.append(np.nan_to_num(average_precision_score(target_flat, predict_flat)))
109
+
110
+ return total
111
+
112
+
113
+ def batch_precision_recall(predict, target, thr=0.5):
114
+ """Batch Precision Recall
115
+ Args:
116
+ predict: input 4D tensor
117
+ target: label 4D tensor
118
+ """
119
+ # _, predict = torch.max(predict, 1)
120
+
121
+ predict = predict > thr
122
+ predict = predict.data.cpu().numpy() + 1
123
+ target = target.data.cpu().numpy() + 1
124
+
125
+ tp = np.sum(((predict == 2) * (target == 2)) * (target > 0))
126
+ fp = np.sum(((predict == 2) * (target == 1)) * (target > 0))
127
+ fn = np.sum(((predict == 1) * (target == 2)) * (target > 0))
128
+
129
+ precision = float(np.nan_to_num(tp / (tp + fp)))
130
+ recall = float(np.nan_to_num(tp / (tp + fn)))
131
+
132
+ return precision, recall
133
+
134
+
135
+ def batch_pix_accuracy(predict, target):
136
+ """Batch Pixel Accuracy
137
+ Args:
138
+ predict: input 3D tensor
139
+ target: label 3D tensor
140
+ """
141
+
142
+ # for thr in np.linspace(0, 1, slices):
143
+
144
+ _, predict = torch.max(predict, 0)
145
+ predict = predict.cpu().numpy() + 1
146
+ target = target.cpu().numpy() + 1
147
+ pixel_labeled = np.sum(target > 0)
148
+ pixel_correct = np.sum((predict == target) * (target > 0))
149
+ assert pixel_correct <= pixel_labeled, \
150
+ "Correct area should be smaller than Labeled"
151
+ return pixel_correct, pixel_labeled
152
+
153
+
154
+ def batch_intersection_union(predict, target, nclass):
155
+ """Batch Intersection of Union
156
+ Args:
157
+ predict: input 3D tensor
158
+ target: label 3D tensor
159
+ nclass: number of categories (int)
160
+ """
161
+ _, predict = torch.max(predict, 0)
162
+ mini = 1
163
+ maxi = nclass
164
+ nbins = nclass
165
+ predict = predict.cpu().numpy() + 1
166
+ target = target.cpu().numpy() + 1
167
+
168
+ predict = predict * (target > 0).astype(predict.dtype)
169
+ intersection = predict * (predict == target)
170
+ # areas of intersection and union
171
+ area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
172
+ area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
173
+ area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
174
+ area_union = area_pred + area_lab - area_inter
175
+ assert (area_inter <= area_union).all(), \
176
+ "Intersection area should be smaller than Union area"
177
+ return area_inter, area_union
178
+
179
+
180
+ # ref https://github.com/CSAILVision/sceneparsing/blob/master/evaluationCode/utils_eval.py
181
+ def pixel_accuracy(im_pred, im_lab):
182
+ im_pred = np.asarray(im_pred)
183
+ im_lab = np.asarray(im_lab)
184
+
185
+ # Remove classes from unlabeled pixels in gt image.
186
+ # We should not penalize detections in unlabeled portions of the image.
187
+ pixel_labeled = np.sum(im_lab > 0)
188
+ pixel_correct = np.sum((im_pred == im_lab) * (im_lab > 0))
189
+ # pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
190
+ return pixel_correct, pixel_labeled
191
+
192
+
193
+ def intersection_and_union(im_pred, im_lab, num_class):
194
+ im_pred = np.asarray(im_pred)
195
+ im_lab = np.asarray(im_lab)
196
+ # Remove classes from unlabeled pixels in gt image.
197
+ im_pred = im_pred * (im_lab > 0)
198
+ # Compute area intersection:
199
+ intersection = im_pred * (im_pred == im_lab)
200
+ area_inter, _ = np.histogram(intersection, bins=num_class - 1,
201
+ range=(1, num_class - 1))
202
+ # Compute area union:
203
+ area_pred, _ = np.histogram(im_pred, bins=num_class - 1,
204
+ range=(1, num_class - 1))
205
+ area_lab, _ = np.histogram(im_lab, bins=num_class - 1,
206
+ range=(1, num_class - 1))
207
+ area_union = area_pred + area_lab - area_inter
208
+ return area_inter, area_union
SegmentationTest/utils/parallel.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
+ ## Created by: Hang Zhang
3
+ ## ECE Department, Rutgers University
4
+ ## Email: [email protected]
5
+ ## Copyright (c) 2017
6
+ ##
7
+ ## This source code is licensed under the MIT-style license found in the
8
+ ## LICENSE file in the root directory of this source tree
9
+ ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
10
+
11
+ """Encoding Data Parallel"""
12
+ import threading
13
+ import functools
14
+ import torch
15
+ from torch.autograd import Variable, Function
16
+ import torch.cuda.comm as comm
17
+ from torch.nn.parallel.data_parallel import DataParallel
18
+ from torch.nn.parallel.parallel_apply import get_a_var
19
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
20
+
21
+ torch_ver = torch.__version__[:3]
22
+
23
+ __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
24
+ 'patch_replication_callback']
25
+
26
+ def allreduce(*inputs):
27
+ """Cross GPU all reduce autograd operation for calculate mean and
28
+ variance in SyncBN.
29
+ """
30
+ return AllReduce.apply(*inputs)
31
+
32
+ class AllReduce(Function):
33
+ @staticmethod
34
+ def forward(ctx, num_inputs, *inputs):
35
+ ctx.num_inputs = num_inputs
36
+ ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
37
+ inputs = [inputs[i:i + num_inputs]
38
+ for i in range(0, len(inputs), num_inputs)]
39
+ # sort before reduce sum
40
+ inputs = sorted(inputs, key=lambda i: i[0].get_device())
41
+ results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
42
+ outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
43
+ return tuple([t for tensors in outputs for t in tensors])
44
+
45
+ @staticmethod
46
+ def backward(ctx, *inputs):
47
+ inputs = [i.data for i in inputs]
48
+ inputs = [inputs[i:i + ctx.num_inputs]
49
+ for i in range(0, len(inputs), ctx.num_inputs)]
50
+ results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
51
+ outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
52
+ return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
53
+
54
+
55
+ class Reduce(Function):
56
+ @staticmethod
57
+ def forward(ctx, *inputs):
58
+ ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
59
+ inputs = sorted(inputs, key=lambda i: i.get_device())
60
+ return comm.reduce_add(inputs)
61
+
62
+ @staticmethod
63
+ def backward(ctx, gradOutput):
64
+ return Broadcast.apply(ctx.target_gpus, gradOutput)
65
+
66
+
67
+ class DataParallelModel(DataParallel):
68
+ """Implements data parallelism at the module level.
69
+
70
+ This container parallelizes the application of the given module by
71
+ splitting the input across the specified devices by chunking in the
72
+ batch dimension.
73
+ In the forward pass, the module is replicated on each device,
74
+ and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
75
+ Note that the outputs are not gathered, please use compatible
76
+ :class:`encoding.parallel.DataParallelCriterion`.
77
+
78
+ The batch size should be larger than the number of GPUs used. It should
79
+ also be an integer multiple of the number of GPUs so that each chunk is
80
+ the same size (so that each GPU processes the same number of samples).
81
+
82
+ Args:
83
+ module: module to be parallelized
84
+ device_ids: CUDA devices (default: all devices)
85
+
86
+ Reference:
87
+ Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
88
+ Amit Agrawal. “Context Encoding for Semantic Segmentation.
89
+ *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
90
+
91
+ Example::
92
+
93
+ >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
94
+ >>> y = net(x)
95
+ """
96
+ def gather(self, outputs, output_device):
97
+ return outputs
98
+
99
+ def replicate(self, module, device_ids):
100
+ modules = super(DataParallelModel, self).replicate(module, device_ids)
101
+ execute_replication_callbacks(modules)
102
+ return modules
103
+
104
+
105
+ class DataParallelCriterion(DataParallel):
106
+ """
107
+ Calculate loss in multiple-GPUs, which balance the memory usage for
108
+ Semantic Segmentation.
109
+
110
+ The targets are splitted across the specified devices by chunking in
111
+ the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
112
+
113
+ Reference:
114
+ Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
115
+ Amit Agrawal. “Context Encoding for Semantic Segmentation.
116
+ *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
117
+
118
+ Example::
119
+
120
+ >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
121
+ >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
122
+ >>> y = net(x)
123
+ >>> loss = criterion(y, target)
124
+ """
125
+ def forward(self, inputs, *targets, **kwargs):
126
+ # input should be already scatterd
127
+ # scattering the targets instead
128
+ if not self.device_ids:
129
+ return self.module(inputs, *targets, **kwargs)
130
+ targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
131
+ if len(self.device_ids) == 1:
132
+ return self.module(inputs, *targets[0], **kwargs[0])
133
+ replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
134
+ outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
135
+ return Reduce.apply(*outputs) / len(outputs)
136
+ #return self.gather(outputs, self.output_device).mean()
137
+
138
+
139
+ def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
140
+ assert len(modules) == len(inputs)
141
+ assert len(targets) == len(inputs)
142
+ if kwargs_tup:
143
+ assert len(modules) == len(kwargs_tup)
144
+ else:
145
+ kwargs_tup = ({},) * len(modules)
146
+ if devices is not None:
147
+ assert len(modules) == len(devices)
148
+ else:
149
+ devices = [None] * len(modules)
150
+
151
+ lock = threading.Lock()
152
+ results = {}
153
+ if torch_ver != "0.3":
154
+ grad_enabled = torch.is_grad_enabled()
155
+
156
+ def _worker(i, module, input, target, kwargs, device=None):
157
+ if torch_ver != "0.3":
158
+ torch.set_grad_enabled(grad_enabled)
159
+ if device is None:
160
+ device = get_a_var(input).get_device()
161
+ try:
162
+ with torch.cuda.device(device):
163
+ # this also avoids accidental slicing of `input` if it is a Tensor
164
+ if not isinstance(input, (list, tuple)):
165
+ input = (input,)
166
+ if type(input) != type(target):
167
+ if isinstance(target, tuple):
168
+ input = tuple(input)
169
+ elif isinstance(target, list):
170
+ input = list(input)
171
+ else:
172
+ raise Exception("Types problem")
173
+
174
+ output = module(*(input + target), **kwargs)
175
+ with lock:
176
+ results[i] = output
177
+ except Exception as e:
178
+ with lock:
179
+ results[i] = e
180
+
181
+ if len(modules) > 1:
182
+ threads = [threading.Thread(target=_worker,
183
+ args=(i, module, input, target,
184
+ kwargs, device),)
185
+ for i, (module, input, target, kwargs, device) in
186
+ enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
187
+
188
+ for thread in threads:
189
+ thread.start()
190
+ for thread in threads:
191
+ thread.join()
192
+ else:
193
+ _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
194
+
195
+ outputs = []
196
+ for i in range(len(inputs)):
197
+ output = results[i]
198
+ if isinstance(output, Exception):
199
+ raise output
200
+ outputs.append(output)
201
+ return outputs
202
+
203
+
204
+ ###########################################################################
205
+ # Adapted from Synchronized-BatchNorm-PyTorch.
206
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
207
+ #
208
+ class CallbackContext(object):
209
+ pass
210
+
211
+
212
+ def execute_replication_callbacks(modules):
213
+ """
214
+ Execute an replication callback `__data_parallel_replicate__` on each module created
215
+ by original replication.
216
+
217
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
218
+
219
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
220
+ (shared among multiple copies of this module on different devices).
221
+ Through this context, different copies can share some information.
222
+
223
+ We guarantee that the callback on the master copy (the first copy) will be called ahead
224
+ of calling the callback of any slave copies.
225
+ """
226
+ master_copy = modules[0]
227
+ nr_modules = len(list(master_copy.modules()))
228
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
229
+
230
+ for i, module in enumerate(modules):
231
+ for j, m in enumerate(module.modules()):
232
+ if hasattr(m, '__data_parallel_replicate__'):
233
+ m.__data_parallel_replicate__(ctxs[j], i)
234
+
235
+
236
+ def patch_replication_callback(data_parallel):
237
+ """
238
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
239
+ Useful when you have customized `DataParallel` implementation.
240
+
241
+ Examples:
242
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
243
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
244
+ > patch_replication_callback(sync_bn)
245
+ # this is equivalent to
246
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
247
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
248
+ """
249
+
250
+ assert isinstance(data_parallel, DataParallel)
251
+
252
+ old_replicate = data_parallel.replicate
253
+
254
+ @functools.wraps(old_replicate)
255
+ def new_replicate(module, device_ids):
256
+ modules = old_replicate(module, device_ids)
257
+ execute_replication_callbacks(modules)
258
+ return modules
259
+
260
+ data_parallel.replicate = new_replicate
SegmentationTest/utils/render.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.cm
3
+ import skimage.io
4
+ import skimage.feature
5
+ import skimage.filters
6
+
7
+
8
+ def vec2im(V, shape=()):
9
+ '''
10
+ Transform an array V into a specified shape - or if no shape is given assume a square output format.
11
+
12
+ Parameters
13
+ ----------
14
+
15
+ V : numpy.ndarray
16
+ an array either representing a matrix or vector to be reshaped into an two-dimensional image
17
+
18
+ shape : tuple or list
19
+ optional. containing the shape information for the output array if not given, the output is assumed to be square
20
+
21
+ Returns
22
+ -------
23
+
24
+ W : numpy.ndarray
25
+ with W.shape = shape or W.shape = [np.sqrt(V.size)]*2
26
+
27
+ '''
28
+
29
+ if len(shape) < 2:
30
+ shape = [np.sqrt(V.size)] * 2
31
+ shape = map(int, shape)
32
+ return np.reshape(V, shape)
33
+
34
+
35
+ def enlarge_image(img, scaling=3):
36
+ '''
37
+ Enlarges a given input matrix by replicating each pixel value scaling times in horizontal and vertical direction.
38
+
39
+ Parameters
40
+ ----------
41
+
42
+ img : numpy.ndarray
43
+ array of shape [H x W] OR [H x W x D]
44
+
45
+ scaling : int
46
+ positive integer value > 0
47
+
48
+ Returns
49
+ -------
50
+
51
+ out : numpy.ndarray
52
+ two-dimensional array of shape [scaling*H x scaling*W]
53
+ OR
54
+ three-dimensional array of shape [scaling*H x scaling*W x D]
55
+ depending on the dimensionality of the input
56
+ '''
57
+
58
+ if scaling < 1 or not isinstance(scaling, int):
59
+ print('scaling factor needs to be an int >= 1')
60
+
61
+ if len(img.shape) == 2:
62
+ H, W = img.shape
63
+
64
+ out = np.zeros((scaling * H, scaling * W))
65
+ for h in range(H):
66
+ fh = scaling * h
67
+ for w in range(W):
68
+ fw = scaling * w
69
+ out[fh:fh + scaling, fw:fw + scaling] = img[h, w]
70
+
71
+ elif len(img.shape) == 3:
72
+ H, W, D = img.shape
73
+
74
+ out = np.zeros((scaling * H, scaling * W, D))
75
+ for h in range(H):
76
+ fh = scaling * h
77
+ for w in range(W):
78
+ fw = scaling * w
79
+ out[fh:fh + scaling, fw:fw + scaling, :] = img[h, w, :]
80
+
81
+ return out
82
+
83
+
84
+ def repaint_corner_pixels(rgbimg, scaling=3):
85
+ '''
86
+ DEPRECATED/OBSOLETE.
87
+
88
+ Recolors the top left and bottom right pixel (groups) with the average rgb value of its three neighboring pixel (groups).
89
+ The recoloring visually masks the opposing pixel values which are a product of stabilizing the scaling.
90
+ Assumes those image ares will pretty much never show evidence.
91
+
92
+ Parameters
93
+ ----------
94
+
95
+ rgbimg : numpy.ndarray
96
+ array of shape [H x W x 3]
97
+
98
+ scaling : int
99
+ positive integer value > 0
100
+
101
+ Returns
102
+ -------
103
+
104
+ rgbimg : numpy.ndarray
105
+ three-dimensional array of shape [scaling*H x scaling*W x 3]
106
+ '''
107
+
108
+ # top left corner.
109
+ rgbimg[0:scaling, 0:scaling, :] = (rgbimg[0, scaling, :] + rgbimg[scaling, 0, :] + rgbimg[scaling, scaling,
110
+ :]) / 3.0
111
+ # bottom right corner
112
+ rgbimg[-scaling:, -scaling:, :] = (rgbimg[-1, -1 - scaling, :] + rgbimg[-1 - scaling, -1, :] + rgbimg[-1 - scaling,
113
+ -1 - scaling,
114
+ :]) / 3.0
115
+ return rgbimg
116
+
117
+
118
+ def digit_to_rgb(X, scaling=3, shape=(), cmap='binary'):
119
+ '''
120
+ Takes as input an intensity array and produces a rgb image due to some color map
121
+
122
+ Parameters
123
+ ----------
124
+
125
+ X : numpy.ndarray
126
+ intensity matrix as array of shape [M x N]
127
+
128
+ scaling : int
129
+ optional. positive integer value > 0
130
+
131
+ shape: tuple or list of its , length = 2
132
+ optional. if not given, X is reshaped to be square.
133
+
134
+ cmap : str
135
+ name of color map of choice. default is 'binary'
136
+
137
+ Returns
138
+ -------
139
+
140
+ image : numpy.ndarray
141
+ three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
142
+ '''
143
+
144
+ # create color map object from name string
145
+ cmap = eval('matplotlib.cm.{}'.format(cmap))
146
+
147
+ image = enlarge_image(vec2im(X, shape), scaling) # enlarge
148
+ image = cmap(image.flatten())[..., 0:3].reshape([image.shape[0], image.shape[1], 3]) # colorize, reshape
149
+
150
+ return image
151
+
152
+
153
+ def hm_to_rgb(R, X=None, scaling=3, shape=(), sigma=2, cmap='bwr', normalize=True):
154
+ '''
155
+ Takes as input an intensity array and produces a rgb image for the represented heatmap.
156
+ optionally draws the outline of another input on top of it.
157
+
158
+ Parameters
159
+ ----------
160
+
161
+ R : numpy.ndarray
162
+ the heatmap to be visualized, shaped [M x N]
163
+
164
+ X : numpy.ndarray
165
+ optional. some input, usually the data point for which the heatmap R is for, which shall serve
166
+ as a template for a black outline to be drawn on top of the image
167
+ shaped [M x N]
168
+
169
+ scaling: int
170
+ factor, on how to enlarge the heatmap (to control resolution and as a inverse way to control outline thickness)
171
+ after reshaping it using shape.
172
+
173
+ shape: tuple or list, length = 2
174
+ optional. if not given, X is reshaped to be square.
175
+
176
+ sigma : double
177
+ optional. sigma-parameter for the canny algorithm used for edge detection. the found edges are drawn as outlines.
178
+
179
+ cmap : str
180
+ optional. color map of choice
181
+
182
+ normalize : bool
183
+ optional. whether to normalize the heatmap to [-1 1] prior to colorization or not.
184
+
185
+ Returns
186
+ -------
187
+
188
+ rgbimg : numpy.ndarray
189
+ three-dimensional array of shape [scaling*H x scaling*W x 3] , where H*W == M*N
190
+ '''
191
+
192
+ # create color map object from name string
193
+ cmap = eval('matplotlib.cm.{}'.format(cmap))
194
+
195
+ if normalize:
196
+ R = R / np.max(np.abs(R)) # normalize to [-1,1] wrt to max relevance magnitude
197
+ R = (R + 1.) / 2. # shift/normalize to [0,1] for color mapping
198
+
199
+ R = enlarge_image(R, scaling)
200
+ rgb = cmap(R.flatten())[..., 0:3].reshape([R.shape[0], R.shape[1], 3])
201
+ # rgb = repaint_corner_pixels(rgb, scaling) #obsolete due to directly calling the color map with [0,1]-normalized inputs
202
+
203
+ if not X is None: # compute the outline of the input
204
+ # X = enlarge_image(vec2im(X,shape), scaling)
205
+ xdims = X.shape
206
+ Rdims = R.shape
207
+
208
+ # if not np.all(xdims == Rdims):
209
+ # print 'transformed heatmap and data dimension mismatch. data dimensions differ?'
210
+ # print 'R.shape = ',Rdims, 'X.shape = ', xdims
211
+ # print 'skipping drawing of outline\n'
212
+ # else:
213
+ # #edges = skimage.filters.canny(X, sigma=sigma)
214
+ # edges = skimage.feature.canny(X, sigma=sigma)
215
+ # edges = np.invert(np.dstack([edges]*3))*1.0
216
+ # rgb *= edges # set outline pixels to black color
217
+
218
+ return rgb
219
+
220
+
221
+ def save_image(rgb_images, path, gap=2):
222
+ '''
223
+ Takes as input a list of rgb images, places them next to each other with a gap and writes out the result.
224
+
225
+ Parameters
226
+ ----------
227
+
228
+ rgb_images : list , tuple, collection. such stuff
229
+ each item in the collection is expected to be an rgb image of dimensions [H x _ x 3]
230
+ where the width is variable
231
+
232
+ path : str
233
+ the output path of the assembled image
234
+
235
+ gap : int
236
+ optional. sets the width of a black area of pixels realized as an image shaped [H x gap x 3] in between the input images
237
+
238
+ Returns
239
+ -------
240
+
241
+ image : numpy.ndarray
242
+ the assembled image as written out to path
243
+ '''
244
+
245
+ sz = []
246
+ image = []
247
+ for i in range(len(rgb_images)):
248
+ if not sz:
249
+ sz = rgb_images[i].shape
250
+ image = rgb_images[i]
251
+ gap = np.zeros((sz[0], gap, sz[2]))
252
+ continue
253
+ if not sz[0] == rgb_images[i].shape[0] and sz[1] == rgb_images[i].shape[2]:
254
+ print('image', i, 'differs in size. unable to perform horizontal alignment')
255
+ print('expected: Hx_xD = {0}x_x{1}'.format(sz[0], sz[1]))
256
+ print('got : Hx_xD = {0}x_x{1}'.format(rgb_images[i].shape[0], rgb_images[i].shape[1]))
257
+ print('skipping image\n')
258
+ else:
259
+ image = np.hstack((image, gap, rgb_images[i]))
260
+
261
+ image *= 255
262
+ image = image.astype(np.uint8)
263
+
264
+ print('saving image to ', path)
265
+ skimage.io.imsave(path, image)
266
+ return image
SegmentationTest/utils/saver.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ import glob
5
+
6
+
7
+ class Saver(object):
8
+
9
+ def __init__(self, args):
10
+ self.args = args
11
+ self.directory = os.path.join('run', args.train_dataset, args.checkname)
12
+ self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*')))
13
+ run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0
14
+
15
+ self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)))
16
+ if not os.path.exists(self.experiment_dir):
17
+ os.makedirs(self.experiment_dir)
18
+
19
+ def save_checkpoint(self, state, filename='checkpoint.pth.tar'):
20
+ """Saves checkpoint to disk"""
21
+ filename = os.path.join(self.experiment_dir, filename)
22
+ torch.save(state, filename)
23
+
24
+ def save_experiment_config(self):
25
+ logfile = os.path.join(self.experiment_dir, 'parameters.txt')
26
+ log_file = open(logfile, 'w')
27
+ p = OrderedDict()
28
+ p['train_dataset'] = self.args.train_dataset
29
+ p['lr'] = self.args.lr
30
+ p['epoch'] = self.args.epochs
31
+
32
+ for key, val in p.items():
33
+ log_file.write(key + ':' + str(val) + '\n')
34
+ log_file.close()
SegmentationTest/utils/summaries.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.tensorboard import SummaryWriter
3
+
4
+
5
+ class TensorboardSummary(object):
6
+ def __init__(self, directory):
7
+ self.directory = directory
8
+ self.writer = SummaryWriter(log_dir=os.path.join(self.directory))
9
+
10
+ def add_scalar(self, *args):
11
+ self.writer.add_scalar(*args)
ViT/ViT.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Vision Transformer (ViT) in PyTorch
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+ from einops import rearrange
8
+
9
+ from ViT.helpers import load_pretrained
10
+ from ViT.weight_init import trunc_normal_
11
+ from ViT.layer_helpers import to_2tuple
12
+
13
+
14
+ def _cfg(url='', **kwargs):
15
+ return {
16
+ 'url': url,
17
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
18
+ 'crop_pct': .9, 'interpolation': 'bicubic',
19
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
20
+ **kwargs
21
+ }
22
+
23
+
24
+ default_cfgs = {
25
+ # patch models
26
+ 'vit_small_patch16_224': _cfg(
27
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
28
+ ),
29
+ 'vit_base_patch16_224': _cfg(
30
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
31
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
32
+ ),
33
+ 'vit_large_patch16_224': _cfg(
34
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
35
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
36
+
37
+ # deit models (FB weights)
38
+ 'deit_tiny_patch16_224': _cfg(
39
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
40
+ 'deit_small_patch16_224': _cfg(
41
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
42
+ 'deit_base_patch16_224': _cfg(
43
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', ),
44
+ 'deit_base_patch16_384': _cfg(
45
+ url='', # no weights yet
46
+ input_size=(3, 384, 384)),
47
+ }
48
+
49
+ class Mlp(nn.Module):
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ class Attention(nn.Module):
69
+ def __init__(self, dim, num_heads=8, qkv_bias=False,attn_drop=0., proj_drop=0.):
70
+ super().__init__()
71
+ self.num_heads = num_heads
72
+ head_dim = dim // num_heads
73
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
74
+ self.scale = head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
77
+ self.attn_drop = nn.Dropout(attn_drop)
78
+ self.proj = nn.Linear(dim, dim)
79
+ self.proj_drop = nn.Dropout(proj_drop)
80
+
81
+ self.attn_gradients = None
82
+ self.attention_map = None
83
+
84
+ def save_attn_gradients(self, attn_gradients):
85
+ self.attn_gradients = attn_gradients
86
+
87
+ def get_attn_gradients(self):
88
+ return self.attn_gradients
89
+
90
+ def save_attention_map(self, attention_map):
91
+ self.attention_map = attention_map
92
+
93
+ def get_attention_map(self):
94
+ return self.attention_map
95
+
96
+ def forward(self, x, register_hook=False, return_attentions=False):
97
+ b, n, _, h = *x.shape, self.num_heads
98
+
99
+ qkv = self.qkv(x)
100
+ q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)
101
+
102
+ dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
103
+
104
+ attn = dots.softmax(dim=-1)
105
+ attn = self.attn_drop(attn)
106
+
107
+ out = torch.einsum('bhij,bhjd->bhid', attn, v)
108
+
109
+ self.save_attention_map(attn)
110
+ if register_hook:
111
+ attn.register_hook(self.save_attn_gradients)
112
+
113
+ out = rearrange(out, 'b h n d -> b n (h d)')
114
+ out = self.proj(out)
115
+ out = self.proj_drop(out)
116
+ if not return_attentions:
117
+ return out
118
+ else:
119
+ return out, attn
120
+
121
+
122
+ class Block(nn.Module):
123
+
124
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
125
+ super().__init__()
126
+ self.norm1 = norm_layer(dim)
127
+ self.attn = Attention(
128
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
129
+ self.norm2 = norm_layer(dim)
130
+ mlp_hidden_dim = int(dim * mlp_ratio)
131
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
132
+
133
+ def forward(self, x, register_hook=False, return_attentions=False):
134
+ if not return_attentions:
135
+ x = x + self.attn(self.norm1(x), register_hook=register_hook)
136
+ else:
137
+ attn_res, attn = self.attn(self.norm1(x), register_hook=register_hook, return_attentions=True)
138
+ x = x + attn_res
139
+ x = x + self.mlp(self.norm2(x))
140
+ if not return_attentions:
141
+ return x
142
+ else:
143
+ return x, attn
144
+
145
+
146
+ class PatchEmbed(nn.Module):
147
+ """ Image to Patch Embedding
148
+ """
149
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
150
+ super().__init__()
151
+ img_size = to_2tuple(img_size)
152
+ patch_size = to_2tuple(patch_size)
153
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
154
+ self.img_size = img_size
155
+ self.patch_size = patch_size
156
+ self.num_patches = num_patches
157
+
158
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
159
+
160
+ def forward(self, x):
161
+ B, C, H, W = x.shape
162
+ # FIXME look at relaxing size constraints
163
+ assert H == self.img_size[0] and W == self.img_size[1], \
164
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
165
+ x = self.proj(x).flatten(2).transpose(1, 2)
166
+ return x
167
+
168
+ class VisionTransformer(nn.Module):
169
+ """ Vision Transformer
170
+ """
171
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
172
+ num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., norm_layer=nn.LayerNorm):
173
+ super().__init__()
174
+ self.num_classes = num_classes
175
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
176
+ self.patch_embed = PatchEmbed(
177
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
178
+ num_patches = self.patch_embed.num_patches
179
+
180
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
181
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
182
+ self.pos_drop = nn.Dropout(p=drop_rate)
183
+
184
+ self.blocks = nn.ModuleList([
185
+ Block(
186
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
187
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
188
+ for i in range(depth)])
189
+ self.norm = norm_layer(embed_dim)
190
+
191
+ # Classifier head
192
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
193
+
194
+ trunc_normal_(self.pos_embed, std=.02)
195
+ trunc_normal_(self.cls_token, std=.02)
196
+ self.apply(self._init_weights)
197
+
198
+ def _init_weights(self, m):
199
+ if isinstance(m, nn.Linear):
200
+ trunc_normal_(m.weight, std=.02)
201
+ if isinstance(m, nn.Linear) and m.bias is not None:
202
+ nn.init.constant_(m.bias, 0)
203
+ elif isinstance(m, nn.LayerNorm):
204
+ nn.init.constant_(m.bias, 0)
205
+ nn.init.constant_(m.weight, 1.0)
206
+
207
+ @torch.jit.ignore
208
+ def no_weight_decay(self):
209
+ return {'pos_embed', 'cls_token'}
210
+
211
+ def forward(self, x, register_hook=False, return_attentions=False):
212
+ if return_attentions:
213
+ attentions = []
214
+
215
+ B = x.shape[0]
216
+ x = self.patch_embed(x)
217
+
218
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
219
+ x = torch.cat((cls_tokens, x), dim=1)
220
+ x = x + self.pos_embed
221
+ x = self.pos_drop(x)
222
+
223
+ for blk in self.blocks:
224
+ if not return_attentions:
225
+ x = blk(x, register_hook=register_hook)
226
+ else:
227
+ x, attn = blk(x, register_hook=register_hook, return_attentions=True)
228
+ attentions.append(attn)
229
+
230
+ x = self.norm(x)
231
+ x = x[:, 0]
232
+ x = self.head(x)
233
+
234
+ if not return_attentions:
235
+ return x
236
+ else:
237
+ return x, torch.cat(attentions).unsqueeze(0)
238
+
239
+
240
+ def _conv_filter(state_dict, patch_size=16):
241
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
242
+ out_dict = {}
243
+ for k, v in state_dict.items():
244
+ if 'patch_embed.proj.weight' in k:
245
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
246
+ out_dict[k] = v
247
+ return out_dict
248
+
249
+
250
+ def vit_base_patch16_224(pretrained=False, **kwargs):
251
+ model = VisionTransformer(
252
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
253
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
254
+ model.default_cfg = default_cfgs['vit_base_patch16_224']
255
+ if pretrained:
256
+ load_pretrained(
257
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
258
+ return model
259
+
260
+
261
+ def vit_base_finetuned_patch16_224(pretrained=False, **kwargs):
262
+ model = VisionTransformer(
263
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
264
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
265
+ model.default_cfg = default_cfgs['vit_base_finetuned_patch16_224']
266
+ if pretrained:
267
+ load_pretrained(
268
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
269
+ return model
270
+
271
+ def vit_large_patch16_224(pretrained=False, **kwargs):
272
+ model = VisionTransformer(
273
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
274
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
275
+ model.default_cfg = default_cfgs['vit_large_patch16_224']
276
+ if pretrained:
277
+ load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
278
+ return model
279
+
280
+ def deit_tiny_patch16_224(pretrained=False, **kwargs):
281
+ model = VisionTransformer(
282
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
283
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
284
+ model.default_cfg = default_cfgs['deit_tiny_patch16_224']
285
+ if pretrained:
286
+ load_pretrained(
287
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
288
+ return model
289
+
290
+ def deit_small_patch16_224(pretrained=False, **kwargs):
291
+ model = VisionTransformer(
292
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
293
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
294
+ model.default_cfg = default_cfgs['deit_small_patch16_224']
295
+ if pretrained:
296
+ load_pretrained(
297
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
298
+ return model
299
+
300
+ def deit_base_patch16_224(pretrained=False, **kwargs):
301
+ model = VisionTransformer(
302
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
303
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
304
+ model.default_cfg = default_cfgs['deit_base_patch16_224']
305
+ if pretrained:
306
+ load_pretrained(
307
+ model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model'])
308
+ return model
ViT_new.py → ViT/ViT_new.py RENAMED
File without changes
ViT/__init__.py ADDED
File without changes
ViT/explainer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+
5
+ # rule 5 from paper
6
+ def avg_heads(cam, grad):
7
+ cam = cam.reshape(-1, cam.shape[-3], cam.shape[-2], cam.shape[-1])
8
+ grad = grad.reshape(-1, cam.shape[-3], grad.shape[-2], grad.shape[-1])
9
+ cam = grad * cam
10
+ cam = cam.clamp(min=0).mean(dim=1)
11
+ return cam
12
+
13
+ # rule 6 from paper
14
+ def apply_self_attention_rules(R_ss, cam_ss):
15
+ R_ss_addition = torch.matmul(cam_ss, R_ss)
16
+ return R_ss_addition
17
+
18
+ def upscale_relevance(relevance):
19
+ relevance = relevance.reshape(-1, 1, 14, 14)
20
+ relevance = torch.nn.functional.interpolate(relevance, scale_factor=16, mode='bilinear')
21
+
22
+ # normalize between 0 and 1
23
+ relevance = relevance.reshape(relevance.shape[0], -1)
24
+ min = relevance.min(1, keepdim=True)[0]
25
+ max = relevance.max(1, keepdim=True)[0]
26
+ relevance = (relevance - min) / (max - min)
27
+
28
+ relevance = relevance.reshape(-1, 1, 224, 224)
29
+ return relevance
30
+
31
+ def generate_relevance(model, input, index=None):
32
+ # a batch of samples
33
+ batch_size = input.shape[0]
34
+ output = model(input, register_hook=True)
35
+ if index == None:
36
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
37
+ index = torch.tensor(index)
38
+
39
+ one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
40
+ one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
41
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
42
+ one_hot = torch.sum(one_hot.to(input.device) * output)
43
+ model.zero_grad()
44
+
45
+ num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]
46
+ R = torch.eye(num_tokens, num_tokens).cuda()
47
+ R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
48
+ for i, blk in enumerate(model.blocks):
49
+ grad = torch.autograd.grad(one_hot, [blk.attn.attention_map], retain_graph=True)[0]
50
+ cam = blk.attn.get_attention_map()
51
+ cam = avg_heads(cam, grad)
52
+ R = R + apply_self_attention_rules(R, cam)
53
+ relevance = R[:, 0, 1:]
54
+ return upscale_relevance(relevance)
55
+
56
+ # create heatmap from mask on image
57
+ def show_cam_on_image(img, mask):
58
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
59
+ heatmap = np.float32(heatmap) / 255
60
+ cam = heatmap + np.float32(img)
61
+ cam = cam / np.max(cam)
62
+ return cam
63
+
64
+
65
+ def get_image_with_relevance(image, relevance):
66
+ image = image.permute(1, 2, 0)
67
+ relevance = relevance.permute(1, 2, 0)
68
+ image = (image - image.min()) / (image.max() - image.min())
69
+ image = 255 * image
70
+ vis = image * relevance
71
+ return vis.data.cpu().numpy()
ViT/helpers.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Model creation / weight loading / state_dict helpers
2
+
3
+ Hacked together by / Copyright 2020 Ross Wightman
4
+ """
5
+ import logging
6
+ import os
7
+ import math
8
+ from collections import OrderedDict
9
+ from copy import deepcopy
10
+ from typing import Callable
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils.model_zoo as model_zoo
15
+
16
+ _logger = logging.getLogger(__name__)
17
+
18
+
19
+ def load_state_dict(checkpoint_path, use_ema=False):
20
+ if checkpoint_path and os.path.isfile(checkpoint_path):
21
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
22
+ state_dict_key = 'state_dict'
23
+ if isinstance(checkpoint, dict):
24
+ if use_ema and 'state_dict_ema' in checkpoint:
25
+ state_dict_key = 'state_dict_ema'
26
+ if state_dict_key and state_dict_key in checkpoint:
27
+ new_state_dict = OrderedDict()
28
+ for k, v in checkpoint[state_dict_key].items():
29
+ # strip `module.` prefix
30
+ name = k[7:] if k.startswith('module') else k
31
+ new_state_dict[name] = v
32
+ state_dict = new_state_dict
33
+ else:
34
+ state_dict = checkpoint
35
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
36
+ return state_dict
37
+ else:
38
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
39
+ raise FileNotFoundError()
40
+
41
+
42
+ def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
43
+ state_dict = load_state_dict(checkpoint_path, use_ema)
44
+ model.load_state_dict(state_dict, strict=strict)
45
+
46
+
47
+ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
48
+ resume_epoch = None
49
+ if os.path.isfile(checkpoint_path):
50
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
51
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52
+ if log_info:
53
+ _logger.info('Restoring model state from checkpoint...')
54
+ new_state_dict = OrderedDict()
55
+ for k, v in checkpoint['state_dict'].items():
56
+ name = k[7:] if k.startswith('module') else k
57
+ new_state_dict[name] = v
58
+ model.load_state_dict(new_state_dict)
59
+
60
+ if optimizer is not None and 'optimizer' in checkpoint:
61
+ if log_info:
62
+ _logger.info('Restoring optimizer state from checkpoint...')
63
+ optimizer.load_state_dict(checkpoint['optimizer'])
64
+
65
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
66
+ if log_info:
67
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
68
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
69
+
70
+ if 'epoch' in checkpoint:
71
+ resume_epoch = checkpoint['epoch']
72
+ if 'version' in checkpoint and checkpoint['version'] > 1:
73
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
74
+
75
+ if log_info:
76
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
77
+ else:
78
+ model.load_state_dict(checkpoint)
79
+ if log_info:
80
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
81
+ return resume_epoch
82
+ else:
83
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
84
+ raise FileNotFoundError()
85
+
86
+
87
+ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
88
+ if cfg is None:
89
+ cfg = getattr(model, 'default_cfg')
90
+ if cfg is None or 'url' not in cfg or not cfg['url']:
91
+ _logger.warning("Pretrained model URL is invalid, using random initialization.")
92
+ return
93
+
94
+ state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
95
+
96
+ if filter_fn is not None:
97
+ state_dict = filter_fn(state_dict)
98
+
99
+ if in_chans == 1:
100
+ conv1_name = cfg['first_conv']
101
+ _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
102
+ conv1_weight = state_dict[conv1_name + '.weight']
103
+ # Some weights are in torch.half, ensure it's float for sum on CPU
104
+ conv1_type = conv1_weight.dtype
105
+ conv1_weight = conv1_weight.float()
106
+ O, I, J, K = conv1_weight.shape
107
+ if I > 3:
108
+ assert conv1_weight.shape[1] % 3 == 0
109
+ # For models with space2depth stems
110
+ conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
111
+ conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
112
+ else:
113
+ conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
114
+ conv1_weight = conv1_weight.to(conv1_type)
115
+ state_dict[conv1_name + '.weight'] = conv1_weight
116
+ elif in_chans != 3:
117
+ conv1_name = cfg['first_conv']
118
+ conv1_weight = state_dict[conv1_name + '.weight']
119
+ conv1_type = conv1_weight.dtype
120
+ conv1_weight = conv1_weight.float()
121
+ O, I, J, K = conv1_weight.shape
122
+ if I != 3:
123
+ _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
124
+ del state_dict[conv1_name + '.weight']
125
+ strict = False
126
+ else:
127
+ # NOTE this strategy should be better than random init, but there could be other combinations of
128
+ # the original RGB input layer weights that'd work better for specific cases.
129
+ _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
130
+ repeat = int(math.ceil(in_chans / 3))
131
+ conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
132
+ conv1_weight *= (3 / float(in_chans))
133
+ conv1_weight = conv1_weight.to(conv1_type)
134
+ state_dict[conv1_name + '.weight'] = conv1_weight
135
+
136
+ classifier_name = cfg['classifier']
137
+ if num_classes == 1000 and cfg['num_classes'] == 1001:
138
+ # special case for imagenet trained models with extra background class in pretrained weights
139
+ classifier_weight = state_dict[classifier_name + '.weight']
140
+ state_dict[classifier_name + '.weight'] = classifier_weight[1:]
141
+ classifier_bias = state_dict[classifier_name + '.bias']
142
+ state_dict[classifier_name + '.bias'] = classifier_bias[1:]
143
+ elif num_classes != cfg['num_classes']:
144
+ # completely discard fully connected for all other differences between pretrained and created model
145
+ del state_dict[classifier_name + '.weight']
146
+ del state_dict[classifier_name + '.bias']
147
+ strict = False
148
+
149
+ model.load_state_dict(state_dict, strict=strict)
150
+
151
+
152
+ def extract_layer(model, layer):
153
+ layer = layer.split('.')
154
+ module = model
155
+ if hasattr(model, 'module') and layer[0] != 'module':
156
+ module = model.module
157
+ if not hasattr(model, 'module') and layer[0] == 'module':
158
+ layer = layer[1:]
159
+ for l in layer:
160
+ if hasattr(module, l):
161
+ if not l.isdigit():
162
+ module = getattr(module, l)
163
+ else:
164
+ module = module[int(l)]
165
+ else:
166
+ return module
167
+ return module
168
+
169
+
170
+ def set_layer(model, layer, val):
171
+ layer = layer.split('.')
172
+ module = model
173
+ if hasattr(model, 'module') and layer[0] != 'module':
174
+ module = model.module
175
+ lst_index = 0
176
+ module2 = module
177
+ for l in layer:
178
+ if hasattr(module2, l):
179
+ if not l.isdigit():
180
+ module2 = getattr(module2, l)
181
+ else:
182
+ module2 = module2[int(l)]
183
+ lst_index += 1
184
+ lst_index -= 1
185
+ for l in layer[:lst_index]:
186
+ if not l.isdigit():
187
+ module = getattr(module, l)
188
+ else:
189
+ module = module[int(l)]
190
+ l = layer[lst_index]
191
+ setattr(module, l, val)
192
+
193
+
194
+ def adapt_model_from_string(parent_module, model_string):
195
+ separator = '***'
196
+ state_dict = {}
197
+ lst_shape = model_string.split(separator)
198
+ for k in lst_shape:
199
+ k = k.split(':')
200
+ key = k[0]
201
+ shape = k[1][1:-1].split(',')
202
+ if shape[0] != '':
203
+ state_dict[key] = [int(i) for i in shape]
204
+
205
+ new_module = deepcopy(parent_module)
206
+ for n, m in parent_module.named_modules():
207
+ old_module = extract_layer(parent_module, n)
208
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
209
+ if isinstance(old_module, Conv2dSame):
210
+ conv = Conv2dSame
211
+ else:
212
+ conv = nn.Conv2d
213
+ s = state_dict[n + '.weight']
214
+ in_channels = s[1]
215
+ out_channels = s[0]
216
+ g = 1
217
+ if old_module.groups > 1:
218
+ in_channels = out_channels
219
+ g = in_channels
220
+ new_conv = conv(
221
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
222
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
223
+ groups=g, stride=old_module.stride)
224
+ set_layer(new_module, n, new_conv)
225
+ if isinstance(old_module, nn.BatchNorm2d):
226
+ new_bn = nn.BatchNorm2d(
227
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
228
+ affine=old_module.affine, track_running_stats=True)
229
+ set_layer(new_module, n, new_bn)
230
+ if isinstance(old_module, nn.Linear):
231
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
232
+ num_features = state_dict[n + '.weight'][1]
233
+ new_fc = nn.Linear(
234
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
235
+ set_layer(new_module, n, new_fc)
236
+ if hasattr(new_module, 'num_features'):
237
+ new_module.num_features = num_features
238
+ new_module.eval()
239
+ parent_module.eval()
240
+
241
+ return new_module
242
+
243
+
244
+ def adapt_model_from_file(parent_module, model_variant):
245
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
246
+ with open(adapt_file, 'r') as f:
247
+ return adapt_model_from_string(parent_module, f.read().strip())
248
+
249
+
250
+ def build_model_with_cfg(
251
+ model_cls: Callable,
252
+ variant: str,
253
+ pretrained: bool,
254
+ default_cfg: dict,
255
+ model_cfg: dict = None,
256
+ feature_cfg: dict = None,
257
+ pretrained_strict: bool = True,
258
+ pretrained_filter_fn: Callable = None,
259
+ **kwargs):
260
+ pruned = kwargs.pop('pruned', False)
261
+ features = False
262
+ feature_cfg = feature_cfg or {}
263
+
264
+ if kwargs.pop('features_only', False):
265
+ features = True
266
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
267
+ if 'out_indices' in kwargs:
268
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
269
+
270
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
271
+ model.default_cfg = deepcopy(default_cfg)
272
+
273
+ if pruned:
274
+ model = adapt_model_from_file(model, variant)
275
+
276
+ if pretrained:
277
+ load_pretrained(
278
+ model,
279
+ num_classes=kwargs.get('num_classes', 0),
280
+ in_chans=kwargs.get('in_chans', 3),
281
+ filter_fn=pretrained_filter_fn, strict=pretrained_strict)
282
+
283
+ if features:
284
+ feature_cls = FeatureListNet
285
+ if 'feature_cls' in feature_cfg:
286
+ feature_cls = feature_cfg.pop('feature_cls')
287
+ if isinstance(feature_cls, str):
288
+ feature_cls = feature_cls.lower()
289
+ if 'hook' in feature_cls:
290
+ feature_cls = FeatureHookNet
291
+ else:
292
+ assert False, f'Unknown feature class {feature_cls}'
293
+ model = feature_cls(model, **feature_cfg)
294
+
295
+ return model
ViT/layer_helpers.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Layer/Module Helpers
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ from itertools import repeat
5
+ import collections.abc
6
+
7
+
8
+ # From PyTorch internals
9
+ def _ntuple(n):
10
+ def parse(x):
11
+ if isinstance(x, collections.abc.Iterable):
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+ to_ntuple = _ntuple
ViT/weight_init.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import warnings
4
+
5
+
6
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
8
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9
+ def norm_cdf(x):
10
+ # Computes standard normal cumulative distribution function
11
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
12
+
13
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
14
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15
+ "The distribution of values may be incorrect.",
16
+ stacklevel=2)
17
+
18
+ with torch.no_grad():
19
+ # Values are generated by using a truncated uniform distribution and
20
+ # then using the inverse CDF for the normal distribution.
21
+ # Get upper and lower cdf values
22
+ l = norm_cdf((a - mean) / std)
23
+ u = norm_cdf((b - mean) / std)
24
+
25
+ # Uniformly fill tensor with values from [l, u], then translate to
26
+ # [2l-1, 2u-1].
27
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
28
+
29
+ # Use inverse cdf transform for normal distribution to get truncated
30
+ # standard normal
31
+ tensor.erfinv_()
32
+
33
+ # Transform to proper mean, std
34
+ tensor.mul_(std * math.sqrt(2.))
35
+ tensor.add_(mean)
36
+
37
+ # Clamp to ensure it's in the proper range
38
+ tensor.clamp_(min=a, max=b)
39
+ return tensor
40
+
41
+
42
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43
+ # type: (Tensor, float, float, float, float) -> Tensor
44
+ r"""Fills the input Tensor with values drawn from a truncated
45
+ normal distribution. The values are effectively drawn from the
46
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
47
+ with values outside :math:`[a, b]` redrawn until they are within
48
+ the bounds. The method used for generating the random values works
49
+ best when :math:`a \leq \text{mean} \leq b`.
50
+ Args:
51
+ tensor: an n-dimensional `torch.Tensor`
52
+ mean: the mean of the normal distribution
53
+ std: the standard deviation of the normal distribution
54
+ a: the minimum cutoff value
55
+ b: the maximum cutoff value
56
+ Examples:
57
+ >>> w = torch.empty(3, 5)
58
+ >>> nn.init.trunc_normal_(w)
59
+ """
60
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
imagenet_ablation_gt.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+ from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
21
+
22
+ # Uncomment the expected model below
23
+
24
+ # ViT
25
+ from ViT.ViT import vit_base_patch16_224 as vit
26
+ # from ViT.ViT import vit_large_patch16_224 as vit
27
+
28
+ # ViT-AugReg
29
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
30
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
31
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
32
+
33
+ # DeiT
34
+ # from ViT.ViT import deit_base_patch16_224 as vit
35
+ # from ViT.ViT import deit_small_patch16_224 as vit
36
+
37
+ from ViT.explainer import generate_relevance, get_image_with_relevance
38
+ import torchvision
39
+ import cv2
40
+ from torch.utils.tensorboard import SummaryWriter
41
+ import json
42
+
43
+ model_names = sorted(name for name in models.__dict__
44
+ if name.islower() and not name.startswith("__")
45
+ and callable(models.__dict__[name]))
46
+ model_names.append("vit")
47
+
48
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
49
+ parser.add_argument('--data', metavar='DATA',
50
+ help='path to dataset')
51
+ parser.add_argument('--seg_data', metavar='SEG_DATA',
52
+ help='path to segmentation dataset')
53
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
54
+ choices=model_names,
55
+ help='model architecture: ' +
56
+ ' | '.join(model_names) +
57
+ ' (default: resnet18)')
58
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
59
+ help='number of data loading workers (default: 4)')
60
+ parser.add_argument('--epochs', default=150, type=int, metavar='N',
61
+ help='number of total epochs to run')
62
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
63
+ help='manual epoch number (useful on restarts)')
64
+ parser.add_argument('-b', '--batch-size', default=8, type=int,
65
+ metavar='N',
66
+ help='mini-batch size (default: 256), this is the total '
67
+ 'batch size of all GPUs on the current node when '
68
+ 'using Data Parallel or Distributed Data Parallel')
69
+ parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
70
+ metavar='LR', help='initial learning rate', dest='lr')
71
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
72
+ help='momentum')
73
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
74
+ metavar='W', help='weight decay (default: 1e-4)',
75
+ dest='weight_decay')
76
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
77
+ metavar='N', help='print frequency (default: 10)')
78
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
79
+ help='path to latest checkpoint (default: none)')
80
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
81
+ help='evaluate model on validation set')
82
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
83
+ help='use pre-trained model')
84
+ parser.add_argument('--world-size', default=-1, type=int,
85
+ help='number of nodes for distributed training')
86
+ parser.add_argument('--rank', default=-1, type=int,
87
+ help='node rank for distributed training')
88
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
89
+ help='url used to set up distributed training')
90
+ parser.add_argument('--dist-backend', default='nccl', type=str,
91
+ help='distributed backend')
92
+ parser.add_argument('--seed', default=None, type=int,
93
+ help='seed for initializing training. ')
94
+ parser.add_argument('--gpu', default=None, type=int,
95
+ help='GPU id to use.')
96
+ parser.add_argument('--save_interval', default=20, type=int,
97
+ help='interval to save segmentation results.')
98
+ parser.add_argument('--num_samples', default=3, type=int,
99
+ help='number of samples per class for training')
100
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
101
+ help='Use multi-processing distributed training to launch '
102
+ 'N processes per node, which has N GPUs. This is the '
103
+ 'fastest way to use PyTorch for either single node or '
104
+ 'multi node data parallel training')
105
+ parser.add_argument('--lambda_seg', default=0.8, type=float,
106
+ help='influence of segmentation loss.')
107
+ parser.add_argument('--lambda_acc', default=0.2, type=float,
108
+ help='influence of accuracy loss.')
109
+ parser.add_argument('--experiment_folder', default=None, type=str,
110
+ help='path to folder to use for experiment.')
111
+ parser.add_argument('--dilation', default=0, type=float,
112
+ help='Use dilation on the segmentation maps.')
113
+ parser.add_argument('--lambda_background', default=2, type=float,
114
+ help='coefficient of loss for segmentation background.')
115
+ parser.add_argument('--lambda_foreground', default=0.3, type=float,
116
+ help='coefficient of loss for segmentation foreground.')
117
+ parser.add_argument('--num_classes', default=500, type=int,
118
+ help='coefficient of loss for segmentation foreground.')
119
+ parser.add_argument('--temperature', default=1, type=float,
120
+ help='temperature for softmax (mostly for DeiT).')
121
+
122
+ best_loss = float('inf')
123
+
124
+ def main():
125
+ args = parser.parse_args()
126
+
127
+ if args.experiment_folder is None:
128
+ args.experiment_folder = f'experiment/' \
129
+ f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}' \
130
+ f'_bckg_{args.lambda_background}_fgd_{args.lambda_foreground}'
131
+ if args.temperature != 1:
132
+ args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
133
+ if args.batch_size != 8:
134
+ args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
135
+ if args.num_classes != 500:
136
+ args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
137
+ if args.num_samples != 3:
138
+ args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
139
+ if args.epochs != 150:
140
+ args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
141
+
142
+ if os.path.exists(args.experiment_folder):
143
+ raise Exception(f"Experiment path {args.experiment_folder} already exists!")
144
+ os.mkdir(args.experiment_folder)
145
+ os.mkdir(f'{args.experiment_folder}/train_samples')
146
+ os.mkdir(f'{args.experiment_folder}/val_samples')
147
+
148
+ with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
149
+ json.dump(args.__dict__, f, indent=2)
150
+
151
+ if args.seed is not None:
152
+ random.seed(args.seed)
153
+ torch.manual_seed(args.seed)
154
+ cudnn.deterministic = True
155
+ warnings.warn('You have chosen to seed training. '
156
+ 'This will turn on the CUDNN deterministic setting, '
157
+ 'which can slow down your training considerably! '
158
+ 'You may see unexpected behavior when restarting '
159
+ 'from checkpoints.')
160
+
161
+ if args.gpu is not None:
162
+ warnings.warn('You have chosen a specific GPU. This will completely '
163
+ 'disable data parallelism.')
164
+
165
+ if args.dist_url == "env://" and args.world_size == -1:
166
+ args.world_size = int(os.environ["WORLD_SIZE"])
167
+
168
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
169
+
170
+ ngpus_per_node = torch.cuda.device_count()
171
+ if args.multiprocessing_distributed:
172
+ # Since we have ngpus_per_node processes per node, the total world_size
173
+ # needs to be adjusted accordingly
174
+ args.world_size = ngpus_per_node * args.world_size
175
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
176
+ # main_worker process function
177
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
178
+ else:
179
+ # Simply call main_worker function
180
+ main_worker(args.gpu, ngpus_per_node, args)
181
+
182
+
183
+ def main_worker(gpu, ngpus_per_node, args):
184
+ global best_loss
185
+ args.gpu = gpu
186
+
187
+ if args.gpu is not None:
188
+ print("Use GPU: {} for training".format(args.gpu))
189
+
190
+ if args.distributed:
191
+ if args.dist_url == "env://" and args.rank == -1:
192
+ args.rank = int(os.environ["RANK"])
193
+ if args.multiprocessing_distributed:
194
+ # For multiprocessing distributed training, rank needs to be the
195
+ # global rank among all the processes
196
+ args.rank = args.rank * ngpus_per_node + gpu
197
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
198
+ world_size=args.world_size, rank=args.rank)
199
+ # create model
200
+ if args.pretrained:
201
+ print("=> using pre-trained model '{}'".format(args.arch))
202
+ model = models.__dict__[args.arch](pretrained=True)
203
+ else:
204
+ print("=> creating model '{}'".format(args.arch))
205
+ #model = models.__dict__[args.arch]()
206
+ model = vit(pretrained=True).cuda()
207
+ model.train()
208
+ print("done")
209
+
210
+ if not torch.cuda.is_available():
211
+ print('using CPU, this will be slow')
212
+ elif args.distributed:
213
+ # For multiprocessing distributed, DistributedDataParallel constructor
214
+ # should always set the single device scope, otherwise,
215
+ # DistributedDataParallel will use all available devices.
216
+ if args.gpu is not None:
217
+ torch.cuda.set_device(args.gpu)
218
+ model.cuda(args.gpu)
219
+ # When using a single GPU per process and per
220
+ # DistributedDataParallel, we need to divide the batch size
221
+ # ourselves based on the total number of GPUs we have
222
+ args.batch_size = int(args.batch_size / ngpus_per_node)
223
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
224
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
225
+ else:
226
+ model.cuda()
227
+ # DistributedDataParallel will divide and allocate batch_size to all
228
+ # available GPUs if device_ids are not set
229
+ model = torch.nn.parallel.DistributedDataParallel(model)
230
+ elif args.gpu is not None:
231
+ torch.cuda.set_device(args.gpu)
232
+ model = model.cuda(args.gpu)
233
+ else:
234
+ # DataParallel will divide and allocate batch_size to all available GPUs
235
+ if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
236
+ model.features = torch.nn.DataParallel(model.features)
237
+ model.cuda()
238
+ else:
239
+ print("start")
240
+ model = torch.nn.DataParallel(model).cuda()
241
+
242
+ # define loss function (criterion) and optimizer
243
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
244
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
245
+
246
+ # optionally resume from a checkpoint
247
+ if args.resume:
248
+ if os.path.isfile(args.resume):
249
+ print("=> loading checkpoint '{}'".format(args.resume))
250
+ if args.gpu is None:
251
+ checkpoint = torch.load(args.resume)
252
+ else:
253
+ # Map model to be loaded to specified single gpu.
254
+ loc = 'cuda:{}'.format(args.gpu)
255
+ checkpoint = torch.load(args.resume, map_location=loc)
256
+ args.start_epoch = checkpoint['epoch']
257
+ best_loss = checkpoint['best_loss']
258
+ if args.gpu is not None:
259
+ # best_loss may be from a checkpoint from a different GPU
260
+ best_loss = best_loss.to(args.gpu)
261
+ model.load_state_dict(checkpoint['state_dict'])
262
+ optimizer.load_state_dict(checkpoint['optimizer'])
263
+ print("=> loaded checkpoint '{}' (epoch {})"
264
+ .format(args.resume, checkpoint['epoch']))
265
+ else:
266
+ print("=> no checkpoint found at '{}'".format(args.resume))
267
+
268
+ cudnn.benchmark = True
269
+
270
+ train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
271
+ num_samples=args.num_samples)
272
+
273
+ if args.distributed:
274
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
275
+ else:
276
+ train_sampler = None
277
+
278
+ train_loader = torch.utils.data.DataLoader(
279
+ train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
280
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
281
+
282
+ val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
283
+ num_samples=1)
284
+
285
+ val_loader = torch.utils.data.DataLoader(
286
+ val_dataset, batch_size=10, shuffle=False,
287
+ num_workers=args.workers, pin_memory=True)
288
+
289
+ if args.evaluate:
290
+ validate(val_loader, model, criterion, 0, args)
291
+ return
292
+
293
+ for epoch in range(args.start_epoch, args.epochs):
294
+ if args.distributed:
295
+ train_sampler.set_epoch(epoch)
296
+ adjust_learning_rate(optimizer, epoch, args)
297
+
298
+ log_dir = os.path.join(args.experiment_folder, 'logs')
299
+ logger = SummaryWriter(log_dir=log_dir)
300
+ args.logger = logger
301
+
302
+ # train for one epoch
303
+ train(train_loader, model, criterion, optimizer, epoch, args)
304
+
305
+ # evaluate on validation set
306
+ loss1 = validate(val_loader, model, criterion, epoch, args)
307
+
308
+ # remember best acc@1 and save checkpoint
309
+ is_best = loss1 <= best_loss
310
+ best_loss = min(loss1, best_loss)
311
+
312
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
313
+ and args.rank % ngpus_per_node == 0):
314
+ save_checkpoint({
315
+ 'epoch': epoch + 1,
316
+ 'arch': args.arch,
317
+ 'state_dict': model.state_dict(),
318
+ 'best_loss': best_loss,
319
+ 'optimizer' : optimizer.state_dict(),
320
+ }, is_best, folder=args.experiment_folder)
321
+
322
+
323
+ def train(train_loader, model, criterion, optimizer, epoch, args):
324
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
325
+
326
+ losses = AverageMeter('Loss', ':.4e')
327
+ top1 = AverageMeter('Acc@1', ':6.2f')
328
+ top5 = AverageMeter('Acc@5', ':6.2f')
329
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
330
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
331
+ progress = ProgressMeter(
332
+ len(train_loader),
333
+ [losses, top1, top5, orig_top1, orig_top5],
334
+ prefix="Epoch: [{}]".format(epoch))
335
+
336
+ orig_model = vit(pretrained=True).cuda()
337
+ orig_model.eval()
338
+
339
+ # switch to train mode
340
+ model.train()
341
+
342
+ for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
343
+ if torch.cuda.is_available():
344
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
345
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
346
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
347
+
348
+ # segmentation loss
349
+ relevance = generate_relevance(model, image_ten, index=class_name)
350
+
351
+ reverse_seg_map = seg_map.clone()
352
+ reverse_seg_map[reverse_seg_map == 1] = -1
353
+ reverse_seg_map[reverse_seg_map == 0] = 1
354
+ reverse_seg_map[reverse_seg_map == -1] = 0
355
+ background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
356
+ foreground_loss = mse_criterion(relevance * seg_map, seg_map)
357
+ segmentation_loss = args.lambda_background * background_loss
358
+ segmentation_loss += args.lambda_foreground * foreground_loss
359
+
360
+ # classification loss
361
+ output = model(image_ten)
362
+ with torch.no_grad():
363
+ output_orig = orig_model(image_ten)
364
+
365
+ _, pred = output.topk(1, 1, True, True)
366
+ pred = pred.flatten()
367
+
368
+ if args.temperature != 1:
369
+ output = output / args.temperature
370
+ classification_loss = criterion(output, class_name.flatten())
371
+
372
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
373
+
374
+ # debugging output
375
+ if i % args.save_interval == 0:
376
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
377
+ for j in range(image_ten.shape[0]):
378
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
379
+ new_vis = get_image_with_relevance(image_ten[j], relevance[j])
380
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
381
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
382
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
383
+ cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
384
+
385
+ # measure accuracy and record loss
386
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
387
+ losses.update(loss.item(), image_ten.size(0))
388
+ top1.update(acc1[0], image_ten.size(0))
389
+ top5.update(acc5[0], image_ten.size(0))
390
+
391
+ # metrics for original vit
392
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
393
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
394
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
395
+
396
+ # compute gradient and do SGD step
397
+ optimizer.zero_grad()
398
+ loss.backward()
399
+ optimizer.step()
400
+
401
+ if i % args.print_freq == 0:
402
+ progress.display(i)
403
+ args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
404
+ epoch*len(train_loader)+i)
405
+ args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
406
+ epoch * len(train_loader) + i)
407
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
408
+ epoch * len(train_loader) + i)
409
+ args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
410
+ epoch * len(train_loader) + i)
411
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
412
+ epoch * len(train_loader) + i)
413
+ args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
414
+ epoch * len(train_loader) + i)
415
+ args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
416
+ epoch * len(train_loader) + i)
417
+
418
+
419
+ def validate(val_loader, model, criterion, epoch, args):
420
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
421
+
422
+ losses = AverageMeter('Loss', ':.4e')
423
+ top1 = AverageMeter('Acc@1', ':6.2f')
424
+ top5 = AverageMeter('Acc@5', ':6.2f')
425
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
426
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
427
+ progress = ProgressMeter(
428
+ len(val_loader),
429
+ [losses, top1, top5, orig_top1, orig_top5],
430
+ prefix="Epoch: [{}]".format(val_loader))
431
+
432
+ # switch to evaluate mode
433
+ model.eval()
434
+
435
+ orig_model = vit(pretrained=True).cuda()
436
+ orig_model.eval()
437
+
438
+ with torch.no_grad():
439
+ for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
440
+ if args.gpu is not None:
441
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
442
+ if torch.cuda.is_available():
443
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
444
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
445
+
446
+ # segmentation loss
447
+ with torch.enable_grad():
448
+ relevance = generate_relevance(model, image_ten, index=class_name)
449
+
450
+ reverse_seg_map = seg_map.clone()
451
+ reverse_seg_map[reverse_seg_map == 1] = -1
452
+ reverse_seg_map[reverse_seg_map == 0] = 1
453
+ reverse_seg_map[reverse_seg_map == -1] = 0
454
+ background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
455
+ foreground_loss = mse_criterion(relevance * seg_map, seg_map)
456
+ segmentation_loss = args.lambda_background * background_loss
457
+ segmentation_loss += args.lambda_foreground * foreground_loss
458
+
459
+ # classification loss
460
+ with torch.no_grad():
461
+ output = model(image_ten)
462
+ output_orig = orig_model(image_ten)
463
+
464
+ _, pred = output.topk(1, 1, True, True)
465
+ pred = pred.flatten()
466
+ if args.temperature != 1:
467
+ output = output / args.temperature
468
+ classification_loss = criterion(output, class_name.flatten())
469
+
470
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
471
+
472
+ # save results
473
+ if i % args.save_interval == 0:
474
+ with torch.enable_grad():
475
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
476
+ for j in range(image_ten.shape[0]):
477
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
478
+ new_vis = get_image_with_relevance(image_ten[j], relevance[j])
479
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
480
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
481
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
482
+ cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
483
+
484
+ # measure accuracy and record loss
485
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
486
+ losses.update(loss.item(), image_ten.size(0))
487
+ top1.update(acc1[0], image_ten.size(0))
488
+ top5.update(acc5[0], image_ten.size(0))
489
+
490
+ # metrics for original vit
491
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
492
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
493
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
494
+
495
+ if i % args.print_freq == 0:
496
+ progress.display(i)
497
+ args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
498
+ epoch * len(val_loader) + i)
499
+ args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
500
+ epoch * len(val_loader) + i)
501
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
502
+ epoch * len(val_loader) + i)
503
+ args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
504
+ epoch * len(val_loader) + i)
505
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
506
+ epoch * len(val_loader) + i)
507
+ args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
508
+ epoch * len(val_loader) + i)
509
+ args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
510
+ epoch * len(val_loader) + i)
511
+
512
+ # TODO: this should also be done with the ProgressMeter
513
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
514
+ .format(top1=top1, top5=top5))
515
+
516
+ return losses.avg
517
+
518
+
519
+ def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
520
+ torch.save(state, f'{folder}/{filename}')
521
+ if is_best:
522
+ shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
523
+
524
+
525
+ class AverageMeter(object):
526
+ """Computes and stores the average and current value"""
527
+ def __init__(self, name, fmt=':f'):
528
+ self.name = name
529
+ self.fmt = fmt
530
+ self.reset()
531
+
532
+ def reset(self):
533
+ self.val = 0
534
+ self.avg = 0
535
+ self.sum = 0
536
+ self.count = 0
537
+
538
+ def update(self, val, n=1):
539
+ self.val = val
540
+ self.sum += val * n
541
+ self.count += n
542
+ self.avg = self.sum / self.count
543
+
544
+ def __str__(self):
545
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
546
+ return fmtstr.format(**self.__dict__)
547
+
548
+
549
+ class ProgressMeter(object):
550
+ def __init__(self, num_batches, meters, prefix=""):
551
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
552
+ self.meters = meters
553
+ self.prefix = prefix
554
+
555
+ def display(self, batch):
556
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
557
+ entries += [str(meter) for meter in self.meters]
558
+ print('\t'.join(entries))
559
+
560
+ def _get_batch_fmtstr(self, num_batches):
561
+ num_digits = len(str(num_batches // 1))
562
+ fmt = '{:' + str(num_digits) + 'd}'
563
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
564
+
565
+ def adjust_learning_rate(optimizer, epoch, args):
566
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
567
+ lr = args.lr * (0.85 ** (epoch // 2))
568
+ for param_group in optimizer.param_groups:
569
+ param_group['lr'] = lr
570
+
571
+
572
+ def accuracy(output, target, topk=(1,)):
573
+ """Computes the accuracy over the k top predictions for the specified values of k"""
574
+ with torch.no_grad():
575
+ maxk = max(topk)
576
+ batch_size = target.size(0)
577
+
578
+ _, pred = output.topk(maxk, 1, True, True)
579
+ pred = pred.t()
580
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
581
+
582
+ res = []
583
+ for k in topk:
584
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
585
+ res.append(correct_k.mul_(100.0 / batch_size))
586
+ return res
587
+
588
+
589
+ if __name__ == '__main__':
590
+ main()
imagenet_classes.json ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n01440764": 0,
3
+ "n01443537": 1,
4
+ "n01484850": 2,
5
+ "n01491361": 3,
6
+ "n01494475": 4,
7
+ "n01496331": 5,
8
+ "n01498041": 6,
9
+ "n01514668": 7,
10
+ "n01514859": 8,
11
+ "n01518878": 9,
12
+ "n01530575": 10,
13
+ "n01531178": 11,
14
+ "n01532829": 12,
15
+ "n01534433": 13,
16
+ "n01537544": 14,
17
+ "n01558993": 15,
18
+ "n01560419": 16,
19
+ "n01580077": 17,
20
+ "n01582220": 18,
21
+ "n01592084": 19,
22
+ "n01601694": 20,
23
+ "n01608432": 21,
24
+ "n01614925": 22,
25
+ "n01616318": 23,
26
+ "n01622779": 24,
27
+ "n01629819": 25,
28
+ "n01630670": 26,
29
+ "n01631663": 27,
30
+ "n01632458": 28,
31
+ "n01632777": 29,
32
+ "n01641577": 30,
33
+ "n01644373": 31,
34
+ "n01644900": 32,
35
+ "n01664065": 33,
36
+ "n01665541": 34,
37
+ "n01667114": 35,
38
+ "n01667778": 36,
39
+ "n01669191": 37,
40
+ "n01675722": 38,
41
+ "n01677366": 39,
42
+ "n01682714": 40,
43
+ "n01685808": 41,
44
+ "n01687978": 42,
45
+ "n01688243": 43,
46
+ "n01689811": 44,
47
+ "n01692333": 45,
48
+ "n01693334": 46,
49
+ "n01694178": 47,
50
+ "n01695060": 48,
51
+ "n01697457": 49,
52
+ "n01698640": 50,
53
+ "n01704323": 51,
54
+ "n01728572": 52,
55
+ "n01728920": 53,
56
+ "n01729322": 54,
57
+ "n01729977": 55,
58
+ "n01734418": 56,
59
+ "n01735189": 57,
60
+ "n01737021": 58,
61
+ "n01739381": 59,
62
+ "n01740131": 60,
63
+ "n01742172": 61,
64
+ "n01744401": 62,
65
+ "n01748264": 63,
66
+ "n01749939": 64,
67
+ "n01751748": 65,
68
+ "n01753488": 66,
69
+ "n01755581": 67,
70
+ "n01756291": 68,
71
+ "n01768244": 69,
72
+ "n01770081": 70,
73
+ "n01770393": 71,
74
+ "n01773157": 72,
75
+ "n01773549": 73,
76
+ "n01773797": 74,
77
+ "n01774384": 75,
78
+ "n01774750": 76,
79
+ "n01775062": 77,
80
+ "n01776313": 78,
81
+ "n01784675": 79,
82
+ "n01795545": 80,
83
+ "n01796340": 81,
84
+ "n01797886": 82,
85
+ "n01798484": 83,
86
+ "n01806143": 84,
87
+ "n01806567": 85,
88
+ "n01807496": 86,
89
+ "n01817953": 87,
90
+ "n01818515": 88,
91
+ "n01819313": 89,
92
+ "n01820546": 90,
93
+ "n01824575": 91,
94
+ "n01828970": 92,
95
+ "n01829413": 93,
96
+ "n01833805": 94,
97
+ "n01843065": 95,
98
+ "n01843383": 96,
99
+ "n01847000": 97,
100
+ "n01855032": 98,
101
+ "n01855672": 99,
102
+ "n01860187": 100,
103
+ "n01871265": 101,
104
+ "n01872401": 102,
105
+ "n01873310": 103,
106
+ "n01877812": 104,
107
+ "n01882714": 105,
108
+ "n01883070": 106,
109
+ "n01910747": 107,
110
+ "n01914609": 108,
111
+ "n01917289": 109,
112
+ "n01924916": 110,
113
+ "n01930112": 111,
114
+ "n01943899": 112,
115
+ "n01944390": 113,
116
+ "n01945685": 114,
117
+ "n01950731": 115,
118
+ "n01955084": 116,
119
+ "n01968897": 117,
120
+ "n01978287": 118,
121
+ "n01978455": 119,
122
+ "n01980166": 120,
123
+ "n01981276": 121,
124
+ "n01983481": 122,
125
+ "n01984695": 123,
126
+ "n01985128": 124,
127
+ "n01986214": 125,
128
+ "n01990800": 126,
129
+ "n02002556": 127,
130
+ "n02002724": 128,
131
+ "n02006656": 129,
132
+ "n02007558": 130,
133
+ "n02009229": 131,
134
+ "n02009912": 132,
135
+ "n02011460": 133,
136
+ "n02012849": 134,
137
+ "n02013706": 135,
138
+ "n02017213": 136,
139
+ "n02018207": 137,
140
+ "n02018795": 138,
141
+ "n02025239": 139,
142
+ "n02027492": 140,
143
+ "n02028035": 141,
144
+ "n02033041": 142,
145
+ "n02037110": 143,
146
+ "n02051845": 144,
147
+ "n02056570": 145,
148
+ "n02058221": 146,
149
+ "n02066245": 147,
150
+ "n02071294": 148,
151
+ "n02074367": 149,
152
+ "n02077923": 150,
153
+ "n02085620": 151,
154
+ "n02085782": 152,
155
+ "n02085936": 153,
156
+ "n02086079": 154,
157
+ "n02086240": 155,
158
+ "n02086646": 156,
159
+ "n02086910": 157,
160
+ "n02087046": 158,
161
+ "n02087394": 159,
162
+ "n02088094": 160,
163
+ "n02088238": 161,
164
+ "n02088364": 162,
165
+ "n02088466": 163,
166
+ "n02088632": 164,
167
+ "n02089078": 165,
168
+ "n02089867": 166,
169
+ "n02089973": 167,
170
+ "n02090379": 168,
171
+ "n02090622": 169,
172
+ "n02090721": 170,
173
+ "n02091032": 171,
174
+ "n02091134": 172,
175
+ "n02091244": 173,
176
+ "n02091467": 174,
177
+ "n02091635": 175,
178
+ "n02091831": 176,
179
+ "n02092002": 177,
180
+ "n02092339": 178,
181
+ "n02093256": 179,
182
+ "n02093428": 180,
183
+ "n02093647": 181,
184
+ "n02093754": 182,
185
+ "n02093859": 183,
186
+ "n02093991": 184,
187
+ "n02094114": 185,
188
+ "n02094258": 186,
189
+ "n02094433": 187,
190
+ "n02095314": 188,
191
+ "n02095570": 189,
192
+ "n02095889": 190,
193
+ "n02096051": 191,
194
+ "n02096177": 192,
195
+ "n02096294": 193,
196
+ "n02096437": 194,
197
+ "n02096585": 195,
198
+ "n02097047": 196,
199
+ "n02097130": 197,
200
+ "n02097209": 198,
201
+ "n02097298": 199,
202
+ "n02097474": 200,
203
+ "n02097658": 201,
204
+ "n02098105": 202,
205
+ "n02098286": 203,
206
+ "n02098413": 204,
207
+ "n02099267": 205,
208
+ "n02099429": 206,
209
+ "n02099601": 207,
210
+ "n02099712": 208,
211
+ "n02099849": 209,
212
+ "n02100236": 210,
213
+ "n02100583": 211,
214
+ "n02100735": 212,
215
+ "n02100877": 213,
216
+ "n02101006": 214,
217
+ "n02101388": 215,
218
+ "n02101556": 216,
219
+ "n02102040": 217,
220
+ "n02102177": 218,
221
+ "n02102318": 219,
222
+ "n02102480": 220,
223
+ "n02102973": 221,
224
+ "n02104029": 222,
225
+ "n02104365": 223,
226
+ "n02105056": 224,
227
+ "n02105162": 225,
228
+ "n02105251": 226,
229
+ "n02105412": 227,
230
+ "n02105505": 228,
231
+ "n02105641": 229,
232
+ "n02105855": 230,
233
+ "n02106030": 231,
234
+ "n02106166": 232,
235
+ "n02106382": 233,
236
+ "n02106550": 234,
237
+ "n02106662": 235,
238
+ "n02107142": 236,
239
+ "n02107312": 237,
240
+ "n02107574": 238,
241
+ "n02107683": 239,
242
+ "n02107908": 240,
243
+ "n02108000": 241,
244
+ "n02108089": 242,
245
+ "n02108422": 243,
246
+ "n02108551": 244,
247
+ "n02108915": 245,
248
+ "n02109047": 246,
249
+ "n02109525": 247,
250
+ "n02109961": 248,
251
+ "n02110063": 249,
252
+ "n02110185": 250,
253
+ "n02110341": 251,
254
+ "n02110627": 252,
255
+ "n02110806": 253,
256
+ "n02110958": 254,
257
+ "n02111129": 255,
258
+ "n02111277": 256,
259
+ "n02111500": 257,
260
+ "n02111889": 258,
261
+ "n02112018": 259,
262
+ "n02112137": 260,
263
+ "n02112350": 261,
264
+ "n02112706": 262,
265
+ "n02113023": 263,
266
+ "n02113186": 264,
267
+ "n02113624": 265,
268
+ "n02113712": 266,
269
+ "n02113799": 267,
270
+ "n02113978": 268,
271
+ "n02114367": 269,
272
+ "n02114548": 270,
273
+ "n02114712": 271,
274
+ "n02114855": 272,
275
+ "n02115641": 273,
276
+ "n02115913": 274,
277
+ "n02116738": 275,
278
+ "n02117135": 276,
279
+ "n02119022": 277,
280
+ "n02119789": 278,
281
+ "n02120079": 279,
282
+ "n02120505": 280,
283
+ "n02123045": 281,
284
+ "n02123159": 282,
285
+ "n02123394": 283,
286
+ "n02123597": 284,
287
+ "n02124075": 285,
288
+ "n02125311": 286,
289
+ "n02127052": 287,
290
+ "n02128385": 288,
291
+ "n02128757": 289,
292
+ "n02128925": 290,
293
+ "n02129165": 291,
294
+ "n02129604": 292,
295
+ "n02130308": 293,
296
+ "n02132136": 294,
297
+ "n02133161": 295,
298
+ "n02134084": 296,
299
+ "n02134418": 297,
300
+ "n02137549": 298,
301
+ "n02138441": 299,
302
+ "n02165105": 300,
303
+ "n02165456": 301,
304
+ "n02167151": 302,
305
+ "n02168699": 303,
306
+ "n02169497": 304,
307
+ "n02172182": 305,
308
+ "n02174001": 306,
309
+ "n02177972": 307,
310
+ "n02190166": 308,
311
+ "n02206856": 309,
312
+ "n02219486": 310,
313
+ "n02226429": 311,
314
+ "n02229544": 312,
315
+ "n02231487": 313,
316
+ "n02233338": 314,
317
+ "n02236044": 315,
318
+ "n02256656": 316,
319
+ "n02259212": 317,
320
+ "n02264363": 318,
321
+ "n02268443": 319,
322
+ "n02268853": 320,
323
+ "n02276258": 321,
324
+ "n02277742": 322,
325
+ "n02279972": 323,
326
+ "n02280649": 324,
327
+ "n02281406": 325,
328
+ "n02281787": 326,
329
+ "n02317335": 327,
330
+ "n02319095": 328,
331
+ "n02321529": 329,
332
+ "n02325366": 330,
333
+ "n02326432": 331,
334
+ "n02328150": 332,
335
+ "n02342885": 333,
336
+ "n02346627": 334,
337
+ "n02356798": 335,
338
+ "n02361337": 336,
339
+ "n02363005": 337,
340
+ "n02364673": 338,
341
+ "n02389026": 339,
342
+ "n02391049": 340,
343
+ "n02395406": 341,
344
+ "n02396427": 342,
345
+ "n02397096": 343,
346
+ "n02398521": 344,
347
+ "n02403003": 345,
348
+ "n02408429": 346,
349
+ "n02410509": 347,
350
+ "n02412080": 348,
351
+ "n02415577": 349,
352
+ "n02417914": 350,
353
+ "n02422106": 351,
354
+ "n02422699": 352,
355
+ "n02423022": 353,
356
+ "n02437312": 354,
357
+ "n02437616": 355,
358
+ "n02441942": 356,
359
+ "n02442845": 357,
360
+ "n02443114": 358,
361
+ "n02443484": 359,
362
+ "n02444819": 360,
363
+ "n02445715": 361,
364
+ "n02447366": 362,
365
+ "n02454379": 363,
366
+ "n02457408": 364,
367
+ "n02480495": 365,
368
+ "n02480855": 366,
369
+ "n02481823": 367,
370
+ "n02483362": 368,
371
+ "n02483708": 369,
372
+ "n02484975": 370,
373
+ "n02486261": 371,
374
+ "n02486410": 372,
375
+ "n02487347": 373,
376
+ "n02488291": 374,
377
+ "n02488702": 375,
378
+ "n02489166": 376,
379
+ "n02490219": 377,
380
+ "n02492035": 378,
381
+ "n02492660": 379,
382
+ "n02493509": 380,
383
+ "n02493793": 381,
384
+ "n02494079": 382,
385
+ "n02497673": 383,
386
+ "n02500267": 384,
387
+ "n02504013": 385,
388
+ "n02504458": 386,
389
+ "n02509815": 387,
390
+ "n02510455": 388,
391
+ "n02514041": 389,
392
+ "n02526121": 390,
393
+ "n02536864": 391,
394
+ "n02606052": 392,
395
+ "n02607072": 393,
396
+ "n02640242": 394,
397
+ "n02641379": 395,
398
+ "n02643566": 396,
399
+ "n02655020": 397,
400
+ "n02666196": 398,
401
+ "n02667093": 399,
402
+ "n02669723": 400,
403
+ "n02672831": 401,
404
+ "n02676566": 402,
405
+ "n02687172": 403,
406
+ "n02690373": 404,
407
+ "n02692877": 405,
408
+ "n02699494": 406,
409
+ "n02701002": 407,
410
+ "n02704792": 408,
411
+ "n02708093": 409,
412
+ "n02727426": 410,
413
+ "n02730930": 411,
414
+ "n02747177": 412,
415
+ "n02749479": 413,
416
+ "n02769748": 414,
417
+ "n02776631": 415,
418
+ "n02777292": 416,
419
+ "n02782093": 417,
420
+ "n02783161": 418,
421
+ "n02786058": 419,
422
+ "n02787622": 420,
423
+ "n02788148": 421,
424
+ "n02790996": 422,
425
+ "n02791124": 423,
426
+ "n02791270": 424,
427
+ "n02793495": 425,
428
+ "n02794156": 426,
429
+ "n02795169": 427,
430
+ "n02797295": 428,
431
+ "n02799071": 429,
432
+ "n02802426": 430,
433
+ "n02804414": 431,
434
+ "n02804610": 432,
435
+ "n02807133": 433,
436
+ "n02808304": 434,
437
+ "n02808440": 435,
438
+ "n02814533": 436,
439
+ "n02814860": 437,
440
+ "n02815834": 438,
441
+ "n02817516": 439,
442
+ "n02823428": 440,
443
+ "n02823750": 441,
444
+ "n02825657": 442,
445
+ "n02834397": 443,
446
+ "n02835271": 444,
447
+ "n02837789": 445,
448
+ "n02840245": 446,
449
+ "n02841315": 447,
450
+ "n02843684": 448,
451
+ "n02859443": 449,
452
+ "n02860847": 450,
453
+ "n02865351": 451,
454
+ "n02869837": 452,
455
+ "n02870880": 453,
456
+ "n02871525": 454,
457
+ "n02877765": 455,
458
+ "n02879718": 456,
459
+ "n02883205": 457,
460
+ "n02892201": 458,
461
+ "n02892767": 459,
462
+ "n02894605": 460,
463
+ "n02895154": 461,
464
+ "n02906734": 462,
465
+ "n02909870": 463,
466
+ "n02910353": 464,
467
+ "n02916936": 465,
468
+ "n02917067": 466,
469
+ "n02927161": 467,
470
+ "n02930766": 468,
471
+ "n02939185": 469,
472
+ "n02948072": 470,
473
+ "n02950826": 471,
474
+ "n02951358": 472,
475
+ "n02951585": 473,
476
+ "n02963159": 474,
477
+ "n02965783": 475,
478
+ "n02966193": 476,
479
+ "n02966687": 477,
480
+ "n02971356": 478,
481
+ "n02974003": 479,
482
+ "n02977058": 480,
483
+ "n02978881": 481,
484
+ "n02979186": 482,
485
+ "n02980441": 483,
486
+ "n02981792": 484,
487
+ "n02988304": 485,
488
+ "n02992211": 486,
489
+ "n02992529": 487,
490
+ "n02999410": 488,
491
+ "n03000134": 489,
492
+ "n03000247": 490,
493
+ "n03000684": 491,
494
+ "n03014705": 492,
495
+ "n03016953": 493,
496
+ "n03017168": 494,
497
+ "n03018349": 495,
498
+ "n03026506": 496,
499
+ "n03028079": 497,
500
+ "n03032252": 498,
501
+ "n03041632": 499,
502
+ "n03042490": 500,
503
+ "n03045698": 501,
504
+ "n03047690": 502,
505
+ "n03062245": 503,
506
+ "n03063599": 504,
507
+ "n03063689": 505,
508
+ "n03065424": 506,
509
+ "n03075370": 507,
510
+ "n03085013": 508,
511
+ "n03089624": 509,
512
+ "n03095699": 510,
513
+ "n03100240": 511,
514
+ "n03109150": 512,
515
+ "n03110669": 513,
516
+ "n03124043": 514,
517
+ "n03124170": 515,
518
+ "n03125729": 516,
519
+ "n03126707": 517,
520
+ "n03127747": 518,
521
+ "n03127925": 519,
522
+ "n03131574": 520,
523
+ "n03133878": 521,
524
+ "n03134739": 522,
525
+ "n03141823": 523,
526
+ "n03146219": 524,
527
+ "n03160309": 525,
528
+ "n03179701": 526,
529
+ "n03180011": 527,
530
+ "n03187595": 528,
531
+ "n03188531": 529,
532
+ "n03196217": 530,
533
+ "n03197337": 531,
534
+ "n03201208": 532,
535
+ "n03207743": 533,
536
+ "n03207941": 534,
537
+ "n03208938": 535,
538
+ "n03216828": 536,
539
+ "n03218198": 537,
540
+ "n03220513": 538,
541
+ "n03223299": 539,
542
+ "n03240683": 540,
543
+ "n03249569": 541,
544
+ "n03250847": 542,
545
+ "n03255030": 543,
546
+ "n03259280": 544,
547
+ "n03271574": 545,
548
+ "n03272010": 546,
549
+ "n03272562": 547,
550
+ "n03290653": 548,
551
+ "n03291819": 549,
552
+ "n03297495": 550,
553
+ "n03314780": 551,
554
+ "n03325584": 552,
555
+ "n03337140": 553,
556
+ "n03344393": 554,
557
+ "n03345487": 555,
558
+ "n03347037": 556,
559
+ "n03355925": 557,
560
+ "n03372029": 558,
561
+ "n03376595": 559,
562
+ "n03379051": 560,
563
+ "n03384352": 561,
564
+ "n03388043": 562,
565
+ "n03388183": 563,
566
+ "n03388549": 564,
567
+ "n03393912": 565,
568
+ "n03394916": 566,
569
+ "n03400231": 567,
570
+ "n03404251": 568,
571
+ "n03417042": 569,
572
+ "n03424325": 570,
573
+ "n03425413": 571,
574
+ "n03443371": 572,
575
+ "n03444034": 573,
576
+ "n03445777": 574,
577
+ "n03445924": 575,
578
+ "n03447447": 576,
579
+ "n03447721": 577,
580
+ "n03450230": 578,
581
+ "n03452741": 579,
582
+ "n03457902": 580,
583
+ "n03459775": 581,
584
+ "n03461385": 582,
585
+ "n03467068": 583,
586
+ "n03476684": 584,
587
+ "n03476991": 585,
588
+ "n03478589": 586,
589
+ "n03481172": 587,
590
+ "n03482405": 588,
591
+ "n03483316": 589,
592
+ "n03485407": 590,
593
+ "n03485794": 591,
594
+ "n03492542": 592,
595
+ "n03494278": 593,
596
+ "n03495258": 594,
597
+ "n03496892": 595,
598
+ "n03498962": 596,
599
+ "n03527444": 597,
600
+ "n03529860": 598,
601
+ "n03530642": 599,
602
+ "n03532672": 600,
603
+ "n03534580": 601,
604
+ "n03535780": 602,
605
+ "n03538406": 603,
606
+ "n03544143": 604,
607
+ "n03584254": 605,
608
+ "n03584829": 606,
609
+ "n03590841": 607,
610
+ "n03594734": 608,
611
+ "n03594945": 609,
612
+ "n03595614": 610,
613
+ "n03598930": 611,
614
+ "n03599486": 612,
615
+ "n03602883": 613,
616
+ "n03617480": 614,
617
+ "n03623198": 615,
618
+ "n03627232": 616,
619
+ "n03630383": 617,
620
+ "n03633091": 618,
621
+ "n03637318": 619,
622
+ "n03642806": 620,
623
+ "n03649909": 621,
624
+ "n03657121": 622,
625
+ "n03658185": 623,
626
+ "n03661043": 624,
627
+ "n03662601": 625,
628
+ "n03666591": 626,
629
+ "n03670208": 627,
630
+ "n03673027": 628,
631
+ "n03676483": 629,
632
+ "n03680355": 630,
633
+ "n03690938": 631,
634
+ "n03691459": 632,
635
+ "n03692522": 633,
636
+ "n03697007": 634,
637
+ "n03706229": 635,
638
+ "n03709823": 636,
639
+ "n03710193": 637,
640
+ "n03710637": 638,
641
+ "n03710721": 639,
642
+ "n03717622": 640,
643
+ "n03720891": 641,
644
+ "n03721384": 642,
645
+ "n03724870": 643,
646
+ "n03729826": 644,
647
+ "n03733131": 645,
648
+ "n03733281": 646,
649
+ "n03733805": 647,
650
+ "n03742115": 648,
651
+ "n03743016": 649,
652
+ "n03759954": 650,
653
+ "n03761084": 651,
654
+ "n03763968": 652,
655
+ "n03764736": 653,
656
+ "n03769881": 654,
657
+ "n03770439": 655,
658
+ "n03770679": 656,
659
+ "n03773504": 657,
660
+ "n03775071": 658,
661
+ "n03775546": 659,
662
+ "n03776460": 660,
663
+ "n03777568": 661,
664
+ "n03777754": 662,
665
+ "n03781244": 663,
666
+ "n03782006": 664,
667
+ "n03785016": 665,
668
+ "n03786901": 666,
669
+ "n03787032": 667,
670
+ "n03788195": 668,
671
+ "n03788365": 669,
672
+ "n03791053": 670,
673
+ "n03792782": 671,
674
+ "n03792972": 672,
675
+ "n03793489": 673,
676
+ "n03794056": 674,
677
+ "n03796401": 675,
678
+ "n03803284": 676,
679
+ "n03804744": 677,
680
+ "n03814639": 678,
681
+ "n03814906": 679,
682
+ "n03825788": 680,
683
+ "n03832673": 681,
684
+ "n03837869": 682,
685
+ "n03838899": 683,
686
+ "n03840681": 684,
687
+ "n03841143": 685,
688
+ "n03843555": 686,
689
+ "n03854065": 687,
690
+ "n03857828": 688,
691
+ "n03866082": 689,
692
+ "n03868242": 690,
693
+ "n03868863": 691,
694
+ "n03871628": 692,
695
+ "n03873416": 693,
696
+ "n03874293": 694,
697
+ "n03874599": 695,
698
+ "n03876231": 696,
699
+ "n03877472": 697,
700
+ "n03877845": 698,
701
+ "n03884397": 699,
702
+ "n03887697": 700,
703
+ "n03888257": 701,
704
+ "n03888605": 702,
705
+ "n03891251": 703,
706
+ "n03891332": 704,
707
+ "n03895866": 705,
708
+ "n03899768": 706,
709
+ "n03902125": 707,
710
+ "n03903868": 708,
711
+ "n03908618": 709,
712
+ "n03908714": 710,
713
+ "n03916031": 711,
714
+ "n03920288": 712,
715
+ "n03924679": 713,
716
+ "n03929660": 714,
717
+ "n03929855": 715,
718
+ "n03930313": 716,
719
+ "n03930630": 717,
720
+ "n03933933": 718,
721
+ "n03935335": 719,
722
+ "n03937543": 720,
723
+ "n03938244": 721,
724
+ "n03942813": 722,
725
+ "n03944341": 723,
726
+ "n03947888": 724,
727
+ "n03950228": 725,
728
+ "n03954731": 726,
729
+ "n03956157": 727,
730
+ "n03958227": 728,
731
+ "n03961711": 729,
732
+ "n03967562": 730,
733
+ "n03970156": 731,
734
+ "n03976467": 732,
735
+ "n03976657": 733,
736
+ "n03977966": 734,
737
+ "n03980874": 735,
738
+ "n03982430": 736,
739
+ "n03983396": 737,
740
+ "n03991062": 738,
741
+ "n03992509": 739,
742
+ "n03995372": 740,
743
+ "n03998194": 741,
744
+ "n04004767": 742,
745
+ "n04005630": 743,
746
+ "n04008634": 744,
747
+ "n04009552": 745,
748
+ "n04019541": 746,
749
+ "n04023962": 747,
750
+ "n04026417": 748,
751
+ "n04033901": 749,
752
+ "n04033995": 750,
753
+ "n04037443": 751,
754
+ "n04039381": 752,
755
+ "n04040759": 753,
756
+ "n04041544": 754,
757
+ "n04044716": 755,
758
+ "n04049303": 756,
759
+ "n04065272": 757,
760
+ "n04067472": 758,
761
+ "n04069434": 759,
762
+ "n04070727": 760,
763
+ "n04074963": 761,
764
+ "n04081281": 762,
765
+ "n04086273": 763,
766
+ "n04090263": 764,
767
+ "n04099969": 765,
768
+ "n04111531": 766,
769
+ "n04116512": 767,
770
+ "n04118538": 768,
771
+ "n04118776": 769,
772
+ "n04120489": 770,
773
+ "n04125021": 771,
774
+ "n04127249": 772,
775
+ "n04131690": 773,
776
+ "n04133789": 774,
777
+ "n04136333": 775,
778
+ "n04141076": 776,
779
+ "n04141327": 777,
780
+ "n04141975": 778,
781
+ "n04146614": 779,
782
+ "n04147183": 780,
783
+ "n04149813": 781,
784
+ "n04152593": 782,
785
+ "n04153751": 783,
786
+ "n04154565": 784,
787
+ "n04162706": 785,
788
+ "n04179913": 786,
789
+ "n04192698": 787,
790
+ "n04200800": 788,
791
+ "n04201297": 789,
792
+ "n04204238": 790,
793
+ "n04204347": 791,
794
+ "n04208210": 792,
795
+ "n04209133": 793,
796
+ "n04209239": 794,
797
+ "n04228054": 795,
798
+ "n04229816": 796,
799
+ "n04235860": 797,
800
+ "n04238763": 798,
801
+ "n04239074": 799,
802
+ "n04243546": 800,
803
+ "n04251144": 801,
804
+ "n04252077": 802,
805
+ "n04252225": 803,
806
+ "n04254120": 804,
807
+ "n04254680": 805,
808
+ "n04254777": 806,
809
+ "n04258138": 807,
810
+ "n04259630": 808,
811
+ "n04263257": 809,
812
+ "n04264628": 810,
813
+ "n04265275": 811,
814
+ "n04266014": 812,
815
+ "n04270147": 813,
816
+ "n04273569": 814,
817
+ "n04275548": 815,
818
+ "n04277352": 816,
819
+ "n04285008": 817,
820
+ "n04286575": 818,
821
+ "n04296562": 819,
822
+ "n04310018": 820,
823
+ "n04311004": 821,
824
+ "n04311174": 822,
825
+ "n04317175": 823,
826
+ "n04325704": 824,
827
+ "n04326547": 825,
828
+ "n04328186": 826,
829
+ "n04330267": 827,
830
+ "n04332243": 828,
831
+ "n04335435": 829,
832
+ "n04336792": 830,
833
+ "n04344873": 831,
834
+ "n04346328": 832,
835
+ "n04347754": 833,
836
+ "n04350905": 834,
837
+ "n04355338": 835,
838
+ "n04355933": 836,
839
+ "n04356056": 837,
840
+ "n04357314": 838,
841
+ "n04366367": 839,
842
+ "n04367480": 840,
843
+ "n04370456": 841,
844
+ "n04371430": 842,
845
+ "n04371774": 843,
846
+ "n04372370": 844,
847
+ "n04376876": 845,
848
+ "n04380533": 846,
849
+ "n04389033": 847,
850
+ "n04392985": 848,
851
+ "n04398044": 849,
852
+ "n04399382": 850,
853
+ "n04404412": 851,
854
+ "n04409515": 852,
855
+ "n04417672": 853,
856
+ "n04418357": 854,
857
+ "n04423845": 855,
858
+ "n04428191": 856,
859
+ "n04429376": 857,
860
+ "n04435653": 858,
861
+ "n04442312": 859,
862
+ "n04443257": 860,
863
+ "n04447861": 861,
864
+ "n04456115": 862,
865
+ "n04458633": 863,
866
+ "n04461696": 864,
867
+ "n04462240": 865,
868
+ "n04465501": 866,
869
+ "n04467665": 867,
870
+ "n04476259": 868,
871
+ "n04479046": 869,
872
+ "n04482393": 870,
873
+ "n04483307": 871,
874
+ "n04485082": 872,
875
+ "n04486054": 873,
876
+ "n04487081": 874,
877
+ "n04487394": 875,
878
+ "n04493381": 876,
879
+ "n04501370": 877,
880
+ "n04505470": 878,
881
+ "n04507155": 879,
882
+ "n04509417": 880,
883
+ "n04515003": 881,
884
+ "n04517823": 882,
885
+ "n04522168": 883,
886
+ "n04523525": 884,
887
+ "n04525038": 885,
888
+ "n04525305": 886,
889
+ "n04532106": 887,
890
+ "n04532670": 888,
891
+ "n04536866": 889,
892
+ "n04540053": 890,
893
+ "n04542943": 891,
894
+ "n04548280": 892,
895
+ "n04548362": 893,
896
+ "n04550184": 894,
897
+ "n04552348": 895,
898
+ "n04553703": 896,
899
+ "n04554684": 897,
900
+ "n04557648": 898,
901
+ "n04560804": 899,
902
+ "n04562935": 900,
903
+ "n04579145": 901,
904
+ "n04579432": 902,
905
+ "n04584207": 903,
906
+ "n04589890": 904,
907
+ "n04590129": 905,
908
+ "n04591157": 906,
909
+ "n04591713": 907,
910
+ "n04592741": 908,
911
+ "n04596742": 909,
912
+ "n04597913": 910,
913
+ "n04599235": 911,
914
+ "n04604644": 912,
915
+ "n04606251": 913,
916
+ "n04612504": 914,
917
+ "n04613696": 915,
918
+ "n06359193": 916,
919
+ "n06596364": 917,
920
+ "n06785654": 918,
921
+ "n06794110": 919,
922
+ "n06874185": 920,
923
+ "n07248320": 921,
924
+ "n07565083": 922,
925
+ "n07579787": 923,
926
+ "n07583066": 924,
927
+ "n07584110": 925,
928
+ "n07590611": 926,
929
+ "n07613480": 927,
930
+ "n07614500": 928,
931
+ "n07615774": 929,
932
+ "n07684084": 930,
933
+ "n07693725": 931,
934
+ "n07695742": 932,
935
+ "n07697313": 933,
936
+ "n07697537": 934,
937
+ "n07711569": 935,
938
+ "n07714571": 936,
939
+ "n07714990": 937,
940
+ "n07715103": 938,
941
+ "n07716358": 939,
942
+ "n07716906": 940,
943
+ "n07717410": 941,
944
+ "n07717556": 942,
945
+ "n07718472": 943,
946
+ "n07718747": 944,
947
+ "n07720875": 945,
948
+ "n07730033": 946,
949
+ "n07734744": 947,
950
+ "n07742313": 948,
951
+ "n07745940": 949,
952
+ "n07747607": 950,
953
+ "n07749582": 951,
954
+ "n07753113": 952,
955
+ "n07753275": 953,
956
+ "n07753592": 954,
957
+ "n07754684": 955,
958
+ "n07760859": 956,
959
+ "n07768694": 957,
960
+ "n07802026": 958,
961
+ "n07831146": 959,
962
+ "n07836838": 960,
963
+ "n07860988": 961,
964
+ "n07871810": 962,
965
+ "n07873807": 963,
966
+ "n07875152": 964,
967
+ "n07880968": 965,
968
+ "n07892512": 966,
969
+ "n07920052": 967,
970
+ "n07930864": 968,
971
+ "n07932039": 969,
972
+ "n09193705": 970,
973
+ "n09229709": 971,
974
+ "n09246464": 972,
975
+ "n09256479": 973,
976
+ "n09288635": 974,
977
+ "n09332890": 975,
978
+ "n09399592": 976,
979
+ "n09421951": 977,
980
+ "n09428293": 978,
981
+ "n09468604": 979,
982
+ "n09472597": 980,
983
+ "n09835506": 981,
984
+ "n10148035": 982,
985
+ "n10565667": 983,
986
+ "n11879895": 984,
987
+ "n11939491": 985,
988
+ "n12057211": 986,
989
+ "n12144580": 987,
990
+ "n12267677": 988,
991
+ "n12620546": 989,
992
+ "n12768682": 990,
993
+ "n12985857": 991,
994
+ "n12998815": 992,
995
+ "n13037406": 993,
996
+ "n13040303": 994,
997
+ "n13044778": 995,
998
+ "n13052670": 996,
999
+ "n13054560": 997,
1000
+ "n13133613": 998,
1001
+ "n15075141": 999
1002
+ }
imagenet_eval_robustness.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+
21
+ # Uncomment the expected model below
22
+
23
+ # ViT
24
+ from ViT.ViT import vit_base_patch16_224 as vit
25
+ # from ViT.ViT import vit_large_patch16_224 as vit
26
+
27
+ # ViT-AugReg
28
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
29
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
30
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
31
+
32
+ # DeiT
33
+ # from ViT.ViT import deit_base_patch16_224 as vit
34
+ # from ViT.ViT import deit_small_patch16_224 as vit
35
+
36
+ from robustness_dataset import RobustnessDataset
37
+ from objectnet_dataset import ObjectNetDataset
38
+ model_names = sorted(name for name in models.__dict__
39
+ if name.islower() and not name.startswith("__")
40
+ and callable(models.__dict__[name]))
41
+ model_names.append("vit")
42
+
43
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
44
+ parser.add_argument('--data', metavar='DIR',
45
+ help='path to dataset')
46
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
47
+ help='number of data loading workers (default: 4)')
48
+ parser.add_argument('--epochs', default=150, type=int, metavar='N',
49
+ help='number of total epochs to run')
50
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
51
+ help='manual epoch number (useful on restarts)')
52
+ parser.add_argument('-b', '--batch-size', default=256, type=int,
53
+ metavar='N',
54
+ help='mini-batch size (default: 256), this is the total '
55
+ 'batch size of all GPUs on the current node when '
56
+ 'using Data Parallel or Distributed Data Parallel')
57
+ parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
58
+ metavar='LR', help='initial learning rate', dest='lr')
59
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
60
+ help='momentum')
61
+ parser.add_argument('--wd', '--weight-decay', default=0.05, type=float,
62
+ metavar='W', help='weight decay (default: 1e-4)',
63
+ dest='weight_decay')
64
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
65
+ metavar='N', help='print frequency (default: 10)')
66
+ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
67
+ help='path to latest checkpoint (default: none)')
68
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
69
+ help='path to resume checkpoint (default: none)')
70
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
71
+ help='evaluate model on validation set')
72
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
73
+ help='use pre-trained model')
74
+ parser.add_argument('--world-size', default=-1, type=int,
75
+ help='number of nodes for distributed training')
76
+ parser.add_argument('--rank', default=-1, type=int,
77
+ help='node rank for distributed training')
78
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
79
+ help='url used to set up distributed training')
80
+ parser.add_argument('--dist-backend', default='nccl', type=str,
81
+ help='distributed backend')
82
+ parser.add_argument('--seed', default=None, type=int,
83
+ help='seed for initializing training. ')
84
+ parser.add_argument('--gpu', default=None, type=int,
85
+ help='GPU id to use.')
86
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
87
+ help='Use multi-processing distributed training to launch '
88
+ 'N processes per node, which has N GPUs. This is the '
89
+ 'fastest way to use PyTorch for either single node or '
90
+ 'multi node data parallel training')
91
+ parser.add_argument("--isV2", default=False, action='store_true',
92
+ help='is dataset imagenet V2.')
93
+ parser.add_argument("--isSI", default=False, action='store_true',
94
+ help='is dataset SI-score.')
95
+ parser.add_argument("--isObjectNet", default=False, action='store_true',
96
+ help='is dataset SI-score.')
97
+
98
+
99
+ def main():
100
+ args = parser.parse_args()
101
+
102
+ if args.seed is not None:
103
+ random.seed(args.seed)
104
+ torch.manual_seed(args.seed)
105
+ cudnn.deterministic = True
106
+ warnings.warn('You have chosen to seed training. '
107
+ 'This will turn on the CUDNN deterministic setting, '
108
+ 'which can slow down your training considerably! '
109
+ 'You may see unexpected behavior when restarting '
110
+ 'from checkpoints.')
111
+
112
+ if args.gpu is not None:
113
+ warnings.warn('You have chosen a specific GPU. This will completely '
114
+ 'disable data parallelism.')
115
+
116
+ if args.dist_url == "env://" and args.world_size == -1:
117
+ args.world_size = int(os.environ["WORLD_SIZE"])
118
+
119
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
120
+
121
+ ngpus_per_node = torch.cuda.device_count()
122
+ if args.multiprocessing_distributed:
123
+ # Since we have ngpus_per_node processes per node, the total world_size
124
+ # needs to be adjusted accordingly
125
+ args.world_size = ngpus_per_node * args.world_size
126
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
127
+ # main_worker process function
128
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
129
+ else:
130
+ # Simply call main_worker function
131
+ main_worker(args.gpu, ngpus_per_node, args)
132
+
133
+
134
+ def main_worker(gpu, ngpus_per_node, args):
135
+ global best_acc1
136
+ args.gpu = gpu
137
+
138
+ if args.gpu is not None:
139
+ print("Use GPU: {} for training".format(args.gpu))
140
+
141
+ if args.distributed:
142
+ if args.dist_url == "env://" and args.rank == -1:
143
+ args.rank = int(os.environ["RANK"])
144
+ if args.multiprocessing_distributed:
145
+ # For multiprocessing distributed training, rank needs to be the
146
+ # global rank among all the processes
147
+ args.rank = args.rank * ngpus_per_node + gpu
148
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
149
+ world_size=args.world_size, rank=args.rank)
150
+ # create model
151
+ print("=> creating model")
152
+ if args.checkpoint:
153
+ model = vit().cuda()
154
+ checkpoint = torch.load(args.checkpoint)
155
+ model.load_state_dict(checkpoint['state_dict'])
156
+ else:
157
+ model = vit(pretrained=True).cuda()
158
+ print("done")
159
+
160
+ if not torch.cuda.is_available():
161
+ print('using CPU, this will be slow')
162
+ elif args.distributed:
163
+ # For multiprocessing distributed, DistributedDataParallel constructor
164
+ # should always set the single device scope, otherwise,
165
+ # DistributedDataParallel will use all available devices.
166
+ if args.gpu is not None:
167
+ torch.cuda.set_device(args.gpu)
168
+ model.cuda(args.gpu)
169
+ # When using a single GPU per process and per
170
+ # DistributedDataParallel, we need to divide the batch size
171
+ # ourselves based on the total number of GPUs we have
172
+ args.batch_size = int(args.batch_size / ngpus_per_node)
173
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
174
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
175
+ else:
176
+ model.cuda()
177
+ # DistributedDataParallel will divide and allocate batch_size to all
178
+ # available GPUs if device_ids are not set
179
+ model = torch.nn.parallel.DistributedDataParallel(model)
180
+ elif args.gpu is not None:
181
+ torch.cuda.set_device(args.gpu)
182
+ model = model.cuda(args.gpu)
183
+ else:
184
+ print("start")
185
+ model = torch.nn.DataParallel(model).cuda()
186
+
187
+ # optionally resume from a checkpoint
188
+ if args.resume:
189
+ if os.path.isfile(args.resume):
190
+ print("=> loading checkpoint '{}'".format(args.resume))
191
+ if args.gpu is None:
192
+ checkpoint = torch.load(args.resume)
193
+ else:
194
+ # Map model to be loaded to specified single gpu.
195
+ loc = 'cuda:{}'.format(args.gpu)
196
+ checkpoint = torch.load(args.resume, map_location=loc)
197
+ args.start_epoch = checkpoint['epoch']
198
+ best_acc1 = checkpoint['best_acc1']
199
+ if args.gpu is not None:
200
+ # best_acc1 may be from a checkpoint from a different GPU
201
+ best_acc1 = best_acc1.to(args.gpu)
202
+ model.load_state_dict(checkpoint['state_dict'])
203
+ print("=> loaded checkpoint '{}' (epoch {})"
204
+ .format(args.resume, checkpoint['epoch']))
205
+ else:
206
+ print("=> no checkpoint found at '{}'".format(args.resume))
207
+
208
+ cudnn.benchmark = True
209
+
210
+ if args.isObjectNet:
211
+ val_dataset = ObjectNetDataset(args.data)
212
+ else:
213
+ val_dataset = RobustnessDataset(args.data, isV2=args.isV2, isSI=args.isSI)
214
+
215
+ val_loader = torch.utils.data.DataLoader(
216
+ val_dataset, batch_size=args.batch_size, shuffle=False,
217
+ num_workers=args.workers, pin_memory=True)
218
+
219
+ if args.evaluate:
220
+ validate(val_loader, model, args)
221
+ return
222
+
223
+ def validate(val_loader, model, args):
224
+ batch_time = AverageMeter('Time', ':6.3f')
225
+ losses = AverageMeter('Loss', ':.4e')
226
+ top1 = AverageMeter('Acc@1', ':6.2f')
227
+ top5 = AverageMeter('Acc@5', ':6.2f')
228
+ progress = ProgressMeter(
229
+ len(val_loader),
230
+ [batch_time, losses, top1, top5],
231
+ prefix='Test: ')
232
+
233
+ # switch to evaluate mode
234
+ model.eval()
235
+
236
+ with torch.no_grad():
237
+ end = time.time()
238
+ for i, (images, target) in enumerate(val_loader):
239
+ if args.gpu is not None:
240
+ images = images.cuda(args.gpu, non_blocking=True)
241
+ if torch.cuda.is_available():
242
+ target = target.cuda(args.gpu, non_blocking=True)
243
+
244
+ # compute output
245
+ output = model(images)
246
+
247
+ # measure accuracy and record loss
248
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
249
+ top1.update(acc1[0], images.size(0))
250
+ top5.update(acc5[0], images.size(0))
251
+
252
+ # measure elapsed time
253
+ batch_time.update(time.time() - end)
254
+ end = time.time()
255
+
256
+ if i % args.print_freq == 0:
257
+ progress.display(i)
258
+
259
+ # TODO: this should also be done with the ProgressMeter
260
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
261
+ .format(top1=top1, top5=top5))
262
+
263
+ return top1.avg
264
+
265
+
266
+ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
267
+ torch.save(state, filename)
268
+ if is_best:
269
+ shutil.copyfile(filename, 'model_best.pth.tar')
270
+
271
+
272
+ class AverageMeter(object):
273
+ """Computes and stores the average and current value"""
274
+ def __init__(self, name, fmt=':f'):
275
+ self.name = name
276
+ self.fmt = fmt
277
+ self.reset()
278
+
279
+ def reset(self):
280
+ self.val = 0
281
+ self.avg = 0
282
+ self.sum = 0
283
+ self.count = 0
284
+
285
+ def update(self, val, n=1):
286
+ self.val = val
287
+ self.sum += val * n
288
+ self.count += n
289
+ self.avg = self.sum / self.count
290
+
291
+ def __str__(self):
292
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
293
+ return fmtstr.format(**self.__dict__)
294
+
295
+
296
+ class ProgressMeter(object):
297
+ def __init__(self, num_batches, meters, prefix=""):
298
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
299
+ self.meters = meters
300
+ self.prefix = prefix
301
+
302
+ def display(self, batch):
303
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
304
+ entries += [str(meter) for meter in self.meters]
305
+ print('\t'.join(entries))
306
+
307
+ def _get_batch_fmtstr(self, num_batches):
308
+ num_digits = len(str(num_batches // 1))
309
+ fmt = '{:' + str(num_digits) + 'd}'
310
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
311
+
312
+ def adjust_learning_rate(optimizer, epoch, args):
313
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
314
+ lr = args.lr * (0.85 ** (epoch // 2))
315
+ for param_group in optimizer.param_groups:
316
+ param_group['lr'] = lr
317
+
318
+
319
+ def accuracy(output, target, topk=(1,)):
320
+ """Computes the accuracy over the k top predictions for the specified values of k"""
321
+ with torch.no_grad():
322
+ maxk = max(topk)
323
+ batch_size = target.size(0)
324
+
325
+ _, pred = output.topk(maxk, 1, True, True)
326
+ pred = pred.t()
327
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
328
+
329
+ res = []
330
+ for k in topk:
331
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
332
+ res.append(correct_k.mul_(100.0 / batch_size))
333
+ return res
334
+
335
+
336
+ if __name__ == '__main__':
337
+ main()
imagenet_eval_robustness_per_class.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+
21
+ # Uncomment the expected model below
22
+
23
+ # ViT
24
+ from ViT.ViT import vit_base_patch16_224 as vit
25
+ # from ViT.ViT import vit_large_patch16_224 as vit
26
+
27
+ # ViT-AugReg
28
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
29
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
30
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
31
+
32
+ # DeiT
33
+ # from ViT.ViT import deit_base_patch16_224 as vit
34
+ # from ViT.ViT import deit_small_patch16_224 as vit
35
+
36
+ from robustness_dataset_per_class import RobustnessDataset
37
+ from objectnet_dataset import ObjectNetDataset
38
+ model_names = sorted(name for name in models.__dict__
39
+ if name.islower() and not name.startswith("__")
40
+ and callable(models.__dict__[name]))
41
+ model_names.append("vit")
42
+
43
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
44
+ parser.add_argument('--data', metavar='DIR',
45
+ help='path to dataset')
46
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
47
+ help='number of data loading workers (default: 4)')
48
+ parser.add_argument('--epochs', default=150, type=int, metavar='N',
49
+ help='number of total epochs to run')
50
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
51
+ help='manual epoch number (useful on restarts)')
52
+ parser.add_argument('-b', '--batch-size', default=256, type=int,
53
+ metavar='N',
54
+ help='mini-batch size (default: 256), this is the total '
55
+ 'batch size of all GPUs on the current node when '
56
+ 'using Data Parallel or Distributed Data Parallel')
57
+ parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
58
+ metavar='LR', help='initial learning rate', dest='lr')
59
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
60
+ help='momentum')
61
+ parser.add_argument('--wd', '--weight-decay', default=0.05, type=float,
62
+ metavar='W', help='weight decay (default: 1e-4)',
63
+ dest='weight_decay')
64
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
65
+ metavar='N', help='print frequency (default: 10)')
66
+ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
67
+ help='path to latest checkpoint (default: none)')
68
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
69
+ help='path to resume checkpoint (default: none)')
70
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
71
+ help='evaluate model on validation set')
72
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
73
+ help='use pre-trained model')
74
+ parser.add_argument('--world-size', default=-1, type=int,
75
+ help='number of nodes for distributed training')
76
+ parser.add_argument('--rank', default=-1, type=int,
77
+ help='node rank for distributed training')
78
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
79
+ help='url used to set up distributed training')
80
+ parser.add_argument('--dist-backend', default='nccl', type=str,
81
+ help='distributed backend')
82
+ parser.add_argument('--seed', default=None, type=int,
83
+ help='seed for initializing training. ')
84
+ parser.add_argument('--gpu', default=None, type=int,
85
+ help='GPU id to use.')
86
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
87
+ help='Use multi-processing distributed training to launch '
88
+ 'N processes per node, which has N GPUs. This is the '
89
+ 'fastest way to use PyTorch for either single node or '
90
+ 'multi node data parallel training')
91
+ parser.add_argument("--isV2", default=False, action='store_true',
92
+ help='is dataset imagenet V2.')
93
+ parser.add_argument("--isSI", default=False, action='store_true',
94
+ help='is dataset SI-score.')
95
+ parser.add_argument("--isObjectNet", default=False, action='store_true',
96
+ help='is dataset SI-score.')
97
+
98
+
99
+ def main():
100
+ args = parser.parse_args()
101
+
102
+ if args.seed is not None:
103
+ random.seed(args.seed)
104
+ torch.manual_seed(args.seed)
105
+ cudnn.deterministic = True
106
+ warnings.warn('You have chosen to seed training. '
107
+ 'This will turn on the CUDNN deterministic setting, '
108
+ 'which can slow down your training considerably! '
109
+ 'You may see unexpected behavior when restarting '
110
+ 'from checkpoints.')
111
+
112
+ if args.gpu is not None:
113
+ warnings.warn('You have chosen a specific GPU. This will completely '
114
+ 'disable data parallelism.')
115
+
116
+ if args.dist_url == "env://" and args.world_size == -1:
117
+ args.world_size = int(os.environ["WORLD_SIZE"])
118
+
119
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
120
+
121
+ ngpus_per_node = torch.cuda.device_count()
122
+ if args.multiprocessing_distributed:
123
+ # Since we have ngpus_per_node processes per node, the total world_size
124
+ # needs to be adjusted accordingly
125
+ args.world_size = ngpus_per_node * args.world_size
126
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
127
+ # main_worker process function
128
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
129
+ else:
130
+ # Simply call main_worker function
131
+ main_worker(args.gpu, ngpus_per_node, args)
132
+
133
+
134
+ def main_worker(gpu, ngpus_per_node, args):
135
+ global best_acc1
136
+ args.gpu = gpu
137
+
138
+ if args.gpu is not None:
139
+ print("Use GPU: {} for training".format(args.gpu))
140
+
141
+ if args.distributed:
142
+ if args.dist_url == "env://" and args.rank == -1:
143
+ args.rank = int(os.environ["RANK"])
144
+ if args.multiprocessing_distributed:
145
+ # For multiprocessing distributed training, rank needs to be the
146
+ # global rank among all the processes
147
+ args.rank = args.rank * ngpus_per_node + gpu
148
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
149
+ world_size=args.world_size, rank=args.rank)
150
+ # create model
151
+ print("=> creating model")
152
+ if args.checkpoint:
153
+ model = vit().cuda()
154
+ checkpoint = torch.load(args.checkpoint)
155
+ model.load_state_dict(checkpoint['state_dict'])
156
+ else:
157
+ model = vit(pretrained=True).cuda()
158
+ print("done")
159
+
160
+ if not torch.cuda.is_available():
161
+ print('using CPU, this will be slow')
162
+ elif args.distributed:
163
+ # For multiprocessing distributed, DistributedDataParallel constructor
164
+ # should always set the single device scope, otherwise,
165
+ # DistributedDataParallel will use all available devices.
166
+ if args.gpu is not None:
167
+ torch.cuda.set_device(args.gpu)
168
+ model.cuda(args.gpu)
169
+ # When using a single GPU per process and per
170
+ # DistributedDataParallel, we need to divide the batch size
171
+ # ourselves based on the total number of GPUs we have
172
+ args.batch_size = int(args.batch_size / ngpus_per_node)
173
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
174
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
175
+ else:
176
+ model.cuda()
177
+ # DistributedDataParallel will divide and allocate batch_size to all
178
+ # available GPUs if device_ids are not set
179
+ model = torch.nn.parallel.DistributedDataParallel(model)
180
+ elif args.gpu is not None:
181
+ torch.cuda.set_device(args.gpu)
182
+ model = model.cuda(args.gpu)
183
+ else:
184
+ # DataParallel will divide and allocate batch_size to all available GPUs
185
+ print("start")
186
+ model = torch.nn.DataParallel(model).cuda()
187
+
188
+ # optionally resume from a checkpoint
189
+ if args.resume:
190
+ if os.path.isfile(args.resume):
191
+ print("=> loading checkpoint '{}'".format(args.resume))
192
+ if args.gpu is None:
193
+ checkpoint = torch.load(args.resume)
194
+ else:
195
+ # Map model to be loaded to specified single gpu.
196
+ loc = 'cuda:{}'.format(args.gpu)
197
+ checkpoint = torch.load(args.resume, map_location=loc)
198
+ args.start_epoch = checkpoint['epoch']
199
+ best_acc1 = checkpoint['best_acc1']
200
+ if args.gpu is not None:
201
+ # best_acc1 may be from a checkpoint from a different GPU
202
+ best_acc1 = best_acc1.to(args.gpu)
203
+ model.load_state_dict(checkpoint['state_dict'])
204
+ print("=> loaded checkpoint '{}' (epoch {})"
205
+ .format(args.resume, checkpoint['epoch']))
206
+ else:
207
+ print("=> no checkpoint found at '{}'".format(args.resume))
208
+
209
+ cudnn.benchmark = True
210
+
211
+ # Data loading code
212
+
213
+ top1_per_class = {}
214
+ top5_per_class = {}
215
+ for folder in os.listdir(args.data):
216
+ val_dataset = RobustnessDataset(args.data, folder=folder, isV2=args.isV2, isSI=args.isSI)
217
+ print("len: ", len(val_dataset))
218
+ val_loader = torch.utils.data.DataLoader(
219
+ val_dataset, batch_size=args.batch_size, shuffle=False,
220
+ num_workers=args.workers, pin_memory=True)
221
+ class_name = val_dataset.get_classname()
222
+ top1, top5 = validate(val_loader, model, args)
223
+ top1_per_class[class_name] = top1.item()
224
+ top5_per_class[class_name] = top5.item()
225
+
226
+ print("overall top1 per class: ", top1_per_class)
227
+ print("overall top5 per class: ", top5_per_class)
228
+
229
+ def validate(val_loader, model, args):
230
+ batch_time = AverageMeter('Time', ':6.3f')
231
+ losses = AverageMeter('Loss', ':.4e')
232
+ top1 = AverageMeter('Acc@1', ':6.2f')
233
+ top5 = AverageMeter('Acc@5', ':6.2f')
234
+ progress = ProgressMeter(
235
+ len(val_loader),
236
+ [batch_time, losses, top1, top5],
237
+ prefix='Test: ')
238
+
239
+ # switch to evaluate mode
240
+ model.eval()
241
+
242
+ with torch.no_grad():
243
+ end = time.time()
244
+ for i, (images, target) in enumerate(val_loader):
245
+ if args.gpu is not None:
246
+ images = images.cuda(args.gpu, non_blocking=True)
247
+ if torch.cuda.is_available():
248
+ target = target.cuda(args.gpu, non_blocking=True)
249
+
250
+ # compute output
251
+ output = model(images)
252
+
253
+ # measure accuracy and record loss
254
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
255
+ top1.update(acc1[0], images.size(0))
256
+ top5.update(acc5[0], images.size(0))
257
+
258
+ # measure elapsed time
259
+ batch_time.update(time.time() - end)
260
+ end = time.time()
261
+
262
+ if i % args.print_freq == 0:
263
+ progress.display(i)
264
+
265
+ # TODO: this should also be done with the ProgressMeter
266
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
267
+ .format(top1=top1, top5=top5))
268
+
269
+ return top1.avg, top5.avg
270
+
271
+
272
+ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
273
+ torch.save(state, filename)
274
+ if is_best:
275
+ shutil.copyfile(filename, 'model_best.pth.tar')
276
+
277
+
278
+ class AverageMeter(object):
279
+ """Computes and stores the average and current value"""
280
+ def __init__(self, name, fmt=':f'):
281
+ self.name = name
282
+ self.fmt = fmt
283
+ self.reset()
284
+
285
+ def reset(self):
286
+ self.val = 0
287
+ self.avg = 0
288
+ self.sum = 0
289
+ self.count = 0
290
+
291
+ def update(self, val, n=1):
292
+ self.val = val
293
+ self.sum += val * n
294
+ self.count += n
295
+ self.avg = self.sum / self.count
296
+
297
+ def __str__(self):
298
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
299
+ return fmtstr.format(**self.__dict__)
300
+
301
+
302
+ class ProgressMeter(object):
303
+ def __init__(self, num_batches, meters, prefix=""):
304
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
305
+ self.meters = meters
306
+ self.prefix = prefix
307
+
308
+ def display(self, batch):
309
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
310
+ entries += [str(meter) for meter in self.meters]
311
+ print('\t'.join(entries))
312
+
313
+ def _get_batch_fmtstr(self, num_batches):
314
+ num_digits = len(str(num_batches // 1))
315
+ fmt = '{:' + str(num_digits) + 'd}'
316
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
317
+
318
+ def adjust_learning_rate(optimizer, epoch, args):
319
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
320
+ lr = args.lr * (0.85 ** (epoch // 2))
321
+ for param_group in optimizer.param_groups:
322
+ param_group['lr'] = lr
323
+
324
+
325
+ def accuracy(output, target, topk=(1,)):
326
+ """Computes the accuracy over the k top predictions for the specified values of k"""
327
+ with torch.no_grad():
328
+ maxk = max(topk)
329
+ batch_size = target.size(0)
330
+
331
+ _, pred = output.topk(maxk, 1, True, True)
332
+ pred = pred.t()
333
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
334
+
335
+ res = []
336
+ for k in topk:
337
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
338
+ res.append(correct_k.mul_(100.0 / batch_size))
339
+ return res
340
+
341
+
342
+ if __name__ == '__main__':
343
+ main()
imagenet_finetune.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+ from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
21
+
22
+ # Uncomment the expected model below
23
+
24
+ # ViT
25
+ from ViT.ViT import vit_base_patch16_224 as vit
26
+ # from ViT.ViT import vit_large_patch16_224 as vit
27
+
28
+ # ViT-AugReg
29
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
30
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
31
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
32
+
33
+ # DeiT
34
+ # from ViT.ViT import deit_base_patch16_224 as vit
35
+ # from ViT.ViT import deit_small_patch16_224 as vit
36
+
37
+ from ViT.explainer import generate_relevance, get_image_with_relevance
38
+ import torchvision
39
+ import cv2
40
+ from torch.utils.tensorboard import SummaryWriter
41
+ import json
42
+
43
+ model_names = sorted(name for name in models.__dict__
44
+ if name.islower() and not name.startswith("__")
45
+ and callable(models.__dict__[name]))
46
+ model_names.append("vit")
47
+
48
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
49
+ parser.add_argument('--data', metavar='DATA',
50
+ help='path to dataset')
51
+ parser.add_argument('--seg_data', metavar='SEG_DATA',
52
+ help='path to segmentation dataset')
53
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
54
+ help='number of data loading workers (default: 4)')
55
+ parser.add_argument('--epochs', default=50, type=int, metavar='N',
56
+ help='number of total epochs to run')
57
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
58
+ help='manual epoch number (useful on restarts)')
59
+ parser.add_argument('-b', '--batch-size', default=8, type=int,
60
+ metavar='N',
61
+ help='mini-batch size (default: 256), this is the total '
62
+ 'batch size of all GPUs on the current node when '
63
+ 'using Data Parallel or Distributed Data Parallel')
64
+ parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
65
+ metavar='LR', help='initial learning rate', dest='lr')
66
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67
+ help='momentum')
68
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
69
+ metavar='W', help='weight decay (default: 1e-4)',
70
+ dest='weight_decay')
71
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
72
+ metavar='N', help='print frequency (default: 10)')
73
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
74
+ help='path to latest checkpoint (default: none)')
75
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76
+ help='evaluate model on validation set')
77
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
78
+ help='use pre-trained model')
79
+ parser.add_argument('--world-size', default=-1, type=int,
80
+ help='number of nodes for distributed training')
81
+ parser.add_argument('--rank', default=-1, type=int,
82
+ help='node rank for distributed training')
83
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
84
+ help='url used to set up distributed training')
85
+ parser.add_argument('--dist-backend', default='nccl', type=str,
86
+ help='distributed backend')
87
+ parser.add_argument('--gpu', default=None, type=int,
88
+ help='GPU id to use.')
89
+ parser.add_argument('--save_interval', default=20, type=int,
90
+ help='interval to save segmentation results.')
91
+ parser.add_argument('--num_samples', default=3, type=int,
92
+ help='number of samples per class for training')
93
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
94
+ help='Use multi-processing distributed training to launch '
95
+ 'N processes per node, which has N GPUs. This is the '
96
+ 'fastest way to use PyTorch for either single node or '
97
+ 'multi node data parallel training')
98
+ parser.add_argument('--lambda_seg', default=0.8, type=float,
99
+ help='influence of segmentation loss.')
100
+ parser.add_argument('--lambda_acc', default=0.2, type=float,
101
+ help='influence of accuracy loss.')
102
+ parser.add_argument('--experiment_folder', default=None, type=str,
103
+ help='path to folder to use for experiment.')
104
+ parser.add_argument('--dilation', default=0, type=float,
105
+ help='Use dilation on the segmentation maps.')
106
+ parser.add_argument('--lambda_background', default=2, type=float,
107
+ help='coefficient of loss for segmentation background.')
108
+ parser.add_argument('--lambda_foreground', default=0.3, type=float,
109
+ help='coefficient of loss for segmentation foreground.')
110
+ parser.add_argument('--num_classes', default=500, type=int,
111
+ help='coefficient of loss for segmentation foreground.')
112
+ parser.add_argument('--temperature', default=1, type=float,
113
+ help='temperature for softmax (mostly for DeiT).')
114
+ parser.add_argument('--class_seed', default=None, type=int,
115
+ help='seed to randomly shuffle classes chosen for training.')
116
+
117
+ best_loss = float('inf')
118
+
119
+ def main():
120
+ args = parser.parse_args()
121
+
122
+ if args.experiment_folder is None:
123
+ args.experiment_folder = f'experiment/' \
124
+ f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}' \
125
+ f'_bckg_{args.lambda_background}_fgd_{args.lambda_foreground}'
126
+ if args.temperature != 1:
127
+ args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
128
+ if args.batch_size != 8:
129
+ args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
130
+ if args.num_classes != 500:
131
+ args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
132
+ if args.num_samples != 3:
133
+ args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
134
+ if args.epochs != 150:
135
+ args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
136
+ if args.class_seed is not None:
137
+ args.experiment_folder = args.experiment_folder + f'_seed_{args.class_seed}'
138
+
139
+ if os.path.exists(args.experiment_folder):
140
+ raise Exception(f"Experiment path {args.experiment_folder} already exists!")
141
+ os.mkdir(args.experiment_folder)
142
+ os.mkdir(f'{args.experiment_folder}/train_samples')
143
+ os.mkdir(f'{args.experiment_folder}/val_samples')
144
+
145
+ with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
146
+ json.dump(args.__dict__, f, indent=2)
147
+
148
+ if args.gpu is not None:
149
+ warnings.warn('You have chosen a specific GPU. This will completely '
150
+ 'disable data parallelism.')
151
+
152
+ if args.dist_url == "env://" and args.world_size == -1:
153
+ args.world_size = int(os.environ["WORLD_SIZE"])
154
+
155
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
156
+
157
+ ngpus_per_node = torch.cuda.device_count()
158
+ if args.multiprocessing_distributed:
159
+ # Since we have ngpus_per_node processes per node, the total world_size
160
+ # needs to be adjusted accordingly
161
+ args.world_size = ngpus_per_node * args.world_size
162
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
163
+ # main_worker process function
164
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
165
+ else:
166
+ # Simply call main_worker function
167
+ main_worker(args.gpu, ngpus_per_node, args)
168
+
169
+
170
+ def main_worker(gpu, ngpus_per_node, args):
171
+ global best_loss
172
+ args.gpu = gpu
173
+
174
+ if args.gpu is not None:
175
+ print("Use GPU: {} for training".format(args.gpu))
176
+
177
+ if args.distributed:
178
+ if args.dist_url == "env://" and args.rank == -1:
179
+ args.rank = int(os.environ["RANK"])
180
+ if args.multiprocessing_distributed:
181
+ # For multiprocessing distributed training, rank needs to be the
182
+ # global rank among all the processes
183
+ args.rank = args.rank * ngpus_per_node + gpu
184
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
185
+ world_size=args.world_size, rank=args.rank)
186
+ # create model
187
+ print("=> creating model")
188
+ model = vit(pretrained=True).cuda()
189
+ model.train()
190
+ print("done")
191
+
192
+ if not torch.cuda.is_available():
193
+ print('using CPU, this will be slow')
194
+ elif args.distributed:
195
+ # For multiprocessing distributed, DistributedDataParallel constructor
196
+ # should always set the single device scope, otherwise,
197
+ # DistributedDataParallel will use all available devices.
198
+ if args.gpu is not None:
199
+ torch.cuda.set_device(args.gpu)
200
+ model.cuda(args.gpu)
201
+ # When using a single GPU per process and per
202
+ # DistributedDataParallel, we need to divide the batch size
203
+ # ourselves based on the total number of GPUs we have
204
+ args.batch_size = int(args.batch_size / ngpus_per_node)
205
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
206
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
207
+ else:
208
+ model.cuda()
209
+ # DistributedDataParallel will divide and allocate batch_size to all
210
+ # available GPUs if device_ids are not set
211
+ model = torch.nn.parallel.DistributedDataParallel(model)
212
+ elif args.gpu is not None:
213
+ torch.cuda.set_device(args.gpu)
214
+ model = model.cuda(args.gpu)
215
+ else:
216
+ # DataParallel will divide and allocate batch_size to all available GPUs
217
+ print("start")
218
+ model = torch.nn.DataParallel(model).cuda()
219
+
220
+ # define loss function (criterion) and optimizer
221
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
222
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
223
+
224
+ # optionally resume from a checkpoint
225
+ if args.resume:
226
+ if os.path.isfile(args.resume):
227
+ print("=> loading checkpoint '{}'".format(args.resume))
228
+ if args.gpu is None:
229
+ checkpoint = torch.load(args.resume)
230
+ else:
231
+ # Map model to be loaded to specified single gpu.
232
+ loc = 'cuda:{}'.format(args.gpu)
233
+ checkpoint = torch.load(args.resume, map_location=loc)
234
+ args.start_epoch = checkpoint['epoch']
235
+ best_loss = checkpoint['best_loss']
236
+ if args.gpu is not None:
237
+ # best_loss may be from a checkpoint from a different GPU
238
+ best_loss = best_loss.to(args.gpu)
239
+ model.load_state_dict(checkpoint['state_dict'])
240
+ optimizer.load_state_dict(checkpoint['optimizer'])
241
+ print("=> loaded checkpoint '{}' (epoch {})"
242
+ .format(args.resume, checkpoint['epoch']))
243
+ else:
244
+ print("=> no checkpoint found at '{}'".format(args.resume))
245
+
246
+ cudnn.benchmark = True
247
+
248
+ train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
249
+ num_samples=args.num_samples, seed=args.class_seed)
250
+
251
+ if args.distributed:
252
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
253
+ else:
254
+ train_sampler = None
255
+
256
+ train_loader = torch.utils.data.DataLoader(
257
+ train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
258
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
259
+
260
+ val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
261
+ num_samples=1, seed=args.class_seed)
262
+
263
+ val_loader = torch.utils.data.DataLoader(
264
+ val_dataset, batch_size=10, shuffle=False,
265
+ num_workers=args.workers, pin_memory=True)
266
+
267
+ if args.evaluate:
268
+ validate(val_loader, model, criterion, 0, args)
269
+ return
270
+
271
+ for epoch in range(args.start_epoch, args.epochs):
272
+ if args.distributed:
273
+ train_sampler.set_epoch(epoch)
274
+ adjust_learning_rate(optimizer, epoch, args)
275
+
276
+ log_dir = os.path.join(args.experiment_folder, 'logs')
277
+ logger = SummaryWriter(log_dir=log_dir)
278
+ args.logger = logger
279
+
280
+ # train for one epoch
281
+ train(train_loader, model, criterion, optimizer, epoch, args)
282
+
283
+ # evaluate on validation set
284
+ loss1 = validate(val_loader, model, criterion, epoch, args)
285
+
286
+ # remember best acc@1 and save checkpoint
287
+ is_best = loss1 <= best_loss
288
+ best_loss = min(loss1, best_loss)
289
+
290
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
291
+ and args.rank % ngpus_per_node == 0):
292
+ save_checkpoint({
293
+ 'epoch': epoch + 1,
294
+ 'state_dict': model.state_dict(),
295
+ 'best_loss': best_loss,
296
+ 'optimizer' : optimizer.state_dict(),
297
+ }, is_best, folder=args.experiment_folder)
298
+
299
+
300
+ def train(train_loader, model, criterion, optimizer, epoch, args):
301
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
302
+
303
+ losses = AverageMeter('Loss', ':.4e')
304
+ top1 = AverageMeter('Acc@1', ':6.2f')
305
+ top5 = AverageMeter('Acc@5', ':6.2f')
306
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
307
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
308
+ progress = ProgressMeter(
309
+ len(train_loader),
310
+ [losses, top1, top5, orig_top1, orig_top5],
311
+ prefix="Epoch: [{}]".format(epoch))
312
+
313
+ orig_model = vit(pretrained=True).cuda()
314
+ orig_model.eval()
315
+
316
+ # switch to train mode
317
+ model.train()
318
+
319
+ for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
320
+ if torch.cuda.is_available():
321
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
322
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
323
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
324
+
325
+ # segmentation loss
326
+ relevance = generate_relevance(model, image_ten, index=class_name)
327
+
328
+ reverse_seg_map = seg_map.clone()
329
+ reverse_seg_map[reverse_seg_map == 1] = -1
330
+ reverse_seg_map[reverse_seg_map == 0] = 1
331
+ reverse_seg_map[reverse_seg_map == -1] = 0
332
+ background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
333
+ foreground_loss = mse_criterion(relevance * seg_map, seg_map)
334
+ segmentation_loss = args.lambda_background * background_loss
335
+ segmentation_loss += args.lambda_foreground * foreground_loss
336
+
337
+ # classification loss
338
+ output = model(image_ten)
339
+ with torch.no_grad():
340
+ output_orig = orig_model(image_ten)
341
+
342
+ _, pred = output.topk(1, 1, True, True)
343
+ pred = pred.flatten()
344
+
345
+ if args.temperature != 1:
346
+ output = output / args.temperature
347
+ classification_loss = criterion(output, pred)
348
+
349
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
350
+
351
+ # debugging output
352
+ if i % args.save_interval == 0:
353
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
354
+ for j in range(image_ten.shape[0]):
355
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
356
+ new_vis = get_image_with_relevance(image_ten[j], relevance[j])
357
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
358
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
359
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
360
+ cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
361
+
362
+ # measure accuracy and record loss
363
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
364
+ losses.update(loss.item(), image_ten.size(0))
365
+ top1.update(acc1[0], image_ten.size(0))
366
+ top5.update(acc5[0], image_ten.size(0))
367
+
368
+ # metrics for original vit
369
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
370
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
371
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
372
+
373
+ # compute gradient and do SGD step
374
+ optimizer.zero_grad()
375
+ loss.backward()
376
+ optimizer.step()
377
+
378
+ if i % args.print_freq == 0:
379
+ progress.display(i)
380
+ args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
381
+ epoch*len(train_loader)+i)
382
+ args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
383
+ epoch * len(train_loader) + i)
384
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
385
+ epoch * len(train_loader) + i)
386
+ args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
387
+ epoch * len(train_loader) + i)
388
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
389
+ epoch * len(train_loader) + i)
390
+ args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
391
+ epoch * len(train_loader) + i)
392
+ args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
393
+ epoch * len(train_loader) + i)
394
+
395
+
396
+ def validate(val_loader, model, criterion, epoch, args):
397
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
398
+
399
+ losses = AverageMeter('Loss', ':.4e')
400
+ top1 = AverageMeter('Acc@1', ':6.2f')
401
+ top5 = AverageMeter('Acc@5', ':6.2f')
402
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
403
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
404
+ progress = ProgressMeter(
405
+ len(val_loader),
406
+ [losses, top1, top5, orig_top1, orig_top5],
407
+ prefix="Epoch: [{}]".format(val_loader))
408
+
409
+ # switch to evaluate mode
410
+ model.eval()
411
+
412
+ orig_model = vit(pretrained=True).cuda()
413
+ orig_model.eval()
414
+
415
+ with torch.no_grad():
416
+ for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
417
+ if args.gpu is not None:
418
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
419
+ if torch.cuda.is_available():
420
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
421
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
422
+
423
+ # segmentation loss
424
+ with torch.enable_grad():
425
+ relevance = generate_relevance(model, image_ten, index=class_name)
426
+
427
+ reverse_seg_map = seg_map.clone()
428
+ reverse_seg_map[reverse_seg_map == 1] = -1
429
+ reverse_seg_map[reverse_seg_map == 0] = 1
430
+ reverse_seg_map[reverse_seg_map == -1] = 0
431
+ background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
432
+ foreground_loss = mse_criterion(relevance * seg_map, seg_map)
433
+ segmentation_loss = args.lambda_background * background_loss
434
+ segmentation_loss += args.lambda_foreground * foreground_loss
435
+
436
+ # classification loss
437
+ with torch.no_grad():
438
+ output = model(image_ten)
439
+ output_orig = orig_model(image_ten)
440
+
441
+ _, pred = output.topk(1, 1, True, True)
442
+ pred = pred.flatten()
443
+ if args.temperature != 1:
444
+ output = output / args.temperature
445
+ classification_loss = criterion(output, pred)
446
+
447
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
448
+
449
+ # save results
450
+ if i % args.save_interval == 0:
451
+ with torch.enable_grad():
452
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
453
+ for j in range(image_ten.shape[0]):
454
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
455
+ new_vis = get_image_with_relevance(image_ten[j], relevance[j])
456
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
457
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
458
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
459
+ cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
460
+
461
+ # measure accuracy and record loss
462
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
463
+ losses.update(loss.item(), image_ten.size(0))
464
+ top1.update(acc1[0], image_ten.size(0))
465
+ top5.update(acc5[0], image_ten.size(0))
466
+
467
+ # metrics for original vit
468
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
469
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
470
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
471
+
472
+ if i % args.print_freq == 0:
473
+ progress.display(i)
474
+ args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
475
+ epoch * len(val_loader) + i)
476
+ args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
477
+ epoch * len(val_loader) + i)
478
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
479
+ epoch * len(val_loader) + i)
480
+ args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
481
+ epoch * len(val_loader) + i)
482
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
483
+ epoch * len(val_loader) + i)
484
+ args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
485
+ epoch * len(val_loader) + i)
486
+ args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
487
+ epoch * len(val_loader) + i)
488
+
489
+ # TODO: this should also be done with the ProgressMeter
490
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
491
+ .format(top1=top1, top5=top5))
492
+
493
+ return losses.avg
494
+
495
+
496
+ def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
497
+ torch.save(state, f'{folder}/{filename}')
498
+ if is_best:
499
+ shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
500
+
501
+
502
+ class AverageMeter(object):
503
+ """Computes and stores the average and current value"""
504
+ def __init__(self, name, fmt=':f'):
505
+ self.name = name
506
+ self.fmt = fmt
507
+ self.reset()
508
+
509
+ def reset(self):
510
+ self.val = 0
511
+ self.avg = 0
512
+ self.sum = 0
513
+ self.count = 0
514
+
515
+ def update(self, val, n=1):
516
+ self.val = val
517
+ self.sum += val * n
518
+ self.count += n
519
+ self.avg = self.sum / self.count
520
+
521
+ def __str__(self):
522
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
523
+ return fmtstr.format(**self.__dict__)
524
+
525
+
526
+ class ProgressMeter(object):
527
+ def __init__(self, num_batches, meters, prefix=""):
528
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
529
+ self.meters = meters
530
+ self.prefix = prefix
531
+
532
+ def display(self, batch):
533
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
534
+ entries += [str(meter) for meter in self.meters]
535
+ print('\t'.join(entries))
536
+
537
+ def _get_batch_fmtstr(self, num_batches):
538
+ num_digits = len(str(num_batches // 1))
539
+ fmt = '{:' + str(num_digits) + 'd}'
540
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
541
+
542
+ def adjust_learning_rate(optimizer, epoch, args):
543
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
544
+ lr = args.lr * (0.85 ** (epoch // 2))
545
+ for param_group in optimizer.param_groups:
546
+ param_group['lr'] = lr
547
+
548
+
549
+ def accuracy(output, target, topk=(1,)):
550
+ """Computes the accuracy over the k top predictions for the specified values of k"""
551
+ with torch.no_grad():
552
+ maxk = max(topk)
553
+ batch_size = target.size(0)
554
+
555
+ _, pred = output.topk(maxk, 1, True, True)
556
+ pred = pred.t()
557
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
558
+
559
+ res = []
560
+ for k in topk:
561
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
562
+ res.append(correct_k.mul_(100.0 / batch_size))
563
+ return res
564
+
565
+
566
+ if __name__ == '__main__':
567
+ main()
imagenet_finetune_gradmask.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+ from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
21
+ import numpy as np
22
+
23
+ # Uncomment the expected model below
24
+
25
+ # ViT
26
+ from ViT.ViT import vit_base_patch16_224 as vit
27
+ # from ViT.ViT import vit_large_patch16_224 as vit
28
+
29
+ # ViT-AugReg
30
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
31
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
32
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
33
+
34
+ # DeiT
35
+ # from ViT.ViT import deit_base_patch16_224 as vit
36
+ # from ViT.ViT import deit_small_patch16_224 as vit
37
+
38
+ from ViT.explainer import generate_relevance, get_image_with_relevance
39
+ import torchvision
40
+ import cv2
41
+ from torch.utils.tensorboard import SummaryWriter
42
+ import json
43
+
44
+ model_names = sorted(name for name in models.__dict__
45
+ if name.islower() and not name.startswith("__")
46
+ and callable(models.__dict__[name]))
47
+ model_names.append("vit")
48
+
49
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
50
+ parser.add_argument('--data', metavar='DATA',
51
+ help='path to dataset')
52
+ parser.add_argument('--seg_data', metavar='SEG_DATA',
53
+ help='path to segmentation dataset')
54
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
55
+ help='number of data loading workers (default: 4)')
56
+ parser.add_argument('--epochs', default=50, type=int, metavar='N',
57
+ help='number of total epochs to run')
58
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
59
+ help='manual epoch number (useful on restarts)')
60
+ parser.add_argument('-b', '--batch-size', default=8, type=int,
61
+ metavar='N',
62
+ help='mini-batch size (default: 256), this is the total '
63
+ 'batch size of all GPUs on the current node when '
64
+ 'using Data Parallel or Distributed Data Parallel')
65
+ parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
66
+ metavar='LR', help='initial learning rate', dest='lr')
67
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
68
+ help='momentum')
69
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
70
+ metavar='W', help='weight decay (default: 1e-4)',
71
+ dest='weight_decay')
72
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
73
+ metavar='N', help='print frequency (default: 10)')
74
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
75
+ help='path to latest checkpoint (default: none)')
76
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
77
+ help='evaluate model on validation set')
78
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
79
+ help='use pre-trained model')
80
+ parser.add_argument('--world-size', default=-1, type=int,
81
+ help='number of nodes for distributed training')
82
+ parser.add_argument('--rank', default=-1, type=int,
83
+ help='node rank for distributed training')
84
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
85
+ help='url used to set up distributed training')
86
+ parser.add_argument('--dist-backend', default='nccl', type=str,
87
+ help='distributed backend')
88
+ parser.add_argument('--seed', default=None, type=int,
89
+ help='seed for initializing training. ')
90
+ parser.add_argument('--gpu', default=None, type=int,
91
+ help='GPU id to use.')
92
+ parser.add_argument('--save_interval', default=20, type=int,
93
+ help='interval to save segmentation results.')
94
+ parser.add_argument('--num_samples', default=3, type=int,
95
+ help='number of samples per class for training')
96
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
97
+ help='Use multi-processing distributed training to launch '
98
+ 'N processes per node, which has N GPUs. This is the '
99
+ 'fastest way to use PyTorch for either single node or '
100
+ 'multi node data parallel training')
101
+ parser.add_argument('--lambda_seg', default=0.8, type=float,
102
+ help='influence of segmentation loss.')
103
+ parser.add_argument('--lambda_acc', default=0.2, type=float,
104
+ help='influence of accuracy loss.')
105
+ parser.add_argument('--experiment_folder', default=None, type=str,
106
+ help='path to folder to use for experiment.')
107
+ parser.add_argument('--num_classes', default=500, type=int,
108
+ help='coefficient of loss for segmentation foreground.')
109
+ parser.add_argument('--temperature', default=1, type=float,
110
+ help='temperature for softmax (mostly for DeiT).')
111
+
112
+ best_loss = float('inf')
113
+
114
+ def main():
115
+ args = parser.parse_args()
116
+
117
+ if args.experiment_folder is None:
118
+ args.experiment_folder = f'experiment/' \
119
+ f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}'
120
+ if args.temperature != 1:
121
+ args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
122
+ if args.batch_size != 10:
123
+ args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
124
+ if args.num_classes != 500:
125
+ args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
126
+ if args.num_samples != 3:
127
+ args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
128
+ if args.epochs != 150:
129
+ args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
130
+
131
+ if os.path.exists(args.experiment_folder):
132
+ raise Exception(f"Experiment path {args.experiment_folder} already exists!")
133
+ os.mkdir(args.experiment_folder)
134
+ os.mkdir(f'{args.experiment_folder}/train_samples')
135
+ os.mkdir(f'{args.experiment_folder}/val_samples')
136
+
137
+ with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
138
+ json.dump(args.__dict__, f, indent=2)
139
+
140
+ if args.seed is not None:
141
+ random.seed(args.seed)
142
+ torch.manual_seed(args.seed)
143
+ cudnn.deterministic = True
144
+ warnings.warn('You have chosen to seed training. '
145
+ 'This will turn on the CUDNN deterministic setting, '
146
+ 'which can slow down your training considerably! '
147
+ 'You may see unexpected behavior when restarting '
148
+ 'from checkpoints.')
149
+
150
+ if args.gpu is not None:
151
+ warnings.warn('You have chosen a specific GPU. This will completely '
152
+ 'disable data parallelism.')
153
+
154
+ if args.dist_url == "env://" and args.world_size == -1:
155
+ args.world_size = int(os.environ["WORLD_SIZE"])
156
+
157
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
158
+
159
+ ngpus_per_node = torch.cuda.device_count()
160
+ if args.multiprocessing_distributed:
161
+ # Since we have ngpus_per_node processes per node, the total world_size
162
+ # needs to be adjusted accordingly
163
+ args.world_size = ngpus_per_node * args.world_size
164
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
165
+ # main_worker process function
166
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
167
+ else:
168
+ # Simply call main_worker function
169
+ main_worker(args.gpu, ngpus_per_node, args)
170
+
171
+
172
+ def main_worker(gpu, ngpus_per_node, args):
173
+ global best_loss
174
+ args.gpu = gpu
175
+
176
+ if args.gpu is not None:
177
+ print("Use GPU: {} for training".format(args.gpu))
178
+
179
+ if args.distributed:
180
+ if args.dist_url == "env://" and args.rank == -1:
181
+ args.rank = int(os.environ["RANK"])
182
+ if args.multiprocessing_distributed:
183
+ # For multiprocessing distributed training, rank needs to be the
184
+ # global rank among all the processes
185
+ args.rank = args.rank * ngpus_per_node + gpu
186
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
187
+ world_size=args.world_size, rank=args.rank)
188
+ # create model
189
+ print("=> creating model")
190
+ model = vit(pretrained=True).cuda()
191
+ model.train()
192
+ print("done")
193
+
194
+ if not torch.cuda.is_available():
195
+ print('using CPU, this will be slow')
196
+ elif args.distributed:
197
+ # For multiprocessing distributed, DistributedDataParallel constructor
198
+ # should always set the single device scope, otherwise,
199
+ # DistributedDataParallel will use all available devices.
200
+ if args.gpu is not None:
201
+ torch.cuda.set_device(args.gpu)
202
+ model.cuda(args.gpu)
203
+ # When using a single GPU per process and per
204
+ # DistributedDataParallel, we need to divide the batch size
205
+ # ourselves based on the total number of GPUs we have
206
+ args.batch_size = int(args.batch_size / ngpus_per_node)
207
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
208
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
209
+ else:
210
+ model.cuda()
211
+ # DistributedDataParallel will divide and allocate batch_size to all
212
+ # available GPUs if device_ids are not set
213
+ model = torch.nn.parallel.DistributedDataParallel(model)
214
+ elif args.gpu is not None:
215
+ torch.cuda.set_device(args.gpu)
216
+ model = model.cuda(args.gpu)
217
+ else:
218
+ # DataParallel will divide and allocate batch_size to all available GPUs
219
+ print("start")
220
+ model = torch.nn.DataParallel(model).cuda()
221
+
222
+ # define loss function (criterion) and optimizer
223
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
224
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
225
+
226
+ # optionally resume from a checkpoint
227
+ if args.resume:
228
+ if os.path.isfile(args.resume):
229
+ print("=> loading checkpoint '{}'".format(args.resume))
230
+ if args.gpu is None:
231
+ checkpoint = torch.load(args.resume)
232
+ else:
233
+ # Map model to be loaded to specified single gpu.
234
+ loc = 'cuda:{}'.format(args.gpu)
235
+ checkpoint = torch.load(args.resume, map_location=loc)
236
+ args.start_epoch = checkpoint['epoch']
237
+ best_loss = checkpoint['best_loss']
238
+ if args.gpu is not None:
239
+ # best_loss may be from a checkpoint from a different GPU
240
+ best_loss = best_loss.to(args.gpu)
241
+ model.load_state_dict(checkpoint['state_dict'])
242
+ optimizer.load_state_dict(checkpoint['optimizer'])
243
+ print("=> loaded checkpoint '{}' (epoch {})"
244
+ .format(args.resume, checkpoint['epoch']))
245
+ else:
246
+ print("=> no checkpoint found at '{}'".format(args.resume))
247
+
248
+ cudnn.benchmark = True
249
+
250
+ train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
251
+ num_samples=args.num_samples)
252
+
253
+ if args.distributed:
254
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
255
+ else:
256
+ train_sampler = None
257
+
258
+ train_loader = torch.utils.data.DataLoader(
259
+ train_dataset, batch_size=args.batch_size, shuffle=False,
260
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
261
+
262
+ val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
263
+ num_samples=1)
264
+
265
+ val_loader = torch.utils.data.DataLoader(
266
+ val_dataset, batch_size=5, shuffle=False,
267
+ num_workers=args.workers, pin_memory=True)
268
+
269
+ if args.evaluate:
270
+ validate(val_loader, model, criterion, 0, args)
271
+ return
272
+
273
+ for epoch in range(args.start_epoch, args.epochs):
274
+ if args.distributed:
275
+ train_sampler.set_epoch(epoch)
276
+ adjust_learning_rate(optimizer, epoch, args)
277
+
278
+ log_dir = os.path.join(args.experiment_folder, 'logs')
279
+ logger = SummaryWriter(log_dir=log_dir)
280
+ args.logger = logger
281
+
282
+ # train for one epoch
283
+ train(train_loader, model, criterion, optimizer, epoch, args)
284
+
285
+ # evaluate on validation set
286
+ loss1 = validate(val_loader, model, criterion, epoch, args)
287
+
288
+ # remember best acc@1 and save checkpoint
289
+ is_best = loss1 < best_loss
290
+ best_loss = min(loss1, best_loss)
291
+
292
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
293
+ and args.rank % ngpus_per_node == 0):
294
+ save_checkpoint({
295
+ 'epoch': epoch + 1,
296
+ 'state_dict': model.state_dict(),
297
+ 'best_loss': best_loss,
298
+ 'optimizer' : optimizer.state_dict(),
299
+ }, is_best, folder=args.experiment_folder)
300
+
301
+ def train(train_loader, model, criterion, optimizer, epoch, args):
302
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
303
+
304
+ losses = AverageMeter('Loss', ':.4e')
305
+ top1 = AverageMeter('Acc@1', ':6.2f')
306
+ top5 = AverageMeter('Acc@5', ':6.2f')
307
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
308
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
309
+ progress = ProgressMeter(
310
+ len(train_loader),
311
+ [losses, top1, top5, orig_top1, orig_top5],
312
+ prefix="Epoch: [{}]".format(epoch))
313
+
314
+ orig_model = vit(pretrained=True).cuda()
315
+ orig_model.eval()
316
+
317
+ # switch to train mode
318
+ model.train()
319
+
320
+ for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
321
+ if torch.cuda.is_available():
322
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
323
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
324
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
325
+
326
+
327
+ image_ten.requires_grad = True
328
+ output = model(image_ten)
329
+
330
+ # segmentation loss
331
+ batch_size = image_ten.shape[0]
332
+ index = class_name
333
+ if index == None:
334
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
335
+ index = torch.tensor(index)
336
+
337
+ one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
338
+ one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
339
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
340
+ one_hot = torch.sum(one_hot.to(image_ten.device) * output)
341
+ model.zero_grad()
342
+
343
+ relevance = torch.autograd.grad(one_hot, image_ten, retain_graph=True)[0]
344
+
345
+ reverse_seg_map = seg_map.clone()
346
+ reverse_seg_map[reverse_seg_map == 1] = -1
347
+ reverse_seg_map[reverse_seg_map == 0] = 1
348
+ reverse_seg_map[reverse_seg_map == -1] = 0
349
+ grad_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
350
+ segmentation_loss = grad_loss
351
+
352
+ # classification loss
353
+ with torch.no_grad():
354
+ output_orig = orig_model(image_ten)
355
+ if args.temperature != 1:
356
+ output = output / args.temperature
357
+ classification_loss = criterion(output, class_name.flatten())
358
+
359
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
360
+
361
+ # debugging output
362
+ if i % args.save_interval == 0:
363
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
364
+ for j in range(image_ten.shape[0]):
365
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
366
+ new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
367
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
368
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
369
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
370
+ cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
371
+
372
+ # measure accuracy and record loss
373
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
374
+ losses.update(loss.item(), image_ten.size(0))
375
+ top1.update(acc1[0], image_ten.size(0))
376
+ top5.update(acc5[0], image_ten.size(0))
377
+
378
+ # metrics for original vit
379
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
380
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
381
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
382
+
383
+ # compute gradient and do SGD step
384
+ optimizer.zero_grad()
385
+ loss.backward()
386
+ optimizer.step()
387
+
388
+ if i % args.print_freq == 0:
389
+ progress.display(i)
390
+ args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
391
+ epoch*len(train_loader)+i)
392
+ args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
393
+ epoch * len(train_loader) + i)
394
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
395
+ epoch * len(train_loader) + i)
396
+ args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
397
+ epoch * len(train_loader) + i)
398
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
399
+ epoch * len(train_loader) + i)
400
+ args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
401
+ epoch * len(train_loader) + i)
402
+ args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
403
+ epoch * len(train_loader) + i)
404
+
405
+
406
+ def validate(val_loader, model, criterion, epoch, args):
407
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
408
+
409
+ losses = AverageMeter('Loss', ':.4e')
410
+ top1 = AverageMeter('Acc@1', ':6.2f')
411
+ top5 = AverageMeter('Acc@5', ':6.2f')
412
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
413
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
414
+ progress = ProgressMeter(
415
+ len(val_loader),
416
+ [losses, top1, top5, orig_top1, orig_top5],
417
+ prefix="Epoch: [{}]".format(val_loader))
418
+
419
+ # switch to evaluate mode
420
+ model.eval()
421
+
422
+ orig_model = vit(pretrained=True).cuda()
423
+ orig_model.eval()
424
+
425
+ with torch.no_grad():
426
+ for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
427
+ if args.gpu is not None:
428
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
429
+ if torch.cuda.is_available():
430
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
431
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
432
+
433
+ with torch.enable_grad():
434
+ image_ten.requires_grad = True
435
+ output = model(image_ten)
436
+
437
+ # segmentation loss
438
+ batch_size = image_ten.shape[0]
439
+ index = class_name
440
+ if index == None:
441
+ index = np.argmax(output.cpu().data.numpy(), axis=-1)
442
+ index = torch.tensor(index)
443
+
444
+ one_hot = np.zeros((batch_size, output.shape[-1]), dtype=np.float32)
445
+ one_hot[torch.arange(batch_size), index.data.cpu().numpy()] = 1
446
+ one_hot = torch.from_numpy(one_hot).requires_grad_(True)
447
+ one_hot = torch.sum(one_hot.to(image_ten.device) * output)
448
+ model.zero_grad()
449
+ relevance = torch.autograd.grad(one_hot, image_ten)[0]
450
+
451
+ reverse_seg_map = seg_map.clone()
452
+ reverse_seg_map[reverse_seg_map == 1] = -1
453
+ reverse_seg_map[reverse_seg_map == 0] = 1
454
+ reverse_seg_map[reverse_seg_map == -1] = 0
455
+ grad_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
456
+ segmentation_loss = grad_loss
457
+
458
+ # classification loss
459
+ output = model(image_ten)
460
+ with torch.no_grad():
461
+ output_orig = orig_model(image_ten)
462
+ if args.temperature != 1:
463
+ output = output / args.temperature
464
+ classification_loss = criterion(output, class_name.flatten())
465
+
466
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
467
+
468
+ # save results
469
+ if i % args.save_interval == 0:
470
+ with torch.enable_grad():
471
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
472
+ for j in range(image_ten.shape[0]):
473
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
474
+ new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
475
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
476
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
477
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
478
+ cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
479
+
480
+ # measure accuracy and record loss
481
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
482
+ losses.update(loss.item(), image_ten.size(0))
483
+ top1.update(acc1[0], image_ten.size(0))
484
+ top5.update(acc5[0], image_ten.size(0))
485
+
486
+ # metrics for original vit
487
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
488
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
489
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
490
+
491
+ if i % args.print_freq == 0:
492
+ progress.display(i)
493
+ args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
494
+ epoch * len(val_loader) + i)
495
+ args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
496
+ epoch * len(val_loader) + i)
497
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
498
+ epoch * len(val_loader) + i)
499
+ args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
500
+ epoch * len(val_loader) + i)
501
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
502
+ epoch * len(val_loader) + i)
503
+ args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
504
+ epoch * len(val_loader) + i)
505
+ args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
506
+ epoch * len(val_loader) + i)
507
+
508
+ # TODO: this should also be done with the ProgressMeter
509
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
510
+ .format(top1=top1, top5=top5))
511
+
512
+ return losses.avg
513
+
514
+
515
+ def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
516
+ torch.save(state, f'{folder}/{filename}')
517
+ if is_best:
518
+ shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
519
+
520
+
521
+ class AverageMeter(object):
522
+ """Computes and stores the average and current value"""
523
+ def __init__(self, name, fmt=':f'):
524
+ self.name = name
525
+ self.fmt = fmt
526
+ self.reset()
527
+
528
+ def reset(self):
529
+ self.val = 0
530
+ self.avg = 0
531
+ self.sum = 0
532
+ self.count = 0
533
+
534
+ def update(self, val, n=1):
535
+ self.val = val
536
+ self.sum += val * n
537
+ self.count += n
538
+ self.avg = self.sum / self.count
539
+
540
+ def __str__(self):
541
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
542
+ return fmtstr.format(**self.__dict__)
543
+
544
+
545
+ class ProgressMeter(object):
546
+ def __init__(self, num_batches, meters, prefix=""):
547
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
548
+ self.meters = meters
549
+ self.prefix = prefix
550
+
551
+ def display(self, batch):
552
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
553
+ entries += [str(meter) for meter in self.meters]
554
+ print('\t'.join(entries))
555
+
556
+ def _get_batch_fmtstr(self, num_batches):
557
+ num_digits = len(str(num_batches // 1))
558
+ fmt = '{:' + str(num_digits) + 'd}'
559
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
560
+
561
+ def adjust_learning_rate(optimizer, epoch, args):
562
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
563
+ lr = args.lr * (0.85 ** (epoch // 2))
564
+ for param_group in optimizer.param_groups:
565
+ param_group['lr'] = lr
566
+
567
+
568
+ def accuracy(output, target, topk=(1,)):
569
+ """Computes the accuracy over the k top predictions for the specified values of k"""
570
+ with torch.no_grad():
571
+ maxk = max(topk)
572
+ batch_size = target.size(0)
573
+
574
+ _, pred = output.topk(maxk, 1, True, True)
575
+ pred = pred.t()
576
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
577
+
578
+ res = []
579
+ for k in topk:
580
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
581
+ res.append(correct_k.mul_(100.0 / batch_size))
582
+ return res
583
+
584
+
585
+ if __name__ == '__main__':
586
+ main()
imagenet_finetune_rrr.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+ import torch.nn.functional as F
21
+ from segmentation_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
22
+ import numpy as np
23
+
24
+ # Uncomment the expected model below
25
+
26
+ # ViT
27
+ from ViT.ViT import vit_base_patch16_224 as vit
28
+ # from ViT.ViT import vit_large_patch16_224 as vit
29
+
30
+ # ViT-AugReg
31
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
32
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
33
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
34
+
35
+ # DeiT
36
+ # from ViT.ViT import deit_base_patch16_224 as vit
37
+ # from ViT.ViT import deit_small_patch16_224 as vit
38
+
39
+ from ViT.explainer import generate_relevance, get_image_with_relevance
40
+ import torchvision
41
+ import cv2
42
+ from torch.utils.tensorboard import SummaryWriter
43
+ import json
44
+
45
+ model_names = sorted(name for name in models.__dict__
46
+ if name.islower() and not name.startswith("__")
47
+ and callable(models.__dict__[name]))
48
+ model_names.append("vit")
49
+
50
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
51
+ parser.add_argument('--data', metavar='DATA',
52
+ help='path to dataset')
53
+ parser.add_argument('--seg_data', metavar='SEG_DATA',
54
+ help='path to segmentation dataset')
55
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
56
+ choices=model_names,
57
+ help='model architecture: ' +
58
+ ' | '.join(model_names) +
59
+ ' (default: resnet18)')
60
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
61
+ help='number of data loading workers (default: 4)')
62
+ parser.add_argument('--epochs', default=50, type=int, metavar='N',
63
+ help='number of total epochs to run')
64
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
65
+ help='manual epoch number (useful on restarts)')
66
+ parser.add_argument('-b', '--batch-size', default=8, type=int,
67
+ metavar='N',
68
+ help='mini-batch size (default: 256), this is the total '
69
+ 'batch size of all GPUs on the current node when '
70
+ 'using Data Parallel or Distributed Data Parallel')
71
+ parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
72
+ metavar='LR', help='initial learning rate', dest='lr')
73
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
74
+ help='momentum')
75
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
76
+ metavar='W', help='weight decay (default: 1e-4)',
77
+ dest='weight_decay')
78
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
79
+ metavar='N', help='print frequency (default: 10)')
80
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
81
+ help='path to latest checkpoint (default: none)')
82
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
83
+ help='evaluate model on validation set')
84
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
85
+ help='use pre-trained model')
86
+ parser.add_argument('--world-size', default=-1, type=int,
87
+ help='number of nodes for distributed training')
88
+ parser.add_argument('--rank', default=-1, type=int,
89
+ help='node rank for distributed training')
90
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
91
+ help='url used to set up distributed training')
92
+ parser.add_argument('--dist-backend', default='nccl', type=str,
93
+ help='distributed backend')
94
+ parser.add_argument('--seed', default=None, type=int,
95
+ help='seed for initializing training. ')
96
+ parser.add_argument('--gpu', default=None, type=int,
97
+ help='GPU id to use.')
98
+ parser.add_argument('--save_interval', default=20, type=int,
99
+ help='interval to save segmentation results.')
100
+ parser.add_argument('--num_samples', default=3, type=int,
101
+ help='number of samples per class for training')
102
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
103
+ help='Use multi-processing distributed training to launch '
104
+ 'N processes per node, which has N GPUs. This is the '
105
+ 'fastest way to use PyTorch for either single node or '
106
+ 'multi node data parallel training')
107
+ parser.add_argument('--lambda_seg', default=0.8, type=float,
108
+ help='influence of segmentation loss.')
109
+ parser.add_argument('--lambda_acc', default=0.2, type=float,
110
+ help='influence of accuracy loss.')
111
+ parser.add_argument('--experiment_folder', default=None, type=str,
112
+ help='path to folder to use for experiment.')
113
+ parser.add_argument('--num_classes', default=500, type=int,
114
+ help='coefficient of loss for segmentation foreground.')
115
+ parser.add_argument('--temperature', default=1, type=float,
116
+ help='temperature for softmax (mostly for DeiT).')
117
+
118
+ best_loss = float('inf')
119
+
120
+ def main():
121
+ args = parser.parse_args()
122
+
123
+ if args.experiment_folder is None:
124
+ args.experiment_folder = f'experiment/' \
125
+ f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}'
126
+ if args.temperature != 1:
127
+ args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
128
+ if args.batch_size != 8:
129
+ args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
130
+ if args.num_classes != 500:
131
+ args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
132
+ if args.num_samples != 3:
133
+ args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
134
+ if args.epochs != 150:
135
+ args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
136
+
137
+ if os.path.exists(args.experiment_folder):
138
+ raise Exception(f"Experiment path {args.experiment_folder} already exists!")
139
+ os.mkdir(args.experiment_folder)
140
+ os.mkdir(f'{args.experiment_folder}/train_samples')
141
+ os.mkdir(f'{args.experiment_folder}/val_samples')
142
+
143
+ with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
144
+ json.dump(args.__dict__, f, indent=2)
145
+
146
+ if args.seed is not None:
147
+ random.seed(args.seed)
148
+ torch.manual_seed(args.seed)
149
+ cudnn.deterministic = True
150
+ warnings.warn('You have chosen to seed training. '
151
+ 'This will turn on the CUDNN deterministic setting, '
152
+ 'which can slow down your training considerably! '
153
+ 'You may see unexpected behavior when restarting '
154
+ 'from checkpoints.')
155
+
156
+ if args.gpu is not None:
157
+ warnings.warn('You have chosen a specific GPU. This will completely '
158
+ 'disable data parallelism.')
159
+
160
+ if args.dist_url == "env://" and args.world_size == -1:
161
+ args.world_size = int(os.environ["WORLD_SIZE"])
162
+
163
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
164
+
165
+ ngpus_per_node = torch.cuda.device_count()
166
+ if args.multiprocessing_distributed:
167
+ # Since we have ngpus_per_node processes per node, the total world_size
168
+ # needs to be adjusted accordingly
169
+ args.world_size = ngpus_per_node * args.world_size
170
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
171
+ # main_worker process function
172
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
173
+ else:
174
+ # Simply call main_worker function
175
+ main_worker(args.gpu, ngpus_per_node, args)
176
+
177
+
178
+ def main_worker(gpu, ngpus_per_node, args):
179
+ global best_loss
180
+ args.gpu = gpu
181
+
182
+ if args.gpu is not None:
183
+ print("Use GPU: {} for training".format(args.gpu))
184
+
185
+ if args.distributed:
186
+ if args.dist_url == "env://" and args.rank == -1:
187
+ args.rank = int(os.environ["RANK"])
188
+ if args.multiprocessing_distributed:
189
+ # For multiprocessing distributed training, rank needs to be the
190
+ # global rank among all the processes
191
+ args.rank = args.rank * ngpus_per_node + gpu
192
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
193
+ world_size=args.world_size, rank=args.rank)
194
+ # create model
195
+ print("=> creating model")
196
+ model = vit(pretrained=True).cuda()
197
+ model.train()
198
+ print("done")
199
+
200
+ if not torch.cuda.is_available():
201
+ print('using CPU, this will be slow')
202
+ elif args.distributed:
203
+ # For multiprocessing distributed, DistributedDataParallel constructor
204
+ # should always set the single device scope, otherwise,
205
+ # DistributedDataParallel will use all available devices.
206
+ if args.gpu is not None:
207
+ torch.cuda.set_device(args.gpu)
208
+ model.cuda(args.gpu)
209
+ # When using a single GPU per process and per
210
+ # DistributedDataParallel, we need to divide the batch size
211
+ # ourselves based on the total number of GPUs we have
212
+ args.batch_size = int(args.batch_size / ngpus_per_node)
213
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
214
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
215
+ else:
216
+ model.cuda()
217
+ # DistributedDataParallel will divide and allocate batch_size to all
218
+ # available GPUs if device_ids are not set
219
+ model = torch.nn.parallel.DistributedDataParallel(model)
220
+ elif args.gpu is not None:
221
+ torch.cuda.set_device(args.gpu)
222
+ model = model.cuda(args.gpu)
223
+ else:
224
+ # DataParallel will divide and allocate batch_size to all available GPUs
225
+ print("start")
226
+ model = torch.nn.DataParallel(model).cuda()
227
+
228
+ # define loss function (criterion) and optimizer
229
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
230
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
231
+
232
+ # optionally resume from a checkpoint
233
+ if args.resume:
234
+ if os.path.isfile(args.resume):
235
+ print("=> loading checkpoint '{}'".format(args.resume))
236
+ if args.gpu is None:
237
+ checkpoint = torch.load(args.resume)
238
+ else:
239
+ # Map model to be loaded to specified single gpu.
240
+ loc = 'cuda:{}'.format(args.gpu)
241
+ checkpoint = torch.load(args.resume, map_location=loc)
242
+ args.start_epoch = checkpoint['epoch']
243
+ best_loss = checkpoint['best_loss']
244
+ if args.gpu is not None:
245
+ # best_loss may be from a checkpoint from a different GPU
246
+ best_loss = best_loss.to(args.gpu)
247
+ model.load_state_dict(checkpoint['state_dict'])
248
+ optimizer.load_state_dict(checkpoint['optimizer'])
249
+ print("=> loaded checkpoint '{}' (epoch {})"
250
+ .format(args.resume, checkpoint['epoch']))
251
+ else:
252
+ print("=> no checkpoint found at '{}'".format(args.resume))
253
+
254
+ cudnn.benchmark = True
255
+
256
+ train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
257
+ num_samples=args.num_samples)
258
+
259
+ if args.distributed:
260
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
261
+ else:
262
+ train_sampler = None
263
+
264
+ train_loader = torch.utils.data.DataLoader(
265
+ train_dataset, batch_size=args.batch_size, shuffle=False,
266
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
267
+
268
+ val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
269
+ num_samples=1)
270
+
271
+ val_loader = torch.utils.data.DataLoader(
272
+ val_dataset, batch_size=5, shuffle=False,
273
+ num_workers=args.workers, pin_memory=True)
274
+
275
+ if args.evaluate:
276
+ validate(val_loader, model, criterion, 0, args)
277
+ return
278
+
279
+ for epoch in range(args.start_epoch, args.epochs):
280
+ if args.distributed:
281
+ train_sampler.set_epoch(epoch)
282
+ adjust_learning_rate(optimizer, epoch, args)
283
+
284
+ log_dir = os.path.join(args.experiment_folder, 'logs')
285
+ logger = SummaryWriter(log_dir=log_dir)
286
+ args.logger = logger
287
+
288
+ # train for one epoch
289
+ train(train_loader, model, criterion, optimizer, epoch, args)
290
+
291
+ # evaluate on validation set
292
+ loss1 = validate(val_loader, model, criterion, epoch, args)
293
+
294
+ # remember best acc@1 and save checkpoint
295
+ is_best = loss1 < best_loss
296
+ best_loss = min(loss1, best_loss)
297
+
298
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
299
+ and args.rank % ngpus_per_node == 0):
300
+ save_checkpoint({
301
+ 'epoch': epoch + 1,
302
+ 'state_dict': model.state_dict(),
303
+ 'best_loss': best_loss,
304
+ 'optimizer' : optimizer.state_dict(),
305
+ }, is_best, folder=args.experiment_folder)
306
+
307
+ def train(train_loader, model, criterion, optimizer, epoch, args):
308
+ losses = AverageMeter('Loss', ':.4e')
309
+ top1 = AverageMeter('Acc@1', ':6.2f')
310
+ top5 = AverageMeter('Acc@5', ':6.2f')
311
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
312
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
313
+ progress = ProgressMeter(
314
+ len(train_loader),
315
+ [losses, top1, top5, orig_top1, orig_top5],
316
+ prefix="Epoch: [{}]".format(epoch))
317
+
318
+ orig_model = vit(pretrained=True).cuda()
319
+ orig_model.eval()
320
+
321
+ # switch to train mode
322
+ model.train()
323
+
324
+ for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
325
+ if torch.cuda.is_available():
326
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
327
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
328
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
329
+
330
+
331
+ image_ten.requires_grad = True
332
+ output = model(image_ten)
333
+
334
+ # segmentation loss
335
+ EPS = 10e-12
336
+ y_pred = torch.sum(torch.log(F.softmax(output, dim=1) + EPS))
337
+ relevance = torch.autograd.grad(y_pred, image_ten, retain_graph=True)[0]
338
+ reverse_seg_map = seg_map.clone()
339
+ reverse_seg_map[reverse_seg_map == 1] = -1
340
+ reverse_seg_map[reverse_seg_map == 0] = 1
341
+ reverse_seg_map[reverse_seg_map == -1] = 0
342
+ rrr_loss = (relevance * reverse_seg_map)**2
343
+ segmentation_loss = rrr_loss.sum()
344
+
345
+ # classification loss
346
+ with torch.no_grad():
347
+ output_orig = orig_model(image_ten)
348
+ if args.temperature != 1:
349
+ output = output / args.temperature
350
+ classification_loss = criterion(output, class_name.flatten())
351
+
352
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
353
+
354
+ # debugging output
355
+ if i % args.save_interval == 0:
356
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
357
+ for j in range(image_ten.shape[0]):
358
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
359
+ new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
360
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
361
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
362
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
363
+ cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
364
+
365
+ # measure accuracy and record loss
366
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
367
+ losses.update(loss.item(), image_ten.size(0))
368
+ top1.update(acc1[0], image_ten.size(0))
369
+ top5.update(acc5[0], image_ten.size(0))
370
+
371
+ # metrics for original vit
372
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
373
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
374
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
375
+
376
+ # compute gradient and do SGD step
377
+ optimizer.zero_grad()
378
+ loss.backward()
379
+ optimizer.step()
380
+
381
+ if i % args.print_freq == 0:
382
+ progress.display(i)
383
+ args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
384
+ epoch*len(train_loader)+i)
385
+ args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
386
+ epoch * len(train_loader) + i)
387
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
388
+ epoch * len(train_loader) + i)
389
+ args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
390
+ epoch * len(train_loader) + i)
391
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
392
+ epoch * len(train_loader) + i)
393
+ args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
394
+ epoch * len(train_loader) + i)
395
+ args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
396
+ epoch * len(train_loader) + i)
397
+
398
+
399
+ def validate(val_loader, model, criterion, epoch, args):
400
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
401
+
402
+ losses = AverageMeter('Loss', ':.4e')
403
+ top1 = AverageMeter('Acc@1', ':6.2f')
404
+ top5 = AverageMeter('Acc@5', ':6.2f')
405
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
406
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
407
+ progress = ProgressMeter(
408
+ len(val_loader),
409
+ [losses, top1, top5, orig_top1, orig_top5],
410
+ prefix="Epoch: [{}]".format(val_loader))
411
+
412
+ # switch to evaluate mode
413
+ model.eval()
414
+
415
+ orig_model = vit(pretrained=True).cuda()
416
+ orig_model.eval()
417
+
418
+ with torch.no_grad():
419
+ for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
420
+ if args.gpu is not None:
421
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
422
+ if torch.cuda.is_available():
423
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
424
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
425
+
426
+ with torch.enable_grad():
427
+ image_ten.requires_grad = True
428
+ output = model(image_ten)
429
+
430
+ # segmentation loss
431
+ EPS = 10e-12
432
+ y_pred = torch.sum(torch.log(F.softmax(output, dim=1) + EPS))
433
+ relevance = torch.autograd.grad(y_pred, image_ten, retain_graph=True)[0]
434
+
435
+ reverse_seg_map = seg_map.clone()
436
+ reverse_seg_map[reverse_seg_map == 1] = -1
437
+ reverse_seg_map[reverse_seg_map == 0] = 1
438
+ reverse_seg_map[reverse_seg_map == -1] = 0
439
+ rrr_loss = (relevance * reverse_seg_map) ** 2
440
+ segmentation_loss = rrr_loss.sum()
441
+
442
+ # classification loss
443
+ output = model(image_ten)
444
+ with torch.no_grad():
445
+ output_orig = orig_model(image_ten)
446
+ if args.temperature != 1:
447
+ output = output / args.temperature
448
+ classification_loss = criterion(output, class_name.flatten())
449
+
450
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
451
+
452
+ # save results
453
+ if i % args.save_interval == 0:
454
+ with torch.enable_grad():
455
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
456
+ for j in range(image_ten.shape[0]):
457
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
458
+ new_vis = get_image_with_relevance(image_ten[j]*relevance[j], torch.ones_like(image_ten[j]))
459
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
460
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
461
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
462
+ cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
463
+
464
+ # measure accuracy and record loss
465
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
466
+ losses.update(loss.item(), image_ten.size(0))
467
+ top1.update(acc1[0], image_ten.size(0))
468
+ top5.update(acc5[0], image_ten.size(0))
469
+
470
+ # metrics for original vit
471
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
472
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
473
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
474
+
475
+ if i % args.print_freq == 0:
476
+ progress.display(i)
477
+ args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
478
+ epoch * len(val_loader) + i)
479
+ args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
480
+ epoch * len(val_loader) + i)
481
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
482
+ epoch * len(val_loader) + i)
483
+ args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
484
+ epoch * len(val_loader) + i)
485
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
486
+ epoch * len(val_loader) + i)
487
+ args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
488
+ epoch * len(val_loader) + i)
489
+ args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
490
+ epoch * len(val_loader) + i)
491
+
492
+ # TODO: this should also be done with the ProgressMeter
493
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
494
+ .format(top1=top1, top5=top5))
495
+
496
+ return losses.avg
497
+
498
+
499
+ def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
500
+ torch.save(state, f'{folder}/{filename}')
501
+ if is_best:
502
+ shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
503
+
504
+
505
+ class AverageMeter(object):
506
+ """Computes and stores the average and current value"""
507
+ def __init__(self, name, fmt=':f'):
508
+ self.name = name
509
+ self.fmt = fmt
510
+ self.reset()
511
+
512
+ def reset(self):
513
+ self.val = 0
514
+ self.avg = 0
515
+ self.sum = 0
516
+ self.count = 0
517
+
518
+ def update(self, val, n=1):
519
+ self.val = val
520
+ self.sum += val * n
521
+ self.count += n
522
+ self.avg = self.sum / self.count
523
+
524
+ def __str__(self):
525
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
526
+ return fmtstr.format(**self.__dict__)
527
+
528
+
529
+ class ProgressMeter(object):
530
+ def __init__(self, num_batches, meters, prefix=""):
531
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
532
+ self.meters = meters
533
+ self.prefix = prefix
534
+
535
+ def display(self, batch):
536
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
537
+ entries += [str(meter) for meter in self.meters]
538
+ print('\t'.join(entries))
539
+
540
+ def _get_batch_fmtstr(self, num_batches):
541
+ num_digits = len(str(num_batches // 1))
542
+ fmt = '{:' + str(num_digits) + 'd}'
543
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
544
+
545
+ def adjust_learning_rate(optimizer, epoch, args):
546
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
547
+ lr = args.lr * (0.85 ** (epoch // 2))
548
+ for param_group in optimizer.param_groups:
549
+ param_group['lr'] = lr
550
+
551
+
552
+ def accuracy(output, target, topk=(1,)):
553
+ """Computes the accuracy over the k top predictions for the specified values of k"""
554
+ with torch.no_grad():
555
+ maxk = max(topk)
556
+ batch_size = target.size(0)
557
+
558
+ _, pred = output.topk(maxk, 1, True, True)
559
+ pred = pred.t()
560
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
561
+
562
+ res = []
563
+ for k in topk:
564
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
565
+ res.append(correct_k.mul_(100.0 / batch_size))
566
+ return res
567
+
568
+
569
+ if __name__ == '__main__':
570
+ main()
imagenet_finetune_tokencut.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import shutil
5
+ import time
6
+ import warnings
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.parallel
11
+ import torch.backends.cudnn as cudnn
12
+ import torch.distributed as dist
13
+ import torch.optim
14
+ import torch.multiprocessing as mp
15
+ import torch.utils.data
16
+ import torch.utils.data.distributed
17
+ import torchvision.transforms as transforms
18
+ import torchvision.datasets as datasets
19
+ import torchvision.models as models
20
+ from tokencut_dataset import SegmentationDataset, VAL_PARTITION, TRAIN_PARTITION
21
+
22
+ # Uncomment the expected model below
23
+
24
+ # ViT
25
+ from ViT.ViT import vit_base_patch16_224 as vit
26
+ # from ViT.ViT import vit_large_patch16_224 as vit
27
+
28
+ # ViT-AugReg
29
+ # from ViT.ViT_new import vit_small_patch16_224 as vit
30
+ # from ViT.ViT_new import vit_base_patch16_224 as vit
31
+ # from ViT.ViT_new import vit_large_patch16_224 as vit
32
+
33
+ # DeiT
34
+ # from ViT.ViT import deit_base_patch16_224 as vit
35
+ # from ViT.ViT import deit_small_patch16_224 as vit
36
+
37
+ from ViT.explainer import generate_relevance, get_image_with_relevance
38
+ import torchvision
39
+ import cv2
40
+ from torch.utils.tensorboard import SummaryWriter
41
+ import json
42
+
43
+ model_names = sorted(name for name in models.__dict__
44
+ if name.islower() and not name.startswith("__")
45
+ and callable(models.__dict__[name]))
46
+ model_names.append("vit")
47
+
48
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
49
+ parser.add_argument('--data', metavar='DATA',
50
+ help='path to dataset')
51
+ parser.add_argument('--seg_data', metavar='SEG_DATA',
52
+ help='path to segmentation dataset')
53
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
54
+ help='number of data loading workers (default: 4)')
55
+ parser.add_argument('--epochs', default=150, type=int, metavar='N',
56
+ help='number of total epochs to run')
57
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
58
+ help='manual epoch number (useful on restarts)')
59
+ parser.add_argument('-b', '--batch-size', default=10, type=int,
60
+ metavar='N',
61
+ help='mini-batch size (default: 256), this is the total '
62
+ 'batch size of all GPUs on the current node when '
63
+ 'using Data Parallel or Distributed Data Parallel')
64
+ parser.add_argument('--lr', '--learning-rate', default=3e-6, type=float,
65
+ metavar='LR', help='initial learning rate', dest='lr')
66
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
67
+ help='momentum')
68
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
69
+ metavar='W', help='weight decay (default: 1e-4)',
70
+ dest='weight_decay')
71
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
72
+ metavar='N', help='print frequency (default: 10)')
73
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
74
+ help='path to latest checkpoint (default: none)')
75
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
76
+ help='evaluate model on validation set')
77
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
78
+ help='use pre-trained model')
79
+ parser.add_argument('--world-size', default=-1, type=int,
80
+ help='number of nodes for distributed training')
81
+ parser.add_argument('--rank', default=-1, type=int,
82
+ help='node rank for distributed training')
83
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
84
+ help='url used to set up distributed training')
85
+ parser.add_argument('--dist-backend', default='nccl', type=str,
86
+ help='distributed backend')
87
+ parser.add_argument('--seed', default=None, type=int,
88
+ help='seed for initializing training. ')
89
+ parser.add_argument('--gpu', default=None, type=int,
90
+ help='GPU id to use.')
91
+ parser.add_argument('--save_interval', default=20, type=int,
92
+ help='interval to save segmentation results.')
93
+ parser.add_argument('--num_samples', default=3, type=int,
94
+ help='number of samples per class for training')
95
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
96
+ help='Use multi-processing distributed training to launch '
97
+ 'N processes per node, which has N GPUs. This is the '
98
+ 'fastest way to use PyTorch for either single node or '
99
+ 'multi node data parallel training')
100
+ parser.add_argument('--lambda_seg', default=0.1, type=float,
101
+ help='influence of segmentation loss.')
102
+ parser.add_argument('--lambda_acc', default=1, type=float,
103
+ help='influence of accuracy loss.')
104
+ parser.add_argument('--experiment_folder', default=None, type=str,
105
+ help='path to folder to use for experiment.')
106
+ parser.add_argument('--dilation', default=0, type=float,
107
+ help='Use dilation on the segmentation maps.')
108
+ parser.add_argument('--lambda_background', default=1, type=float,
109
+ help='coefficient of loss for segmentation background.')
110
+ parser.add_argument('--lambda_foreground', default=0.3, type=float,
111
+ help='coefficient of loss for segmentation foreground.')
112
+ parser.add_argument('--num_classes', default=500, type=int,
113
+ help='coefficient of loss for segmentation foreground.')
114
+ parser.add_argument('--temperature', default=1, type=float,
115
+ help='temperature for softmax (mostly for DeiT).')
116
+
117
+ best_loss = float('inf')
118
+
119
+ def main():
120
+ args = parser.parse_args()
121
+
122
+ if args.experiment_folder is None:
123
+ args.experiment_folder = f'experiment/' \
124
+ f'lr_{args.lr}_seg_{args.lambda_seg}_acc_{args.lambda_acc}' \
125
+ f'_bckg_{args.lambda_background}_fgd_{args.lambda_foreground}'
126
+ if args.temperature != 1:
127
+ args.experiment_folder = args.experiment_folder + f'_tempera_{args.temperature}'
128
+ if args.batch_size != 8:
129
+ args.experiment_folder = args.experiment_folder + f'_bs_{args.batch_size}'
130
+ if args.num_classes != 500:
131
+ args.experiment_folder = args.experiment_folder + f'_num_classes_{args.num_classes}'
132
+ if args.num_samples != 3:
133
+ args.experiment_folder = args.experiment_folder + f'_num_samples_{args.num_samples}'
134
+ if args.epochs != 150:
135
+ args.experiment_folder = args.experiment_folder + f'_num_epochs_{args.epochs}'
136
+
137
+ if os.path.exists(args.experiment_folder):
138
+ raise Exception(f"Experiment path {args.experiment_folder} already exists!")
139
+ os.mkdir(args.experiment_folder)
140
+ os.mkdir(f'{args.experiment_folder}/train_samples')
141
+ os.mkdir(f'{args.experiment_folder}/val_samples')
142
+
143
+ with open(f'{args.experiment_folder}/commandline_args.txt', 'w') as f:
144
+ json.dump(args.__dict__, f, indent=2)
145
+
146
+ if args.seed is not None:
147
+ random.seed(args.seed)
148
+ torch.manual_seed(args.seed)
149
+ cudnn.deterministic = True
150
+ warnings.warn('You have chosen to seed training. '
151
+ 'This will turn on the CUDNN deterministic setting, '
152
+ 'which can slow down your training considerably! '
153
+ 'You may see unexpected behavior when restarting '
154
+ 'from checkpoints.')
155
+
156
+ if args.gpu is not None:
157
+ warnings.warn('You have chosen a specific GPU. This will completely '
158
+ 'disable data parallelism.')
159
+
160
+ if args.dist_url == "env://" and args.world_size == -1:
161
+ args.world_size = int(os.environ["WORLD_SIZE"])
162
+
163
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
164
+
165
+ ngpus_per_node = torch.cuda.device_count()
166
+ if args.multiprocessing_distributed:
167
+ # Since we have ngpus_per_node processes per node, the total world_size
168
+ # needs to be adjusted accordingly
169
+ args.world_size = ngpus_per_node * args.world_size
170
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
171
+ # main_worker process function
172
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
173
+ else:
174
+ # Simply call main_worker function
175
+ main_worker(args.gpu, ngpus_per_node, args)
176
+
177
+
178
+ def main_worker(gpu, ngpus_per_node, args):
179
+ global best_loss
180
+ args.gpu = gpu
181
+
182
+ if args.gpu is not None:
183
+ print("Use GPU: {} for training".format(args.gpu))
184
+
185
+ if args.distributed:
186
+ if args.dist_url == "env://" and args.rank == -1:
187
+ args.rank = int(os.environ["RANK"])
188
+ if args.multiprocessing_distributed:
189
+ # For multiprocessing distributed training, rank needs to be the
190
+ # global rank among all the processes
191
+ args.rank = args.rank * ngpus_per_node + gpu
192
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
193
+ world_size=args.world_size, rank=args.rank)
194
+ # create model
195
+ print("=> creating model")
196
+ model = vit(pretrained=True).cuda()
197
+ model.train()
198
+ print("done")
199
+
200
+ if not torch.cuda.is_available():
201
+ print('using CPU, this will be slow')
202
+ elif args.distributed:
203
+ # For multiprocessing distributed, DistributedDataParallel constructor
204
+ # should always set the single device scope, otherwise,
205
+ # DistributedDataParallel will use all available devices.
206
+ if args.gpu is not None:
207
+ torch.cuda.set_device(args.gpu)
208
+ model.cuda(args.gpu)
209
+ # When using a single GPU per process and per
210
+ # DistributedDataParallel, we need to divide the batch size
211
+ # ourselves based on the total number of GPUs we have
212
+ args.batch_size = int(args.batch_size / ngpus_per_node)
213
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
214
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
215
+ else:
216
+ model.cuda()
217
+ # DistributedDataParallel will divide and allocate batch_size to all
218
+ # available GPUs if device_ids are not set
219
+ model = torch.nn.parallel.DistributedDataParallel(model)
220
+ elif args.gpu is not None:
221
+ torch.cuda.set_device(args.gpu)
222
+ model = model.cuda(args.gpu)
223
+ else:
224
+ # DataParallel will divide and allocate batch_size to all available GPUs
225
+ print("start")
226
+ model = torch.nn.DataParallel(model).cuda()
227
+
228
+ # define loss function (criterion) and optimizer
229
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
230
+ optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
231
+
232
+ # optionally resume from a checkpoint
233
+ if args.resume:
234
+ if os.path.isfile(args.resume):
235
+ print("=> loading checkpoint '{}'".format(args.resume))
236
+ if args.gpu is None:
237
+ checkpoint = torch.load(args.resume)
238
+ else:
239
+ # Map model to be loaded to specified single gpu.
240
+ loc = 'cuda:{}'.format(args.gpu)
241
+ checkpoint = torch.load(args.resume, map_location=loc)
242
+ args.start_epoch = checkpoint['epoch']
243
+ best_loss = checkpoint['best_loss']
244
+ if args.gpu is not None:
245
+ # best_loss may be from a checkpoint from a different GPU
246
+ best_loss = best_loss.to(args.gpu)
247
+ model.load_state_dict(checkpoint['state_dict'])
248
+ optimizer.load_state_dict(checkpoint['optimizer'])
249
+ print("=> loaded checkpoint '{}' (epoch {})"
250
+ .format(args.resume, checkpoint['epoch']))
251
+ else:
252
+ print("=> no checkpoint found at '{}'".format(args.resume))
253
+
254
+ cudnn.benchmark = True
255
+
256
+ train_dataset = SegmentationDataset(args.seg_data, args.data, partition=TRAIN_PARTITION, train_classes=args.num_classes,
257
+ num_samples=args.num_samples)
258
+
259
+ if args.distributed:
260
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
261
+ else:
262
+ train_sampler = None
263
+
264
+ train_loader = torch.utils.data.DataLoader(
265
+ train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
266
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
267
+
268
+ val_dataset = SegmentationDataset(args.seg_data, args.data, partition=VAL_PARTITION, train_classes=args.num_classes,
269
+ num_samples=1)
270
+
271
+ val_loader = torch.utils.data.DataLoader(
272
+ val_dataset, batch_size=10, shuffle=False,
273
+ num_workers=args.workers, pin_memory=True)
274
+
275
+ if args.evaluate:
276
+ validate(val_loader, model, criterion, 0, args)
277
+ return
278
+
279
+ for epoch in range(args.start_epoch, args.epochs):
280
+ if args.distributed:
281
+ train_sampler.set_epoch(epoch)
282
+ adjust_learning_rate(optimizer, epoch, args)
283
+
284
+ log_dir = os.path.join(args.experiment_folder, 'logs')
285
+ logger = SummaryWriter(log_dir=log_dir)
286
+ args.logger = logger
287
+
288
+ # train for one epoch
289
+ train(train_loader, model, criterion, optimizer, epoch, args)
290
+
291
+ # evaluate on validation set
292
+ loss1 = validate(val_loader, model, criterion, epoch, args)
293
+
294
+ # remember best acc@1 and save checkpoint
295
+ is_best = loss1 <= best_loss
296
+ best_loss = min(loss1, best_loss)
297
+
298
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
299
+ and args.rank % ngpus_per_node == 0):
300
+ save_checkpoint({
301
+ 'epoch': epoch + 1,
302
+ 'state_dict': model.state_dict(),
303
+ 'best_loss': best_loss,
304
+ 'optimizer' : optimizer.state_dict(),
305
+ }, is_best, folder=args.experiment_folder)
306
+
307
+
308
+ def train(train_loader, model, criterion, optimizer, epoch, args):
309
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
310
+ losses = AverageMeter('Loss', ':.4e')
311
+ top1 = AverageMeter('Acc@1', ':6.2f')
312
+ top5 = AverageMeter('Acc@5', ':6.2f')
313
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
314
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
315
+ progress = ProgressMeter(
316
+ len(train_loader),
317
+ # [batch_time, data_time, losses, top1, top5, orig_top1, orig_top5],
318
+ [losses, top1, top5, orig_top1, orig_top5],
319
+ prefix="Epoch: [{}]".format(epoch))
320
+
321
+ orig_model = vit(pretrained=True).cuda()
322
+ orig_model.eval()
323
+
324
+ # switch to train mode
325
+ model.train()
326
+
327
+ end = time.time()
328
+ for i, (seg_map, image_ten, class_name) in enumerate(train_loader):
329
+
330
+ if torch.cuda.is_available():
331
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
332
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
333
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
334
+
335
+ # compute output
336
+
337
+ # segmentation loss
338
+ relevance = generate_relevance(model, image_ten, index=class_name)
339
+
340
+ reverse_seg_map = seg_map.clone()
341
+ reverse_seg_map[reverse_seg_map == 1] = -1
342
+ reverse_seg_map[reverse_seg_map == 0] = 1
343
+ reverse_seg_map[reverse_seg_map == -1] = 0
344
+ background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
345
+ foreground_loss = mse_criterion(relevance * seg_map, seg_map)
346
+ segmentation_loss = args.lambda_background * background_loss
347
+ segmentation_loss += args.lambda_foreground * foreground_loss
348
+
349
+ # classification loss
350
+ output = model(image_ten)
351
+ with torch.no_grad():
352
+ output_orig = orig_model(image_ten)
353
+
354
+ _, pred = output.topk(1, 1, True, True)
355
+ pred = pred.flatten()
356
+ if args.temperature != 1:
357
+ output = output / args.temperature
358
+ classification_loss = criterion(output, pred)
359
+
360
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
361
+
362
+ # debugging output
363
+ if i % args.save_interval == 0:
364
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
365
+ for j in range(image_ten.shape[0]):
366
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
367
+ new_vis = get_image_with_relevance(image_ten[j], relevance[j])
368
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
369
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
370
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
371
+ cv2.imwrite(f'{args.experiment_folder}/train_samples/res_{i}_{j}.jpg', h_img)
372
+
373
+ # measure accuracy and record loss
374
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
375
+ losses.update(loss.item(), image_ten.size(0))
376
+ top1.update(acc1[0], image_ten.size(0))
377
+ top5.update(acc5[0], image_ten.size(0))
378
+
379
+ # metrics for original vit
380
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
381
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
382
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
383
+
384
+ # compute gradient and do SGD step
385
+ optimizer.zero_grad()
386
+ loss.backward()
387
+ optimizer.step()
388
+
389
+ if i % args.print_freq == 0:
390
+ progress.display(i)
391
+ args.logger.add_scalar('{}/{}'.format('train', 'segmentation_loss'), segmentation_loss,
392
+ epoch*len(train_loader)+i)
393
+ args.logger.add_scalar('{}/{}'.format('train', 'classification_loss'), classification_loss,
394
+ epoch * len(train_loader) + i)
395
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top1'), acc1_orig,
396
+ epoch * len(train_loader) + i)
397
+ args.logger.add_scalar('{}/{}'.format('train', 'top1'), acc1,
398
+ epoch * len(train_loader) + i)
399
+ args.logger.add_scalar('{}/{}'.format('train', 'orig_top5'), acc5_orig,
400
+ epoch * len(train_loader) + i)
401
+ args.logger.add_scalar('{}/{}'.format('train', 'top5'), acc5,
402
+ epoch * len(train_loader) + i)
403
+ args.logger.add_scalar('{}/{}'.format('train', 'tot_loss'), loss,
404
+ epoch * len(train_loader) + i)
405
+
406
+
407
+ def validate(val_loader, model, criterion, epoch, args):
408
+ mse_criterion = torch.nn.MSELoss(reduction='mean')
409
+ losses = AverageMeter('Loss', ':.4e')
410
+ top1 = AverageMeter('Acc@1', ':6.2f')
411
+ top5 = AverageMeter('Acc@5', ':6.2f')
412
+ orig_top1 = AverageMeter('Acc@1_orig', ':6.2f')
413
+ orig_top5 = AverageMeter('Acc@5_orig', ':6.2f')
414
+ progress = ProgressMeter(
415
+ len(val_loader),
416
+ [losses, top1, top5, orig_top1, orig_top5],
417
+ prefix="Epoch: [{}]".format(val_loader))
418
+
419
+ # switch to evaluate mode
420
+ model.eval()
421
+
422
+ orig_model = vit(pretrained=True).cuda()
423
+ orig_model.eval()
424
+
425
+ with torch.no_grad():
426
+ end = time.time()
427
+ for i, (seg_map, image_ten, class_name) in enumerate(val_loader):
428
+ if args.gpu is not None:
429
+ image_ten = image_ten.cuda(args.gpu, non_blocking=True)
430
+ if torch.cuda.is_available():
431
+ seg_map = seg_map.cuda(args.gpu, non_blocking=True)
432
+ class_name = class_name.cuda(args.gpu, non_blocking=True)
433
+
434
+ # segmentation loss
435
+ with torch.enable_grad():
436
+ relevance = generate_relevance(model, image_ten, index=class_name)
437
+
438
+ reverse_seg_map = seg_map.clone()
439
+ reverse_seg_map[reverse_seg_map == 1] = -1
440
+ reverse_seg_map[reverse_seg_map == 0] = 1
441
+ reverse_seg_map[reverse_seg_map == -1] = 0
442
+ background_loss = mse_criterion(relevance * reverse_seg_map, torch.zeros_like(relevance))
443
+ foreground_loss = mse_criterion(relevance * seg_map, seg_map)
444
+ segmentation_loss = args.lambda_background * background_loss
445
+ segmentation_loss += args.lambda_foreground * foreground_loss
446
+
447
+ # classification loss
448
+ with torch.no_grad():
449
+ output = model(image_ten)
450
+ output_orig = orig_model(image_ten)
451
+
452
+ _, pred = output.topk(1, 1, True, True)
453
+ pred = pred.flatten()
454
+ if args.temperature != 1:
455
+ output = output / args.temperature
456
+ classification_loss = criterion(output, pred)
457
+ loss = args.lambda_seg * segmentation_loss + args.lambda_acc * classification_loss
458
+
459
+ # save results
460
+ if i % args.save_interval == 0:
461
+ with torch.enable_grad():
462
+ orig_relevance = generate_relevance(orig_model, image_ten, index=class_name)
463
+ for j in range(image_ten.shape[0]):
464
+ image = get_image_with_relevance(image_ten[j], torch.ones_like(image_ten[j]))
465
+ new_vis = get_image_with_relevance(image_ten[j], relevance[j])
466
+ old_vis = get_image_with_relevance(image_ten[j], orig_relevance[j])
467
+ gt = get_image_with_relevance(image_ten[j], seg_map[j])
468
+ h_img = cv2.hconcat([image, gt, old_vis, new_vis])
469
+ cv2.imwrite(f'{args.experiment_folder}/val_samples/res_{i}_{j}.jpg', h_img)
470
+
471
+ # measure accuracy and record loss
472
+ acc1, acc5 = accuracy(output, class_name, topk=(1, 5))
473
+ losses.update(loss.item(), image_ten.size(0))
474
+ top1.update(acc1[0], image_ten.size(0))
475
+ top5.update(acc5[0], image_ten.size(0))
476
+
477
+ # metrics for original vit
478
+ acc1_orig, acc5_orig = accuracy(output_orig, class_name, topk=(1, 5))
479
+ orig_top1.update(acc1_orig[0], image_ten.size(0))
480
+ orig_top5.update(acc5_orig[0], image_ten.size(0))
481
+
482
+ if i % args.print_freq == 0:
483
+ progress.display(i)
484
+ args.logger.add_scalar('{}/{}'.format('val', 'segmentation_loss'), segmentation_loss,
485
+ epoch * len(val_loader) + i)
486
+ args.logger.add_scalar('{}/{}'.format('val', 'classification_loss'), classification_loss,
487
+ epoch * len(val_loader) + i)
488
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top1'), acc1_orig,
489
+ epoch * len(val_loader) + i)
490
+ args.logger.add_scalar('{}/{}'.format('val', 'top1'), acc1,
491
+ epoch * len(val_loader) + i)
492
+ args.logger.add_scalar('{}/{}'.format('val', 'orig_top5'), acc5_orig,
493
+ epoch * len(val_loader) + i)
494
+ args.logger.add_scalar('{}/{}'.format('val', 'top5'), acc5,
495
+ epoch * len(val_loader) + i)
496
+ args.logger.add_scalar('{}/{}'.format('val', 'tot_loss'), loss,
497
+ epoch * len(val_loader) + i)
498
+
499
+ # TODO: this should also be done with the ProgressMeter
500
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
501
+ .format(top1=top1, top5=top5))
502
+
503
+ return losses.avg
504
+
505
+
506
+ def save_checkpoint(state, is_best, folder, filename='checkpoint.pth.tar'):
507
+ torch.save(state, f'{folder}/{filename}')
508
+ if is_best:
509
+ shutil.copyfile(f'{folder}/{filename}', f'{folder}/model_best.pth.tar')
510
+
511
+
512
+ class AverageMeter(object):
513
+ """Computes and stores the average and current value"""
514
+ def __init__(self, name, fmt=':f'):
515
+ self.name = name
516
+ self.fmt = fmt
517
+ self.reset()
518
+
519
+ def reset(self):
520
+ self.val = 0
521
+ self.avg = 0
522
+ self.sum = 0
523
+ self.count = 0
524
+
525
+ def update(self, val, n=1):
526
+ self.val = val
527
+ self.sum += val * n
528
+ self.count += n
529
+ self.avg = self.sum / self.count
530
+
531
+ def __str__(self):
532
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
533
+ return fmtstr.format(**self.__dict__)
534
+
535
+
536
+ class ProgressMeter(object):
537
+ def __init__(self, num_batches, meters, prefix=""):
538
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
539
+ self.meters = meters
540
+ self.prefix = prefix
541
+
542
+ def display(self, batch):
543
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
544
+ entries += [str(meter) for meter in self.meters]
545
+ print('\t'.join(entries))
546
+
547
+ def _get_batch_fmtstr(self, num_batches):
548
+ num_digits = len(str(num_batches // 1))
549
+ fmt = '{:' + str(num_digits) + 'd}'
550
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
551
+
552
+ def adjust_learning_rate(optimizer, epoch, args):
553
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
554
+ lr = args.lr * (0.85 ** (epoch // 2))
555
+ for param_group in optimizer.param_groups:
556
+ param_group['lr'] = lr
557
+
558
+
559
+ def accuracy(output, target, topk=(1,)):
560
+ """Computes the accuracy over the k top predictions for the specified values of k"""
561
+ with torch.no_grad():
562
+ maxk = max(topk)
563
+ batch_size = target.size(0)
564
+
565
+ _, pred = output.topk(maxk, 1, True, True)
566
+ pred = pred.t()
567
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
568
+
569
+ res = []
570
+ for k in topk:
571
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
572
+ res.append(correct_k.mul_(100.0 / batch_size))
573
+ return res
574
+
575
+
576
+ if __name__ == '__main__':
577
+ main()
label_str_to_imagenet_classes.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Dictionary mapping labels (strings) to imagenet classes (ints).
16
+
17
+ Generated manually.
18
+ """
19
+
20
+ label_str_to_imagenet_classes = {
21
+ 'ambulance': 407,
22
+ 'armadillo': 363,
23
+ 'artichoke': 944,
24
+ 'backpack': 414,
25
+ 'bagel': 931,
26
+ 'balance beam': 416,
27
+ 'banana': 954,
28
+ 'band-aid': 419,
29
+ 'beaker': 438,
30
+ 'bell pepper': 945,
31
+ 'billiard table': 736,
32
+ 'binoculars': 447,
33
+ 'broccoli': 937,
34
+ 'brown bear': 294,
35
+ 'burrito': 965,
36
+ 'candle': 470,
37
+ 'canoe': 472,
38
+ 'cello': 486,
39
+ 'cheetah': 293,
40
+ 'cocktail shaker': 503,
41
+ 'common fig': 952,
42
+ 'computer mouse': 673,
43
+ 'cowboy hat': 515,
44
+ 'cucumber': 943,
45
+ 'diaper': 529,
46
+ 'digital clock': 530,
47
+ 'dumbbell': 543,
48
+ 'envelope': 549,
49
+ 'eraser': 767,
50
+ 'filing cabinet': 553,
51
+ 'flowerpot': 738,
52
+ 'flute': 558,
53
+ 'frying pan': 567,
54
+ 'golf ball': 574,
55
+ 'goose': 99,
56
+ 'guacamole': 924,
57
+ 'hair dryer': 589,
58
+ 'hair spray': 585,
59
+ 'hammer': 587,
60
+ 'hamster': 333,
61
+ 'harmonica': 593,
62
+ 'hedgehog': 334,
63
+ 'hippopotamus': 344,
64
+ 'hot dog': 934,
65
+ 'ipod': 605,
66
+ 'jeans': 608,
67
+ 'kite': 21,
68
+ 'koala': 105,
69
+ 'ladle': 618,
70
+ 'laptop': 620,
71
+ 'lemon': 951,
72
+ 'light switch': 844,
73
+ 'lighthouse': 437,
74
+ 'limousine': 627,
75
+ 'lipstick': 629,
76
+ 'lynx': 287,
77
+ 'magpie': 18,
78
+ 'maracas': 641,
79
+ 'measuring cup': 647,
80
+ 'microwave oven': 651,
81
+ 'miniskirt': 655,
82
+ 'missile': 657,
83
+ 'mixing bowl': 659,
84
+ 'mobile phone': 487,
85
+ 'mushroom': 947,
86
+ 'orange': 950,
87
+ 'ostrich': 9,
88
+ 'otter': 360,
89
+ 'paper towel': 700,
90
+ 'pencil case': 709,
91
+ 'pig': 341,
92
+ 'pillow': 721,
93
+ 'pitcher (container)': 725,
94
+ 'pizza': 963,
95
+ 'plastic bag': 728,
96
+ 'polar bear': 296,
97
+ 'pomegranate': 957,
98
+ 'pretzel': 932,
99
+ 'printer': 742,
100
+ 'punching bag': 747,
101
+ 'racket': 752,
102
+ 'red panda': 387,
103
+ 'remote control': 761,
104
+ 'rugby ball': 768,
105
+ 'ruler': 769,
106
+ 'saxophone': 776,
107
+ 'screwdriver': 784,
108
+ 'sea lion': 150,
109
+ 'seat belt': 785,
110
+ 'skunk': 361,
111
+ 'snowmobile': 802,
112
+ 'soap dispenser': 804,
113
+ 'sock': 806,
114
+ 'sombrero': 808,
115
+ 'spatula': 813,
116
+ 'starfish': 327,
117
+ 'strawberry': 949,
118
+ 'studio couch': 831,
119
+ 'taxi': 468,
120
+ 'teapot': 849,
121
+ 'teddy bear': 850,
122
+ 'tennis ball': 852,
123
+ 'toaster': 859,
124
+ 'toilet paper': 999,
125
+ 'torch': 862,
126
+ 'traffic light': 920,
127
+ 'vase': 883,
128
+ 'volleyball (ball)': 890,
129
+ 'washing machine': 897,
130
+ 'wok': 909,
131
+ 'zebra': 340,
132
+ 'zucchini': 939
133
+ }
objectnet_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torch.utils import data
3
+ from torchvision.datasets import ImageFolder
4
+ import torch
5
+ import os
6
+ from PIL import Image
7
+ import numpy as np
8
+ import argparse
9
+ from tqdm import tqdm
10
+ from munkres import Munkres
11
+ import multiprocessing
12
+ from multiprocessing import Process, Manager
13
+ import collections
14
+ import torchvision.transforms as transforms
15
+ import torchvision.transforms.functional as TF
16
+ import random
17
+ import torchvision
18
+ import cv2
19
+ from label_str_to_imagenet_classes import label_str_to_imagenet_classes
20
+
21
+ torch.manual_seed(0)
22
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
23
+ std=[0.5, 0.5, 0.5])
24
+
25
+ transform = transforms.Compose([
26
+ transforms.Resize(256),
27
+ transforms.CenterCrop(224),
28
+ transforms.ToTensor(),
29
+ normalize,
30
+ ])
31
+
32
+ class ObjectNetDataset(ImageFolder):
33
+ def __init__(self, imagenet_path):
34
+ self._imagenet_path = imagenet_path
35
+ self._all_images = []
36
+
37
+ o_dataset = ImageFolder(self._imagenet_path)
38
+ # get mappings folder
39
+ mappings_folder = os.path.abspath(
40
+ os.path.join(self._imagenet_path, "../mappings")
41
+ )
42
+
43
+ # get ObjectNet label to ImageNet label mapping
44
+ with open(
45
+ os.path.join(mappings_folder, "objectnet_to_imagenet_1k.json")
46
+ ) as file_handle:
47
+ o_label_to_all_i_labels = json.load(file_handle)
48
+
49
+ # now remove double i labels to avoid confusion
50
+ o_label_to_i_labels = {
51
+ o_label: all_i_label.split("; ")
52
+ for o_label, all_i_label in o_label_to_all_i_labels.items()
53
+ }
54
+
55
+ # some in-between mappings ...
56
+ o_folder_to_o_idx = o_dataset.class_to_idx
57
+ with open(
58
+ os.path.join(mappings_folder, "folder_to_objectnet_label.json")
59
+ ) as file_handle:
60
+ o_folder_o_label = json.load(file_handle)
61
+
62
+ # now get mapping from o_label to o_idx
63
+ o_label_to_o_idx = {
64
+ o_label: o_folder_to_o_idx[o_folder]
65
+ for o_folder, o_label in o_folder_o_label.items()
66
+ }
67
+
68
+ # some in-between mappings ...
69
+ with open(
70
+ os.path.join(mappings_folder, "pytorch_to_imagenet_2012_id.json")
71
+ ) as file_handle:
72
+ i_idx_to_i_line = json.load(file_handle)
73
+ with open(
74
+ os.path.join(mappings_folder, "imagenet_to_label_2012_v2")
75
+ ) as file_handle:
76
+ i_line_to_i_label = file_handle.readlines()
77
+
78
+ i_line_to_i_label = {
79
+ i_line: i_label[:-1]
80
+ for i_line, i_label in enumerate(i_line_to_i_label)
81
+ }
82
+
83
+ # now get mapping from i_label to i_idx
84
+ i_label_to_i_idx = {
85
+ i_line_to_i_label[i_line]: int(i_idx)
86
+ for i_idx, i_line in i_idx_to_i_line.items()
87
+ }
88
+
89
+ # now get the final mapping of interest!!!
90
+ o_idx_to_i_idxs = {
91
+ o_label_to_o_idx[o_label]: [
92
+ i_label_to_i_idx[i_label] for i_label in i_labels
93
+ ]
94
+ for o_label, i_labels in o_label_to_i_labels.items()
95
+ }
96
+
97
+ self._tag_list = []
98
+ # now get a list of files of interest
99
+ for filepath, o_idx in o_dataset.samples:
100
+ if o_idx not in o_idx_to_i_idxs:
101
+ continue
102
+ rel_file = os.path.relpath(filepath, self._imagenet_path)
103
+ if o_idx_to_i_idxs[o_idx][0] not in self._tag_list:
104
+ self._tag_list.append(o_idx_to_i_idxs[o_idx][0])
105
+ self._all_images.append((rel_file, o_idx_to_i_idxs[o_idx][0]))
106
+
107
+ def __getitem__(self, item):
108
+ image_path, classification = self._all_images[item]
109
+ image_path = os.path.join(self._imagenet_path, image_path)
110
+ image = Image.open(image_path)
111
+ image = image.convert('RGB')
112
+ image = transform(image)
113
+
114
+ return image, classification
115
+
116
+ def __len__(self):
117
+ return len(self._all_images)
robustness_dataset.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torch.utils import data
3
+ from torchvision.datasets import ImageFolder
4
+ import torch
5
+ import os
6
+ from PIL import Image
7
+ import numpy as np
8
+ import argparse
9
+ from tqdm import tqdm
10
+ from munkres import Munkres
11
+ import multiprocessing
12
+ from multiprocessing import Process, Manager
13
+ import collections
14
+ import torchvision.transforms as transforms
15
+ import torchvision.transforms.functional as TF
16
+ import random
17
+ import torchvision
18
+ import cv2
19
+ from label_str_to_imagenet_classes import label_str_to_imagenet_classes
20
+
21
+ torch.manual_seed(0)
22
+
23
+ ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
24
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
25
+ std=[0.5, 0.5, 0.5])
26
+
27
+ transform = transforms.Compose([
28
+ transforms.Resize(256),
29
+ transforms.CenterCrop(224),
30
+ transforms.ToTensor(),
31
+ normalize,
32
+ ])
33
+
34
+ class RobustnessDataset(ImageFolder):
35
+ def __init__(self, imagenet_path, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
36
+ self._isV2 = isV2
37
+ self._isSI = isSI
38
+ self._imagenet_path = imagenet_path
39
+ with open(imagenet_classes_path, 'r') as f:
40
+ self._imagenet_classes = json.load(f)
41
+ self._tag_list = [tag for tag in os.listdir(self._imagenet_path)]
42
+ self._all_images = []
43
+ for tag in self._tag_list:
44
+ base_dir = os.path.join(self._imagenet_path, tag)
45
+ for i, file in enumerate(os.listdir(base_dir)):
46
+ self._all_images.append(ImageItem(file, tag))
47
+
48
+
49
+ def __getitem__(self, item):
50
+ image_item = self._all_images[item]
51
+ image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
52
+ image = Image.open(image_path)
53
+ image = image.convert('RGB')
54
+ image = transform(image)
55
+
56
+ if self._isV2:
57
+ class_name = int(image_item.tag)
58
+ elif self._isSI:
59
+ class_name = int(label_str_to_imagenet_classes[image_item.tag])
60
+ else:
61
+ class_name = int(self._imagenet_classes[image_item.tag])
62
+
63
+ return image, class_name
64
+
65
+ def __len__(self):
66
+ return len(self._all_images)
robustness_dataset_per_class.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torchvision.datasets import ImageFolder
3
+ import torch
4
+ import os
5
+ from PIL import Image
6
+ import collections
7
+ import torchvision.transforms as transforms
8
+ from label_str_to_imagenet_classes import label_str_to_imagenet_classes
9
+
10
+ torch.manual_seed(0)
11
+
12
+ ImageItem = collections.namedtuple('ImageItem', ('image_name', 'tag'))
13
+
14
+ normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
15
+ std=[0.5, 0.5, 0.5])
16
+
17
+ transform = transforms.Compose([
18
+ transforms.Resize(256),
19
+ transforms.CenterCrop(224),
20
+ transforms.ToTensor(),
21
+ normalize,
22
+ ])
23
+
24
+ class RobustnessDataset(ImageFolder):
25
+ def __init__(self, imagenet_path, folder, imagenet_classes_path='imagenet_classes.json', isV2=False, isSI=False):
26
+ self._isV2 = isV2
27
+ self._isSI = isSI
28
+ self._folder = folder
29
+ self._imagenet_path = imagenet_path
30
+ with open(imagenet_classes_path, 'r') as f:
31
+ self._imagenet_classes = json.load(f)
32
+ self._all_images = []
33
+
34
+ base_dir = os.path.join(self._imagenet_path, folder)
35
+ for i, file in enumerate(os.listdir(base_dir)):
36
+ self._all_images.append(ImageItem(file, folder))
37
+
38
+
39
+ def __getitem__(self, item):
40
+ image_item = self._all_images[item]
41
+ image_path = os.path.join(self._imagenet_path, image_item.tag, image_item.image_name)
42
+ image = Image.open(image_path)
43
+ image = image.convert('RGB')
44
+ image = transform(image)
45
+
46
+ if self._isV2:
47
+ class_name = int(image_item.tag)
48
+ elif self._isSI:
49
+ class_name = int(label_str_to_imagenet_classes[image_item.tag])
50
+ else:
51
+ class_name = int(self._imagenet_classes[image_item.tag])
52
+
53
+ return image, class_name
54
+
55
+ def __len__(self):
56
+ return len(self._all_images)
57
+
58
+ def get_classname(self):
59
+ if self._isV2:
60
+ class_name = int(self._folder)
61
+ elif self._isSI:
62
+ class_name = int(label_str_to_imagenet_classes[self._folder])
63
+ else:
64
+ class_name = int(self._imagenet_classes[self._folder])
65
+ return class_name
samples/augreg_base/1_in.png ADDED
samples/augreg_base/2_in.png ADDED
samples/augreg_base/3_in.png ADDED
samples/augreg_base/a.png ADDED
samples/augreg_base/a_2.png ADDED
samples/augreg_base/a_3.png ADDED
samples/catdog.png ADDED
samples/deit_base/1_in.png ADDED
samples/deit_base/2_in.png ADDED
samples/deit_base/3_in.png ADDED
samples/deit_base/a.png ADDED
samples/deit_base/a_2.png ADDED