amaye15 commited on
Commit
a11aa48
1 Parent(s): ee91fd0

End of training

Browse files
Files changed (7) hide show
  1. README.md +191 -0
  2. config.json +303 -0
  3. config.toml +27 -0
  4. model.safetensors +3 -0
  5. preprocessor_config.json +36 -0
  6. train.ipynb +2084 -0
  7. training_args.bin +3 -0
README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: microsoft/swinv2-base-patch4-window16-256
4
+ tags:
5
+ - generated_from_trainer
6
+ datasets:
7
+ - stanford-dogs
8
+ metrics:
9
+ - accuracy
10
+ - f1
11
+ - precision
12
+ - recall
13
+ model-index:
14
+ - name: microsoft-swinv2-base-patch4-window16-256-batch32-lr0.0005-standford-dogs
15
+ results:
16
+ - task:
17
+ name: Image Classification
18
+ type: image-classification
19
+ dataset:
20
+ name: stanford-dogs
21
+ type: stanford-dogs
22
+ config: default
23
+ split: full
24
+ args: default
25
+ metrics:
26
+ - name: Accuracy
27
+ type: accuracy
28
+ value: 0.9429057337220602
29
+ - name: F1
30
+ type: f1
31
+ value: 0.9410841953165723
32
+ - name: Precision
33
+ type: precision
34
+ value: 0.9431724455914652
35
+ - name: Recall
36
+ type: recall
37
+ value: 0.9417046971391595
38
+ ---
39
+
40
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
41
+ should probably proofread and complete it, then remove this comment. -->
42
+
43
+ # microsoft-swinv2-base-patch4-window16-256-batch32-lr0.0005-standford-dogs
44
+
45
+ This model is a fine-tuned version of [microsoft/swinv2-base-patch4-window16-256](https://huggingface.co/microsoft/swinv2-base-patch4-window16-256) on the stanford-dogs dataset.
46
+ It achieves the following results on the evaluation set:
47
+ - Loss: 0.1810
48
+ - Accuracy: 0.9429
49
+ - F1: 0.9411
50
+ - Precision: 0.9432
51
+ - Recall: 0.9417
52
+
53
+ ## Model description
54
+
55
+ More information needed
56
+
57
+ ## Intended uses & limitations
58
+
59
+ More information needed
60
+
61
+ ## Training and evaluation data
62
+
63
+ More information needed
64
+
65
+ ## Training procedure
66
+
67
+ ### Training hyperparameters
68
+
69
+ The following hyperparameters were used during training:
70
+ - learning_rate: 5e-05
71
+ - train_batch_size: 32
72
+ - eval_batch_size: 32
73
+ - seed: 42
74
+ - gradient_accumulation_steps: 4
75
+ - total_train_batch_size: 128
76
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
77
+ - lr_scheduler_type: linear
78
+ - training_steps: 1000
79
+
80
+ ### Training results
81
+
82
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | F1 | Precision | Recall |
83
+ |:-------------:|:------:|:----:|:---------------:|:--------:|:------:|:---------:|:------:|
84
+ | 4.7518 | 0.0777 | 10 | 4.6391 | 0.0741 | 0.0533 | 0.0667 | 0.0705 |
85
+ | 4.5585 | 0.1553 | 20 | 4.3463 | 0.1919 | 0.1445 | 0.1900 | 0.1794 |
86
+ | 4.2377 | 0.2330 | 30 | 3.8243 | 0.3525 | 0.3100 | 0.4154 | 0.3382 |
87
+ | 3.6654 | 0.3107 | 40 | 2.9276 | 0.6409 | 0.6111 | 0.6994 | 0.6300 |
88
+ | 2.7617 | 0.3883 | 50 | 1.7703 | 0.8248 | 0.8042 | 0.8361 | 0.8182 |
89
+ | 1.9475 | 0.4660 | 60 | 1.0440 | 0.8863 | 0.8781 | 0.8924 | 0.8821 |
90
+ | 1.3629 | 0.5437 | 70 | 0.6490 | 0.9099 | 0.9031 | 0.9191 | 0.9062 |
91
+ | 1.0488 | 0.6214 | 80 | 0.4485 | 0.9150 | 0.9075 | 0.9147 | 0.9118 |
92
+ | 0.8477 | 0.6990 | 90 | 0.3744 | 0.9206 | 0.9169 | 0.9294 | 0.9190 |
93
+ | 0.7184 | 0.7767 | 100 | 0.3301 | 0.9259 | 0.9215 | 0.9283 | 0.9227 |
94
+ | 0.7149 | 0.8544 | 110 | 0.2970 | 0.9186 | 0.9152 | 0.9227 | 0.9156 |
95
+ | 0.6429 | 0.9320 | 120 | 0.2675 | 0.9286 | 0.9238 | 0.9301 | 0.9256 |
96
+ | 0.5864 | 1.0097 | 130 | 0.2609 | 0.9291 | 0.9258 | 0.9338 | 0.9272 |
97
+ | 0.5414 | 1.0874 | 140 | 0.2644 | 0.9162 | 0.9122 | 0.9212 | 0.9156 |
98
+ | 0.5323 | 1.1650 | 150 | 0.2454 | 0.9281 | 0.9225 | 0.9362 | 0.9256 |
99
+ | 0.5061 | 1.2427 | 160 | 0.2481 | 0.9269 | 0.9235 | 0.9308 | 0.9251 |
100
+ | 0.5898 | 1.3204 | 170 | 0.2306 | 0.9346 | 0.9324 | 0.9389 | 0.9331 |
101
+ | 0.5277 | 1.3981 | 180 | 0.2192 | 0.9368 | 0.9327 | 0.9384 | 0.9350 |
102
+ | 0.4824 | 1.4757 | 190 | 0.2171 | 0.9337 | 0.9297 | 0.9375 | 0.9311 |
103
+ | 0.4632 | 1.5534 | 200 | 0.2244 | 0.9346 | 0.9315 | 0.9379 | 0.9326 |
104
+ | 0.4882 | 1.6311 | 210 | 0.2237 | 0.9361 | 0.9323 | 0.9404 | 0.9345 |
105
+ | 0.4583 | 1.7087 | 220 | 0.2228 | 0.9327 | 0.9289 | 0.9373 | 0.9304 |
106
+ | 0.4692 | 1.7864 | 230 | 0.2098 | 0.9354 | 0.9316 | 0.9370 | 0.9332 |
107
+ | 0.5407 | 1.8641 | 240 | 0.2102 | 0.9356 | 0.9342 | 0.9375 | 0.9351 |
108
+ | 0.4629 | 1.9417 | 250 | 0.2045 | 0.9378 | 0.9349 | 0.9396 | 0.9367 |
109
+ | 0.4363 | 2.0194 | 260 | 0.2023 | 0.9373 | 0.9346 | 0.9398 | 0.9355 |
110
+ | 0.4328 | 2.0971 | 270 | 0.2063 | 0.9354 | 0.9320 | 0.9360 | 0.9343 |
111
+ | 0.3554 | 2.1748 | 280 | 0.1948 | 0.9439 | 0.9398 | 0.9475 | 0.9418 |
112
+ | 0.4024 | 2.2524 | 290 | 0.1985 | 0.9388 | 0.9372 | 0.9397 | 0.9377 |
113
+ | 0.4006 | 2.3301 | 300 | 0.2153 | 0.9334 | 0.9275 | 0.9420 | 0.9311 |
114
+ | 0.3935 | 2.4078 | 310 | 0.2021 | 0.9393 | 0.9346 | 0.9416 | 0.9368 |
115
+ | 0.3591 | 2.4854 | 320 | 0.2126 | 0.9346 | 0.9311 | 0.9403 | 0.9333 |
116
+ | 0.4058 | 2.5631 | 330 | 0.2020 | 0.9378 | 0.9357 | 0.9393 | 0.9358 |
117
+ | 0.396 | 2.6408 | 340 | 0.2038 | 0.9371 | 0.9339 | 0.9410 | 0.9357 |
118
+ | 0.4157 | 2.7184 | 350 | 0.2091 | 0.9332 | 0.9288 | 0.9352 | 0.9308 |
119
+ | 0.4222 | 2.7961 | 360 | 0.1933 | 0.9393 | 0.9372 | 0.9399 | 0.9378 |
120
+ | 0.3521 | 2.8738 | 370 | 0.1984 | 0.9397 | 0.9381 | 0.9430 | 0.9388 |
121
+ | 0.3925 | 2.9515 | 380 | 0.1874 | 0.9383 | 0.9347 | 0.9390 | 0.9358 |
122
+ | 0.3475 | 3.0291 | 390 | 0.1994 | 0.9383 | 0.9364 | 0.9410 | 0.9376 |
123
+ | 0.3526 | 3.1068 | 400 | 0.1941 | 0.9390 | 0.9352 | 0.9402 | 0.9373 |
124
+ | 0.351 | 3.1845 | 410 | 0.1893 | 0.9417 | 0.9403 | 0.9438 | 0.9410 |
125
+ | 0.3549 | 3.2621 | 420 | 0.1960 | 0.9390 | 0.9370 | 0.9410 | 0.9381 |
126
+ | 0.3291 | 3.3398 | 430 | 0.1948 | 0.9397 | 0.9358 | 0.9387 | 0.9374 |
127
+ | 0.3153 | 3.4175 | 440 | 0.1992 | 0.9441 | 0.9415 | 0.9453 | 0.9427 |
128
+ | 0.3116 | 3.4951 | 450 | 0.2005 | 0.9417 | 0.9389 | 0.9432 | 0.9404 |
129
+ | 0.3053 | 3.5728 | 460 | 0.1974 | 0.9412 | 0.9372 | 0.9424 | 0.9394 |
130
+ | 0.3141 | 3.6505 | 470 | 0.1941 | 0.9405 | 0.9386 | 0.9420 | 0.9395 |
131
+ | 0.3275 | 3.7282 | 480 | 0.2182 | 0.9334 | 0.9301 | 0.9374 | 0.9321 |
132
+ | 0.2997 | 3.8058 | 490 | 0.2029 | 0.9376 | 0.9343 | 0.9392 | 0.9360 |
133
+ | 0.3242 | 3.8835 | 500 | 0.1996 | 0.9380 | 0.9344 | 0.9399 | 0.9361 |
134
+ | 0.3585 | 3.9612 | 510 | 0.1935 | 0.9405 | 0.9378 | 0.9421 | 0.9389 |
135
+ | 0.2942 | 4.0388 | 520 | 0.2028 | 0.9368 | 0.9341 | 0.9428 | 0.9367 |
136
+ | 0.3233 | 4.1165 | 530 | 0.2029 | 0.9378 | 0.9353 | 0.9406 | 0.9364 |
137
+ | 0.2942 | 4.1942 | 540 | 0.1959 | 0.9385 | 0.9368 | 0.9395 | 0.9372 |
138
+ | 0.3079 | 4.2718 | 550 | 0.1941 | 0.9371 | 0.9349 | 0.9373 | 0.9354 |
139
+ | 0.2931 | 4.3495 | 560 | 0.1871 | 0.9414 | 0.9388 | 0.9410 | 0.9394 |
140
+ | 0.3058 | 4.4272 | 570 | 0.1879 | 0.9419 | 0.9403 | 0.9430 | 0.9407 |
141
+ | 0.3402 | 4.5049 | 580 | 0.1833 | 0.9434 | 0.9409 | 0.9435 | 0.9420 |
142
+ | 0.3169 | 4.5825 | 590 | 0.1882 | 0.9412 | 0.9391 | 0.9425 | 0.9402 |
143
+ | 0.3071 | 4.6602 | 600 | 0.1821 | 0.9448 | 0.9425 | 0.9460 | 0.9431 |
144
+ | 0.313 | 4.7379 | 610 | 0.1879 | 0.9429 | 0.9401 | 0.9441 | 0.9413 |
145
+ | 0.3338 | 4.8155 | 620 | 0.1843 | 0.9456 | 0.9424 | 0.9469 | 0.9439 |
146
+ | 0.2468 | 4.8932 | 630 | 0.1866 | 0.9436 | 0.9412 | 0.9441 | 0.9426 |
147
+ | 0.2567 | 4.9709 | 640 | 0.1882 | 0.9405 | 0.9387 | 0.9417 | 0.9393 |
148
+ | 0.2792 | 5.0485 | 650 | 0.1914 | 0.9429 | 0.9407 | 0.9442 | 0.9418 |
149
+ | 0.2985 | 5.1262 | 660 | 0.1880 | 0.9429 | 0.9393 | 0.9442 | 0.9411 |
150
+ | 0.2744 | 5.2039 | 670 | 0.1865 | 0.9410 | 0.9378 | 0.9420 | 0.9390 |
151
+ | 0.2662 | 5.2816 | 680 | 0.1877 | 0.9419 | 0.9400 | 0.9423 | 0.9407 |
152
+ | 0.2613 | 5.3592 | 690 | 0.1890 | 0.9393 | 0.9369 | 0.9401 | 0.9378 |
153
+ | 0.2698 | 5.4369 | 700 | 0.1849 | 0.9429 | 0.9409 | 0.9441 | 0.9417 |
154
+ | 0.2592 | 5.5146 | 710 | 0.1854 | 0.9429 | 0.9414 | 0.9439 | 0.9425 |
155
+ | 0.2819 | 5.5922 | 720 | 0.1868 | 0.9429 | 0.9414 | 0.9443 | 0.9418 |
156
+ | 0.2625 | 5.6699 | 730 | 0.1832 | 0.9434 | 0.9417 | 0.9438 | 0.9422 |
157
+ | 0.273 | 5.7476 | 740 | 0.1862 | 0.9439 | 0.9408 | 0.9445 | 0.9424 |
158
+ | 0.2718 | 5.8252 | 750 | 0.1838 | 0.9441 | 0.9417 | 0.9443 | 0.9428 |
159
+ | 0.3055 | 5.9029 | 760 | 0.1852 | 0.9422 | 0.9396 | 0.9426 | 0.9407 |
160
+ | 0.276 | 5.9806 | 770 | 0.1843 | 0.9424 | 0.9409 | 0.9434 | 0.9415 |
161
+ | 0.2614 | 6.0583 | 780 | 0.1839 | 0.9429 | 0.9403 | 0.9431 | 0.9411 |
162
+ | 0.2452 | 6.1359 | 790 | 0.1858 | 0.9407 | 0.9384 | 0.9414 | 0.9390 |
163
+ | 0.2608 | 6.2136 | 800 | 0.1851 | 0.9429 | 0.9411 | 0.9437 | 0.9417 |
164
+ | 0.2639 | 6.2913 | 810 | 0.1842 | 0.9453 | 0.9432 | 0.9463 | 0.9438 |
165
+ | 0.2696 | 6.3689 | 820 | 0.1812 | 0.9424 | 0.9406 | 0.9425 | 0.9412 |
166
+ | 0.2524 | 6.4466 | 830 | 0.1830 | 0.9427 | 0.9411 | 0.9433 | 0.9417 |
167
+ | 0.2673 | 6.5243 | 840 | 0.1823 | 0.9451 | 0.9436 | 0.9464 | 0.9442 |
168
+ | 0.2991 | 6.6019 | 850 | 0.1837 | 0.9429 | 0.9408 | 0.9431 | 0.9419 |
169
+ | 0.2704 | 6.6796 | 860 | 0.1833 | 0.9439 | 0.9424 | 0.9446 | 0.9431 |
170
+ | 0.2437 | 6.7573 | 870 | 0.1857 | 0.9424 | 0.9410 | 0.9434 | 0.9416 |
171
+ | 0.2266 | 6.8350 | 880 | 0.1846 | 0.9431 | 0.9416 | 0.9436 | 0.9423 |
172
+ | 0.2276 | 6.9126 | 890 | 0.1825 | 0.9441 | 0.9426 | 0.9448 | 0.9433 |
173
+ | 0.2249 | 6.9903 | 900 | 0.1813 | 0.9436 | 0.9419 | 0.9441 | 0.9425 |
174
+ | 0.2559 | 7.0680 | 910 | 0.1813 | 0.9444 | 0.9425 | 0.9448 | 0.9431 |
175
+ | 0.2616 | 7.1456 | 920 | 0.1813 | 0.9441 | 0.9421 | 0.9443 | 0.9428 |
176
+ | 0.2247 | 7.2233 | 930 | 0.1813 | 0.9439 | 0.9421 | 0.9442 | 0.9426 |
177
+ | 0.2471 | 7.3010 | 940 | 0.1813 | 0.9448 | 0.9430 | 0.9453 | 0.9436 |
178
+ | 0.2446 | 7.3786 | 950 | 0.1817 | 0.9444 | 0.9427 | 0.9450 | 0.9432 |
179
+ | 0.2262 | 7.4563 | 960 | 0.1819 | 0.9434 | 0.9417 | 0.9439 | 0.9423 |
180
+ | 0.2632 | 7.5340 | 970 | 0.1818 | 0.9439 | 0.9422 | 0.9444 | 0.9427 |
181
+ | 0.2258 | 7.6117 | 980 | 0.1815 | 0.9434 | 0.9416 | 0.9439 | 0.9422 |
182
+ | 0.2404 | 7.6893 | 990 | 0.1811 | 0.9429 | 0.9410 | 0.9432 | 0.9416 |
183
+ | 0.2379 | 7.7670 | 1000 | 0.1810 | 0.9429 | 0.9411 | 0.9432 | 0.9417 |
184
+
185
+
186
+ ### Framework versions
187
+
188
+ - Transformers 4.40.2
189
+ - Pytorch 2.3.0
190
+ - Datasets 2.19.1
191
+ - Tokenizers 0.19.1
config.json ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "microsoft/swinv2-base-patch4-window16-256",
3
+ "architectures": [
4
+ "Swinv2ForImageClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "depths": [
8
+ 2,
9
+ 2,
10
+ 18,
11
+ 2
12
+ ],
13
+ "drop_path_rate": 0.1,
14
+ "embed_dim": 128,
15
+ "encoder_stride": 32,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.0,
18
+ "hidden_size": 1024,
19
+ "id2label": {
20
+ "0": "Affenpinscher",
21
+ "1": "Afghan Hound",
22
+ "2": "African Hunting Dog",
23
+ "3": "Airedale",
24
+ "4": "American Staffordshire Terrier",
25
+ "5": "Appenzeller",
26
+ "6": "Australian Terrier",
27
+ "7": "Basenji",
28
+ "8": "Basset",
29
+ "9": "Beagle",
30
+ "10": "Bedlington Terrier",
31
+ "11": "Bernese Mountain Dog",
32
+ "12": "Black And Tan Coonhound",
33
+ "13": "Blenheim Spaniel",
34
+ "14": "Bloodhound",
35
+ "15": "Bluetick",
36
+ "16": "Border Collie",
37
+ "17": "Border Terrier",
38
+ "18": "Borzoi",
39
+ "19": "Boston Bull",
40
+ "20": "Bouvier Des Flandres",
41
+ "21": "Boxer",
42
+ "22": "Brabancon Griffon",
43
+ "23": "Briard",
44
+ "24": "Brittany Spaniel",
45
+ "25": "Bull Mastiff",
46
+ "26": "Cairn",
47
+ "27": "Cardigan",
48
+ "28": "Chesapeake Bay Retriever",
49
+ "29": "Chihuahua",
50
+ "30": "Chow",
51
+ "31": "Clumber",
52
+ "32": "Cocker Spaniel",
53
+ "33": "Collie",
54
+ "34": "Curly Coated Retriever",
55
+ "35": "Dandie Dinmont",
56
+ "36": "Dhole",
57
+ "37": "Dingo",
58
+ "38": "Doberman",
59
+ "39": "English Foxhound",
60
+ "40": "English Setter",
61
+ "41": "English Springer",
62
+ "42": "Entlebucher",
63
+ "43": "Eskimo Dog",
64
+ "44": "Flat Coated Retriever",
65
+ "45": "French Bulldog",
66
+ "46": "German Shepherd",
67
+ "47": "German Short Haired Pointer",
68
+ "48": "Giant Schnauzer",
69
+ "49": "Golden Retriever",
70
+ "50": "Gordon Setter",
71
+ "51": "Great Dane",
72
+ "52": "Great Pyrenees",
73
+ "53": "Greater Swiss Mountain Dog",
74
+ "54": "Groenendael",
75
+ "55": "Ibizan Hound",
76
+ "56": "Irish Setter",
77
+ "57": "Irish Terrier",
78
+ "58": "Irish Water Spaniel",
79
+ "59": "Irish Wolfhound",
80
+ "60": "Italian Greyhound",
81
+ "61": "Japanese Spaniel",
82
+ "62": "Keeshond",
83
+ "63": "Kelpie",
84
+ "64": "Kerry Blue Terrier",
85
+ "65": "Komondor",
86
+ "66": "Kuvasz",
87
+ "67": "Labrador Retriever",
88
+ "68": "Lakeland Terrier",
89
+ "69": "Leonberg",
90
+ "70": "Lhasa",
91
+ "71": "Malamute",
92
+ "72": "Malinois",
93
+ "73": "Maltese Dog",
94
+ "74": "Mexican Hairless",
95
+ "75": "Miniature Pinscher",
96
+ "76": "Miniature Poodle",
97
+ "77": "Miniature Schnauzer",
98
+ "78": "Newfoundland",
99
+ "79": "Norfolk Terrier",
100
+ "80": "Norwegian Elkhound",
101
+ "81": "Norwich Terrier",
102
+ "82": "Old English Sheepdog",
103
+ "83": "Otterhound",
104
+ "84": "Papillon",
105
+ "85": "Pekinese",
106
+ "86": "Pembroke",
107
+ "87": "Pomeranian",
108
+ "88": "Pug",
109
+ "89": "Redbone",
110
+ "90": "Rhodesian Ridgeback",
111
+ "91": "Rottweiler",
112
+ "92": "Saint Bernard",
113
+ "93": "Saluki",
114
+ "94": "Samoyed",
115
+ "95": "Schipperke",
116
+ "96": "Scotch Terrier",
117
+ "97": "Scottish Deerhound",
118
+ "98": "Sealyham Terrier",
119
+ "99": "Shetland Sheepdog",
120
+ "100": "Shih Tzu",
121
+ "101": "Siberian Husky",
122
+ "102": "Silky Terrier",
123
+ "103": "Soft Coated Wheaten Terrier",
124
+ "104": "Staffordshire Bullterrier",
125
+ "105": "Standard Poodle",
126
+ "106": "Standard Schnauzer",
127
+ "107": "Sussex Spaniel",
128
+ "108": "Tibetan Mastiff",
129
+ "109": "Tibetan Terrier",
130
+ "110": "Toy Poodle",
131
+ "111": "Toy Terrier",
132
+ "112": "Vizsla",
133
+ "113": "Walker Hound",
134
+ "114": "Weimaraner",
135
+ "115": "Welsh Springer Spaniel",
136
+ "116": "West Highland White Terrier",
137
+ "117": "Whippet",
138
+ "118": "Wire Haired Fox Terrier",
139
+ "119": "Yorkshire Terrier"
140
+ },
141
+ "image_size": 256,
142
+ "initializer_range": 0.02,
143
+ "label2id": {
144
+ "Affenpinscher": 0,
145
+ "Afghan Hound": 1,
146
+ "African Hunting Dog": 2,
147
+ "Airedale": 3,
148
+ "American Staffordshire Terrier": 4,
149
+ "Appenzeller": 5,
150
+ "Australian Terrier": 6,
151
+ "Basenji": 7,
152
+ "Basset": 8,
153
+ "Beagle": 9,
154
+ "Bedlington Terrier": 10,
155
+ "Bernese Mountain Dog": 11,
156
+ "Black And Tan Coonhound": 12,
157
+ "Blenheim Spaniel": 13,
158
+ "Bloodhound": 14,
159
+ "Bluetick": 15,
160
+ "Border Collie": 16,
161
+ "Border Terrier": 17,
162
+ "Borzoi": 18,
163
+ "Boston Bull": 19,
164
+ "Bouvier Des Flandres": 20,
165
+ "Boxer": 21,
166
+ "Brabancon Griffon": 22,
167
+ "Briard": 23,
168
+ "Brittany Spaniel": 24,
169
+ "Bull Mastiff": 25,
170
+ "Cairn": 26,
171
+ "Cardigan": 27,
172
+ "Chesapeake Bay Retriever": 28,
173
+ "Chihuahua": 29,
174
+ "Chow": 30,
175
+ "Clumber": 31,
176
+ "Cocker Spaniel": 32,
177
+ "Collie": 33,
178
+ "Curly Coated Retriever": 34,
179
+ "Dandie Dinmont": 35,
180
+ "Dhole": 36,
181
+ "Dingo": 37,
182
+ "Doberman": 38,
183
+ "English Foxhound": 39,
184
+ "English Setter": 40,
185
+ "English Springer": 41,
186
+ "Entlebucher": 42,
187
+ "Eskimo Dog": 43,
188
+ "Flat Coated Retriever": 44,
189
+ "French Bulldog": 45,
190
+ "German Shepherd": 46,
191
+ "German Short Haired Pointer": 47,
192
+ "Giant Schnauzer": 48,
193
+ "Golden Retriever": 49,
194
+ "Gordon Setter": 50,
195
+ "Great Dane": 51,
196
+ "Great Pyrenees": 52,
197
+ "Greater Swiss Mountain Dog": 53,
198
+ "Groenendael": 54,
199
+ "Ibizan Hound": 55,
200
+ "Irish Setter": 56,
201
+ "Irish Terrier": 57,
202
+ "Irish Water Spaniel": 58,
203
+ "Irish Wolfhound": 59,
204
+ "Italian Greyhound": 60,
205
+ "Japanese Spaniel": 61,
206
+ "Keeshond": 62,
207
+ "Kelpie": 63,
208
+ "Kerry Blue Terrier": 64,
209
+ "Komondor": 65,
210
+ "Kuvasz": 66,
211
+ "Labrador Retriever": 67,
212
+ "Lakeland Terrier": 68,
213
+ "Leonberg": 69,
214
+ "Lhasa": 70,
215
+ "Malamute": 71,
216
+ "Malinois": 72,
217
+ "Maltese Dog": 73,
218
+ "Mexican Hairless": 74,
219
+ "Miniature Pinscher": 75,
220
+ "Miniature Poodle": 76,
221
+ "Miniature Schnauzer": 77,
222
+ "Newfoundland": 78,
223
+ "Norfolk Terrier": 79,
224
+ "Norwegian Elkhound": 80,
225
+ "Norwich Terrier": 81,
226
+ "Old English Sheepdog": 82,
227
+ "Otterhound": 83,
228
+ "Papillon": 84,
229
+ "Pekinese": 85,
230
+ "Pembroke": 86,
231
+ "Pomeranian": 87,
232
+ "Pug": 88,
233
+ "Redbone": 89,
234
+ "Rhodesian Ridgeback": 90,
235
+ "Rottweiler": 91,
236
+ "Saint Bernard": 92,
237
+ "Saluki": 93,
238
+ "Samoyed": 94,
239
+ "Schipperke": 95,
240
+ "Scotch Terrier": 96,
241
+ "Scottish Deerhound": 97,
242
+ "Sealyham Terrier": 98,
243
+ "Shetland Sheepdog": 99,
244
+ "Shih Tzu": 100,
245
+ "Siberian Husky": 101,
246
+ "Silky Terrier": 102,
247
+ "Soft Coated Wheaten Terrier": 103,
248
+ "Staffordshire Bullterrier": 104,
249
+ "Standard Poodle": 105,
250
+ "Standard Schnauzer": 106,
251
+ "Sussex Spaniel": 107,
252
+ "Tibetan Mastiff": 108,
253
+ "Tibetan Terrier": 109,
254
+ "Toy Poodle": 110,
255
+ "Toy Terrier": 111,
256
+ "Vizsla": 112,
257
+ "Walker Hound": 113,
258
+ "Weimaraner": 114,
259
+ "Welsh Springer Spaniel": 115,
260
+ "West Highland White Terrier": 116,
261
+ "Whippet": 117,
262
+ "Wire Haired Fox Terrier": 118,
263
+ "Yorkshire Terrier": 119
264
+ },
265
+ "layer_norm_eps": 1e-05,
266
+ "mlp_ratio": 4.0,
267
+ "model_type": "swinv2",
268
+ "num_channels": 3,
269
+ "num_heads": [
270
+ 4,
271
+ 8,
272
+ 16,
273
+ 32
274
+ ],
275
+ "num_layers": 4,
276
+ "out_features": [
277
+ "stage4"
278
+ ],
279
+ "out_indices": [
280
+ 4
281
+ ],
282
+ "patch_size": 4,
283
+ "path_norm": true,
284
+ "pretrained_window_sizes": [
285
+ 0,
286
+ 0,
287
+ 0,
288
+ 0
289
+ ],
290
+ "problem_type": "single_label_classification",
291
+ "qkv_bias": true,
292
+ "stage_names": [
293
+ "stem",
294
+ "stage1",
295
+ "stage2",
296
+ "stage3",
297
+ "stage4"
298
+ ],
299
+ "torch_dtype": "float32",
300
+ "transformers_version": "4.40.2",
301
+ "use_absolute_embeddings": false,
302
+ "window_size": 16
303
+ }
config.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [training_args]
2
+ output_dir="/Users/andrewmayes/Openclassroom/CanineNet/code/"
3
+ evaluation_strategy="steps"
4
+ save_strategy="steps"
5
+ learning_rate=5e-5
6
+ #per_device_train_batch_size=32 # 512
7
+ #per_device_eval_batch_size=32 # 512
8
+ # num_train_epochs=5,
9
+ eval_delay=0 # 50
10
+ eval_steps=0.01
11
+ #eval_accumulation_steps
12
+ gradient_accumulation_steps=4
13
+ gradient_checkpointing=true
14
+ optim="adafactor"
15
+ max_steps=1000 # 100
16
+ #logging_dir=""
17
+ #log_level="error"
18
+ load_best_model_at_end=true
19
+ metric_for_best_model="f1"
20
+ greater_is_better=true
21
+ #use_mps_device=true
22
+ logging_steps=0.01
23
+ save_steps=0.01
24
+ #auto_find_batch_size=true
25
+ report_to="mlflow"
26
+ save_total_limit=2
27
+ #hub_model_id="amaye15/SwinV2-Base-Document-Classifier"
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfa4144ca46aaf446d9067e8b320acfe81723f93d0a51fab6f4f3824a1306e92
3
+ size 348129336
preprocessor_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_processor_keys": [
3
+ "images",
4
+ "do_resize",
5
+ "size",
6
+ "resample",
7
+ "do_rescale",
8
+ "rescale_factor",
9
+ "do_normalize",
10
+ "image_mean",
11
+ "image_std",
12
+ "return_tensors",
13
+ "data_format",
14
+ "input_data_format"
15
+ ],
16
+ "do_normalize": true,
17
+ "do_rescale": true,
18
+ "do_resize": true,
19
+ "image_mean": [
20
+ 0.485,
21
+ 0.456,
22
+ 0.406
23
+ ],
24
+ "image_processor_type": "ViTImageProcessor",
25
+ "image_std": [
26
+ 0.229,
27
+ 0.224,
28
+ 0.225
29
+ ],
30
+ "resample": 3,
31
+ "rescale_factor": 0.00392156862745098,
32
+ "size": {
33
+ "height": 256,
34
+ "width": 256
35
+ }
36
+ }
train.ipynb ADDED
@@ -0,0 +1,2084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Install"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "%pip install uv"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "!uv pip install dagshub setuptools accelerate toml torch torchvision transformers mlflow datasets ipywidgets python-dotenv evaluate"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "metadata": {},
31
+ "source": [
32
+ "# Setup"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 1,
38
+ "metadata": {},
39
+ "outputs": [
40
+ {
41
+ "data": {
42
+ "text/html": [
43
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Initialized MLflow to track repo <span style=\"color: #008000; text-decoration-color: #008000\">\"amaye15/CanineNet\"</span>\n",
44
+ "</pre>\n"
45
+ ],
46
+ "text/plain": [
47
+ "Initialized MLflow to track repo \u001b[32m\"amaye15/CanineNet\"\u001b[0m\n"
48
+ ]
49
+ },
50
+ "metadata": {},
51
+ "output_type": "display_data"
52
+ },
53
+ {
54
+ "data": {
55
+ "text/html": [
56
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Repository amaye15/CanineNet initialized!\n",
57
+ "</pre>\n"
58
+ ],
59
+ "text/plain": [
60
+ "Repository amaye15/CanineNet initialized!\n"
61
+ ]
62
+ },
63
+ "metadata": {},
64
+ "output_type": "display_data"
65
+ }
66
+ ],
67
+ "source": [
68
+ "import os\n",
69
+ "import toml\n",
70
+ "import torch\n",
71
+ "import mlflow\n",
72
+ "import dagshub\n",
73
+ "import datasets\n",
74
+ "import evaluate\n",
75
+ "from dotenv import load_dotenv\n",
76
+ "from torchvision.transforms import v2\n",
77
+ "from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer\n",
78
+ "\n",
79
+ "ENV_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/.env\"\n",
80
+ "CONFIG_PATH = \"/Users/andrewmayes/Openclassroom/CanineNet/code/config.toml\"\n",
81
+ "CONFIG = toml.load(CONFIG_PATH)\n",
82
+ "\n",
83
+ "load_dotenv(ENV_PATH)\n",
84
+ "\n",
85
+ "dagshub.init(repo_name=os.environ['MLFLOW_TRACKING_PROJECTNAME'], repo_owner=os.environ['MLFLOW_TRACKING_USERNAME'], mlflow=True, dvc=True)\n",
86
+ "\n",
87
+ "os.environ['MLFLOW_TRACKING_USERNAME'] = \"amaye15\"\n",
88
+ "\n",
89
+ "mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME']\n",
90
+ " + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')\n",
91
+ "\n",
92
+ "CREATE_DATASET = True\n",
93
+ "ORIGINAL_DATASET = \"Alanox/stanford-dogs\"\n",
94
+ "MODIFIED_DATASET = \"amaye15/stanford-dogs\"\n",
95
+ "REMOVE_COLUMNS = [\"name\", \"annotations\"]\n",
96
+ "RENAME_COLUMNS = {\"image\":\"pixel_values\", \"target\":\"label\"}\n",
97
+ "SPLIT = 0.2\n",
98
+ "\n",
99
+ "METRICS = [\"accuracy\", \"f1\", \"precision\", \"recall\"]\n",
100
+ "# MODELS = 'google/vit-base-patch16-224'\n",
101
+ "# MODELS = \"google/siglip-base-patch16-224\"\n",
102
+ "\n"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "# Dataset"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 2,
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "name": "stdout",
119
+ "output_type": "stream",
120
+ "text": [
121
+ "Affenpinscher: 0\n",
122
+ "Afghan Hound: 1\n",
123
+ "African Hunting Dog: 2\n",
124
+ "Airedale: 3\n",
125
+ "American Staffordshire Terrier: 4\n",
126
+ "Appenzeller: 5\n",
127
+ "Australian Terrier: 6\n",
128
+ "Basenji: 7\n",
129
+ "Basset: 8\n",
130
+ "Beagle: 9\n",
131
+ "Bedlington Terrier: 10\n",
132
+ "Bernese Mountain Dog: 11\n",
133
+ "Black And Tan Coonhound: 12\n",
134
+ "Blenheim Spaniel: 13\n",
135
+ "Bloodhound: 14\n",
136
+ "Bluetick: 15\n",
137
+ "Border Collie: 16\n",
138
+ "Border Terrier: 17\n",
139
+ "Borzoi: 18\n",
140
+ "Boston Bull: 19\n",
141
+ "Bouvier Des Flandres: 20\n",
142
+ "Boxer: 21\n",
143
+ "Brabancon Griffon: 22\n",
144
+ "Briard: 23\n",
145
+ "Brittany Spaniel: 24\n",
146
+ "Bull Mastiff: 25\n",
147
+ "Cairn: 26\n",
148
+ "Cardigan: 27\n",
149
+ "Chesapeake Bay Retriever: 28\n",
150
+ "Chihuahua: 29\n",
151
+ "Chow: 30\n",
152
+ "Clumber: 31\n",
153
+ "Cocker Spaniel: 32\n",
154
+ "Collie: 33\n",
155
+ "Curly Coated Retriever: 34\n",
156
+ "Dandie Dinmont: 35\n",
157
+ "Dhole: 36\n",
158
+ "Dingo: 37\n",
159
+ "Doberman: 38\n",
160
+ "English Foxhound: 39\n",
161
+ "English Setter: 40\n",
162
+ "English Springer: 41\n",
163
+ "Entlebucher: 42\n",
164
+ "Eskimo Dog: 43\n",
165
+ "Flat Coated Retriever: 44\n",
166
+ "French Bulldog: 45\n",
167
+ "German Shepherd: 46\n",
168
+ "German Short Haired Pointer: 47\n",
169
+ "Giant Schnauzer: 48\n",
170
+ "Golden Retriever: 49\n",
171
+ "Gordon Setter: 50\n",
172
+ "Great Dane: 51\n",
173
+ "Great Pyrenees: 52\n",
174
+ "Greater Swiss Mountain Dog: 53\n",
175
+ "Groenendael: 54\n",
176
+ "Ibizan Hound: 55\n",
177
+ "Irish Setter: 56\n",
178
+ "Irish Terrier: 57\n",
179
+ "Irish Water Spaniel: 58\n",
180
+ "Irish Wolfhound: 59\n",
181
+ "Italian Greyhound: 60\n",
182
+ "Japanese Spaniel: 61\n",
183
+ "Keeshond: 62\n",
184
+ "Kelpie: 63\n",
185
+ "Kerry Blue Terrier: 64\n",
186
+ "Komondor: 65\n",
187
+ "Kuvasz: 66\n",
188
+ "Labrador Retriever: 67\n",
189
+ "Lakeland Terrier: 68\n",
190
+ "Leonberg: 69\n",
191
+ "Lhasa: 70\n",
192
+ "Malamute: 71\n",
193
+ "Malinois: 72\n",
194
+ "Maltese Dog: 73\n",
195
+ "Mexican Hairless: 74\n",
196
+ "Miniature Pinscher: 75\n",
197
+ "Miniature Poodle: 76\n",
198
+ "Miniature Schnauzer: 77\n",
199
+ "Newfoundland: 78\n",
200
+ "Norfolk Terrier: 79\n",
201
+ "Norwegian Elkhound: 80\n",
202
+ "Norwich Terrier: 81\n",
203
+ "Old English Sheepdog: 82\n",
204
+ "Otterhound: 83\n",
205
+ "Papillon: 84\n",
206
+ "Pekinese: 85\n",
207
+ "Pembroke: 86\n",
208
+ "Pomeranian: 87\n",
209
+ "Pug: 88\n",
210
+ "Redbone: 89\n",
211
+ "Rhodesian Ridgeback: 90\n",
212
+ "Rottweiler: 91\n",
213
+ "Saint Bernard: 92\n",
214
+ "Saluki: 93\n",
215
+ "Samoyed: 94\n",
216
+ "Schipperke: 95\n",
217
+ "Scotch Terrier: 96\n",
218
+ "Scottish Deerhound: 97\n",
219
+ "Sealyham Terrier: 98\n",
220
+ "Shetland Sheepdog: 99\n",
221
+ "Shih Tzu: 100\n",
222
+ "Siberian Husky: 101\n",
223
+ "Silky Terrier: 102\n",
224
+ "Soft Coated Wheaten Terrier: 103\n",
225
+ "Staffordshire Bullterrier: 104\n",
226
+ "Standard Poodle: 105\n",
227
+ "Standard Schnauzer: 106\n",
228
+ "Sussex Spaniel: 107\n",
229
+ "Tibetan Mastiff: 108\n",
230
+ "Tibetan Terrier: 109\n",
231
+ "Toy Poodle: 110\n",
232
+ "Toy Terrier: 111\n",
233
+ "Vizsla: 112\n",
234
+ "Walker Hound: 113\n",
235
+ "Weimaraner: 114\n",
236
+ "Welsh Springer Spaniel: 115\n",
237
+ "West Highland White Terrier: 116\n",
238
+ "Whippet: 117\n",
239
+ "Wire Haired Fox Terrier: 118\n",
240
+ "Yorkshire Terrier: 119\n"
241
+ ]
242
+ }
243
+ ],
244
+ "source": [
245
+ "if CREATE_DATASET:\n",
246
+ " ds = datasets.load_dataset(ORIGINAL_DATASET, token=os.getenv(\"HF_TOKEN\"), split=\"full\", trust_remote_code=True)\n",
247
+ " ds = ds.remove_columns(REMOVE_COLUMNS).rename_columns(RENAME_COLUMNS)\n",
248
+ "\n",
249
+ " labels = ds.select_columns(\"label\").to_pandas().sort_values(\"label\").get(\"label\").unique().tolist()\n",
250
+ " numbers = range(len(labels))\n",
251
+ " label2int = dict(zip(labels, numbers))\n",
252
+ " int2label = dict(zip(numbers, labels))\n",
253
+ "\n",
254
+ " for key, val in label2int.items():\n",
255
+ " print(f\"{key}: {val}\")\n",
256
+ "\n",
257
+ " ds = ds.class_encode_column(\"label\")\n",
258
+ " ds = ds.align_labels_with_mapping(label2int, \"label\")\n",
259
+ "\n",
260
+ " ds = ds.train_test_split(test_size=SPLIT, stratify_by_column = \"label\")\n",
261
+ " #ds.push_to_hub(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"))\n",
262
+ "\n",
263
+ " CONFIG[\"label2int\"] = str(label2int)\n",
264
+ " CONFIG[\"int2label\"] = str(int2label)\n",
265
+ "\n",
266
+ " # with open(\"output.toml\", \"w\") as toml_file:\n",
267
+ " # toml.dump(toml.dumps(CONFIG), toml_file)\n",
268
+ "\n",
269
+ " #ds = datasets.load_dataset(MODIFIED_DATASET, token=os.getenv(\"HF_TOKEN\"), trust_remote_code=True, streaming=True)"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": 3,
275
+ "metadata": {},
276
+ "outputs": [
277
+ {
278
+ "name": "stderr",
279
+ "output_type": "stream",
280
+ "text": [
281
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
282
+ " warnings.warn(\n",
283
+ "Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration. Please open a PR/issue to update `preprocessor_config.json` to use `image_processor_type` instead of `feature_extractor_type`. This warning will be removed in v4.40.\n",
284
+ "Some weights of Swinv2ForImageClassification were not initialized from the model checkpoint at microsoft/swinv2-base-patch4-window16-256 and are newly initialized because the shapes did not match:\n",
285
+ "- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([120, 1024]) in the model instantiated\n",
286
+ "- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([120]) in the model instantiated\n",
287
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
288
+ "max_steps is given, it will override any value given in num_train_epochs\n"
289
+ ]
290
+ },
291
+ {
292
+ "data": {
293
+ "application/vnd.jupyter.widget-view+json": {
294
+ "model_id": "a476611547454ef08359b18f9122e993",
295
+ "version_major": 2,
296
+ "version_minor": 0
297
+ },
298
+ "text/plain": [
299
+ " 0%| | 0/1000 [00:00<?, ?it/s]"
300
+ ]
301
+ },
302
+ "metadata": {},
303
+ "output_type": "display_data"
304
+ },
305
+ {
306
+ "name": "stderr",
307
+ "output_type": "stream",
308
+ "text": [
309
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
310
+ " warnings.warn(\n"
311
+ ]
312
+ },
313
+ {
314
+ "name": "stdout",
315
+ "output_type": "stream",
316
+ "text": [
317
+ "{'loss': 4.7518, 'grad_norm': 9.390625953674316, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.08}\n"
318
+ ]
319
+ },
320
+ {
321
+ "data": {
322
+ "application/vnd.jupyter.widget-view+json": {
323
+ "model_id": "8790a7299dba41f08047eb867a620e40",
324
+ "version_major": 2,
325
+ "version_minor": 0
326
+ },
327
+ "text/plain": [
328
+ " 0%| | 0/129 [00:00<?, ?it/s]"
329
+ ]
330
+ },
331
+ "metadata": {},
332
+ "output_type": "display_data"
333
+ },
334
+ {
335
+ "name": "stderr",
336
+ "output_type": "stream",
337
+ "text": [
338
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
339
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
340
+ ]
341
+ },
342
+ {
343
+ "name": "stdout",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "{'eval_loss': 4.639115810394287, 'eval_accuracy': 0.07410106899902819, 'eval_f1': 0.053338767481424125, 'eval_precision': 0.06668646345851367, 'eval_recall': 0.07050091579227699, 'eval_runtime': 122.7601, 'eval_samples_per_second': 33.529, 'eval_steps_per_second': 1.051, 'epoch': 0.08}\n"
347
+ ]
348
+ },
349
+ {
350
+ "name": "stderr",
351
+ "output_type": "stream",
352
+ "text": [
353
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
354
+ " warnings.warn(\n"
355
+ ]
356
+ },
357
+ {
358
+ "name": "stdout",
359
+ "output_type": "stream",
360
+ "text": [
361
+ "{'loss': 4.5585, 'grad_norm': 16.583328247070312, 'learning_rate': 4.9e-05, 'epoch': 0.16}\n"
362
+ ]
363
+ },
364
+ {
365
+ "data": {
366
+ "application/vnd.jupyter.widget-view+json": {
367
+ "model_id": "2d1989e313e24c889ce12f06eb6387c4",
368
+ "version_major": 2,
369
+ "version_minor": 0
370
+ },
371
+ "text/plain": [
372
+ " 0%| | 0/129 [00:00<?, ?it/s]"
373
+ ]
374
+ },
375
+ "metadata": {},
376
+ "output_type": "display_data"
377
+ },
378
+ {
379
+ "name": "stderr",
380
+ "output_type": "stream",
381
+ "text": [
382
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
383
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
384
+ ]
385
+ },
386
+ {
387
+ "name": "stdout",
388
+ "output_type": "stream",
389
+ "text": [
390
+ "{'eval_loss': 4.346254348754883, 'eval_accuracy': 0.19193391642371235, 'eval_f1': 0.14445541276201207, 'eval_precision': 0.18999966240710864, 'eval_recall': 0.1794198274691427, 'eval_runtime': 125.4121, 'eval_samples_per_second': 32.82, 'eval_steps_per_second': 1.029, 'epoch': 0.16}\n"
391
+ ]
392
+ },
393
+ {
394
+ "name": "stderr",
395
+ "output_type": "stream",
396
+ "text": [
397
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
398
+ " warnings.warn(\n"
399
+ ]
400
+ },
401
+ {
402
+ "name": "stdout",
403
+ "output_type": "stream",
404
+ "text": [
405
+ "{'loss': 4.2377, 'grad_norm': 26.891368865966797, 'learning_rate': 4.85e-05, 'epoch': 0.23}\n"
406
+ ]
407
+ },
408
+ {
409
+ "data": {
410
+ "application/vnd.jupyter.widget-view+json": {
411
+ "model_id": "1670e0352d4a4e42b84d6e43726fefdf",
412
+ "version_major": 2,
413
+ "version_minor": 0
414
+ },
415
+ "text/plain": [
416
+ " 0%| | 0/129 [00:00<?, ?it/s]"
417
+ ]
418
+ },
419
+ "metadata": {},
420
+ "output_type": "display_data"
421
+ },
422
+ {
423
+ "name": "stderr",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
427
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
428
+ ]
429
+ },
430
+ {
431
+ "name": "stdout",
432
+ "output_type": "stream",
433
+ "text": [
434
+ "{'eval_loss': 3.8243072032928467, 'eval_accuracy': 0.3525267249757046, 'eval_f1': 0.31002515422438126, 'eval_precision': 0.41537451879077536, 'eval_recall': 0.33815280861269076, 'eval_runtime': 121.7465, 'eval_samples_per_second': 33.808, 'eval_steps_per_second': 1.06, 'epoch': 0.23}\n"
435
+ ]
436
+ },
437
+ {
438
+ "name": "stderr",
439
+ "output_type": "stream",
440
+ "text": [
441
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
442
+ " warnings.warn(\n"
443
+ ]
444
+ },
445
+ {
446
+ "name": "stdout",
447
+ "output_type": "stream",
448
+ "text": [
449
+ "{'loss': 3.6654, 'grad_norm': 36.484336853027344, 'learning_rate': 4.8e-05, 'epoch': 0.31}\n"
450
+ ]
451
+ },
452
+ {
453
+ "data": {
454
+ "application/vnd.jupyter.widget-view+json": {
455
+ "model_id": "f6ab5c20d53e4f9b961fda0faf7ed68e",
456
+ "version_major": 2,
457
+ "version_minor": 0
458
+ },
459
+ "text/plain": [
460
+ " 0%| | 0/129 [00:00<?, ?it/s]"
461
+ ]
462
+ },
463
+ "metadata": {},
464
+ "output_type": "display_data"
465
+ },
466
+ {
467
+ "name": "stderr",
468
+ "output_type": "stream",
469
+ "text": [
470
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
471
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
472
+ ]
473
+ },
474
+ {
475
+ "name": "stdout",
476
+ "output_type": "stream",
477
+ "text": [
478
+ "{'eval_loss': 2.9276485443115234, 'eval_accuracy': 0.640913508260447, 'eval_f1': 0.6111447372819326, 'eval_precision': 0.6994107673804264, 'eval_recall': 0.6299850736069784, 'eval_runtime': 119.9978, 'eval_samples_per_second': 34.301, 'eval_steps_per_second': 1.075, 'epoch': 0.31}\n"
479
+ ]
480
+ },
481
+ {
482
+ "name": "stderr",
483
+ "output_type": "stream",
484
+ "text": [
485
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
486
+ " warnings.warn(\n"
487
+ ]
488
+ },
489
+ {
490
+ "name": "stdout",
491
+ "output_type": "stream",
492
+ "text": [
493
+ "{'loss': 2.7617, 'grad_norm': 39.120765686035156, 'learning_rate': 4.75e-05, 'epoch': 0.39}\n"
494
+ ]
495
+ },
496
+ {
497
+ "data": {
498
+ "application/vnd.jupyter.widget-view+json": {
499
+ "model_id": "94a2240afe624a2ebe8649ce43b156bc",
500
+ "version_major": 2,
501
+ "version_minor": 0
502
+ },
503
+ "text/plain": [
504
+ " 0%| | 0/129 [00:00<?, ?it/s]"
505
+ ]
506
+ },
507
+ "metadata": {},
508
+ "output_type": "display_data"
509
+ },
510
+ {
511
+ "name": "stderr",
512
+ "output_type": "stream",
513
+ "text": [
514
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
515
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
516
+ ]
517
+ },
518
+ {
519
+ "name": "stdout",
520
+ "output_type": "stream",
521
+ "text": [
522
+ "{'eval_loss': 1.7703156471252441, 'eval_accuracy': 0.8248299319727891, 'eval_f1': 0.8041716661861199, 'eval_precision': 0.8361251271357693, 'eval_recall': 0.8181542323069564, 'eval_runtime': 118.2837, 'eval_samples_per_second': 34.798, 'eval_steps_per_second': 1.091, 'epoch': 0.39}\n"
523
+ ]
524
+ },
525
+ {
526
+ "name": "stderr",
527
+ "output_type": "stream",
528
+ "text": [
529
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
530
+ " warnings.warn(\n"
531
+ ]
532
+ },
533
+ {
534
+ "name": "stdout",
535
+ "output_type": "stream",
536
+ "text": [
537
+ "{'loss': 1.9475, 'grad_norm': 39.01482009887695, 'learning_rate': 4.7e-05, 'epoch': 0.47}\n"
538
+ ]
539
+ },
540
+ {
541
+ "data": {
542
+ "application/vnd.jupyter.widget-view+json": {
543
+ "model_id": "0ca6adbf62c74f90840087469c261dcb",
544
+ "version_major": 2,
545
+ "version_minor": 0
546
+ },
547
+ "text/plain": [
548
+ " 0%| | 0/129 [00:00<?, ?it/s]"
549
+ ]
550
+ },
551
+ "metadata": {},
552
+ "output_type": "display_data"
553
+ },
554
+ {
555
+ "name": "stderr",
556
+ "output_type": "stream",
557
+ "text": [
558
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
559
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
560
+ ]
561
+ },
562
+ {
563
+ "name": "stdout",
564
+ "output_type": "stream",
565
+ "text": [
566
+ "{'eval_loss': 1.043952465057373, 'eval_accuracy': 0.8862973760932945, 'eval_f1': 0.8780889702421721, 'eval_precision': 0.8924114337387474, 'eval_recall': 0.882146610947727, 'eval_runtime': 116.9676, 'eval_samples_per_second': 35.189, 'eval_steps_per_second': 1.103, 'epoch': 0.47}\n"
567
+ ]
568
+ },
569
+ {
570
+ "name": "stderr",
571
+ "output_type": "stream",
572
+ "text": [
573
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
574
+ " warnings.warn(\n"
575
+ ]
576
+ },
577
+ {
578
+ "name": "stdout",
579
+ "output_type": "stream",
580
+ "text": [
581
+ "{'loss': 1.3629, 'grad_norm': 43.029022216796875, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.54}\n"
582
+ ]
583
+ },
584
+ {
585
+ "data": {
586
+ "application/vnd.jupyter.widget-view+json": {
587
+ "model_id": "af9d01f55c914ca0bab82f3ae3941567",
588
+ "version_major": 2,
589
+ "version_minor": 0
590
+ },
591
+ "text/plain": [
592
+ " 0%| | 0/129 [00:00<?, ?it/s]"
593
+ ]
594
+ },
595
+ "metadata": {},
596
+ "output_type": "display_data"
597
+ },
598
+ {
599
+ "name": "stdout",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "{'eval_loss': 0.6490112543106079, 'eval_accuracy': 0.9098639455782312, 'eval_f1': 0.9031378990742644, 'eval_precision': 0.9191206744821958, 'eval_recall': 0.9061687520491638, 'eval_runtime': 117.6616, 'eval_samples_per_second': 34.982, 'eval_steps_per_second': 1.096, 'epoch': 0.54}\n"
603
+ ]
604
+ },
605
+ {
606
+ "name": "stderr",
607
+ "output_type": "stream",
608
+ "text": [
609
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
610
+ " warnings.warn(\n"
611
+ ]
612
+ },
613
+ {
614
+ "name": "stdout",
615
+ "output_type": "stream",
616
+ "text": [
617
+ "{'loss': 1.0488, 'grad_norm': 39.756900787353516, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.62}\n"
618
+ ]
619
+ },
620
+ {
621
+ "data": {
622
+ "application/vnd.jupyter.widget-view+json": {
623
+ "model_id": "15b72c0511c34c4892b22af3d3aa30f0",
624
+ "version_major": 2,
625
+ "version_minor": 0
626
+ },
627
+ "text/plain": [
628
+ " 0%| | 0/129 [00:00<?, ?it/s]"
629
+ ]
630
+ },
631
+ "metadata": {},
632
+ "output_type": "display_data"
633
+ },
634
+ {
635
+ "name": "stderr",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
639
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
640
+ ]
641
+ },
642
+ {
643
+ "name": "stdout",
644
+ "output_type": "stream",
645
+ "text": [
646
+ "{'eval_loss': 0.4484623670578003, 'eval_accuracy': 0.9149659863945578, 'eval_f1': 0.9074605144391643, 'eval_precision': 0.9146605032717635, 'eval_recall': 0.9118267161714619, 'eval_runtime': 119.5898, 'eval_samples_per_second': 34.418, 'eval_steps_per_second': 1.079, 'epoch': 0.62}\n"
647
+ ]
648
+ },
649
+ {
650
+ "name": "stderr",
651
+ "output_type": "stream",
652
+ "text": [
653
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
654
+ " warnings.warn(\n"
655
+ ]
656
+ },
657
+ {
658
+ "name": "stdout",
659
+ "output_type": "stream",
660
+ "text": [
661
+ "{'loss': 0.8477, 'grad_norm': 33.365509033203125, 'learning_rate': 4.55e-05, 'epoch': 0.7}\n"
662
+ ]
663
+ },
664
+ {
665
+ "data": {
666
+ "application/vnd.jupyter.widget-view+json": {
667
+ "model_id": "113e272ba5454f859e9cef987ca0c638",
668
+ "version_major": 2,
669
+ "version_minor": 0
670
+ },
671
+ "text/plain": [
672
+ " 0%| | 0/129 [00:00<?, ?it/s]"
673
+ ]
674
+ },
675
+ "metadata": {},
676
+ "output_type": "display_data"
677
+ },
678
+ {
679
+ "name": "stdout",
680
+ "output_type": "stream",
681
+ "text": [
682
+ "{'eval_loss': 0.3743549585342407, 'eval_accuracy': 0.9205539358600583, 'eval_f1': 0.9169219807550186, 'eval_precision': 0.9293988137338194, 'eval_recall': 0.9189619888825881, 'eval_runtime': 120.9735, 'eval_samples_per_second': 34.024, 'eval_steps_per_second': 1.066, 'epoch': 0.7}\n"
683
+ ]
684
+ },
685
+ {
686
+ "name": "stderr",
687
+ "output_type": "stream",
688
+ "text": [
689
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
690
+ " warnings.warn(\n"
691
+ ]
692
+ },
693
+ {
694
+ "name": "stdout",
695
+ "output_type": "stream",
696
+ "text": [
697
+ "{'loss': 0.7184, 'grad_norm': 34.209716796875, 'learning_rate': 4.5e-05, 'epoch': 0.78}\n"
698
+ ]
699
+ },
700
+ {
701
+ "data": {
702
+ "application/vnd.jupyter.widget-view+json": {
703
+ "model_id": "627a14cb020c4a77ba73f92cdfcaa8a2",
704
+ "version_major": 2,
705
+ "version_minor": 0
706
+ },
707
+ "text/plain": [
708
+ " 0%| | 0/129 [00:00<?, ?it/s]"
709
+ ]
710
+ },
711
+ "metadata": {},
712
+ "output_type": "display_data"
713
+ },
714
+ {
715
+ "name": "stdout",
716
+ "output_type": "stream",
717
+ "text": [
718
+ "{'eval_loss': 0.3300754427909851, 'eval_accuracy': 0.9258989310009719, 'eval_f1': 0.9215361801462706, 'eval_precision': 0.9282778107282813, 'eval_recall': 0.9227499673950158, 'eval_runtime': 120.3629, 'eval_samples_per_second': 34.197, 'eval_steps_per_second': 1.072, 'epoch': 0.78}\n"
719
+ ]
720
+ },
721
+ {
722
+ "name": "stderr",
723
+ "output_type": "stream",
724
+ "text": [
725
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
726
+ " warnings.warn(\n"
727
+ ]
728
+ },
729
+ {
730
+ "name": "stdout",
731
+ "output_type": "stream",
732
+ "text": [
733
+ "{'loss': 0.7149, 'grad_norm': 30.539337158203125, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.85}\n"
734
+ ]
735
+ },
736
+ {
737
+ "data": {
738
+ "application/vnd.jupyter.widget-view+json": {
739
+ "model_id": "f2f8b8fc76bb4cd9aa79a3c2e1762716",
740
+ "version_major": 2,
741
+ "version_minor": 0
742
+ },
743
+ "text/plain": [
744
+ " 0%| | 0/129 [00:00<?, ?it/s]"
745
+ ]
746
+ },
747
+ "metadata": {},
748
+ "output_type": "display_data"
749
+ },
750
+ {
751
+ "name": "stdout",
752
+ "output_type": "stream",
753
+ "text": [
754
+ "{'eval_loss': 0.2970269024372101, 'eval_accuracy': 0.9186103012633625, 'eval_f1': 0.9152438416028247, 'eval_precision': 0.9226596717426415, 'eval_recall': 0.9156450507816004, 'eval_runtime': 118.1418, 'eval_samples_per_second': 34.839, 'eval_steps_per_second': 1.092, 'epoch': 0.85}\n"
755
+ ]
756
+ },
757
+ {
758
+ "name": "stderr",
759
+ "output_type": "stream",
760
+ "text": [
761
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
762
+ " warnings.warn(\n"
763
+ ]
764
+ },
765
+ {
766
+ "name": "stdout",
767
+ "output_type": "stream",
768
+ "text": [
769
+ "{'loss': 0.6429, 'grad_norm': 66.82299041748047, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.93}\n"
770
+ ]
771
+ },
772
+ {
773
+ "data": {
774
+ "application/vnd.jupyter.widget-view+json": {
775
+ "model_id": "3b9f408845014c45ac4724ba33a21fee",
776
+ "version_major": 2,
777
+ "version_minor": 0
778
+ },
779
+ "text/plain": [
780
+ " 0%| | 0/129 [00:00<?, ?it/s]"
781
+ ]
782
+ },
783
+ "metadata": {},
784
+ "output_type": "display_data"
785
+ },
786
+ {
787
+ "name": "stdout",
788
+ "output_type": "stream",
789
+ "text": [
790
+ "{'eval_loss': 0.2674591839313507, 'eval_accuracy': 0.9285714285714286, 'eval_f1': 0.9238129060193967, 'eval_precision': 0.9301368936281837, 'eval_recall': 0.9256311870951985, 'eval_runtime': 118.2453, 'eval_samples_per_second': 34.809, 'eval_steps_per_second': 1.091, 'epoch': 0.93}\n"
791
+ ]
792
+ },
793
+ {
794
+ "name": "stderr",
795
+ "output_type": "stream",
796
+ "text": [
797
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
798
+ " warnings.warn(\n"
799
+ ]
800
+ },
801
+ {
802
+ "name": "stdout",
803
+ "output_type": "stream",
804
+ "text": [
805
+ "{'loss': 0.5864, 'grad_norm': 33.499786376953125, 'learning_rate': 4.35e-05, 'epoch': 1.01}\n"
806
+ ]
807
+ },
808
+ {
809
+ "data": {
810
+ "application/vnd.jupyter.widget-view+json": {
811
+ "model_id": "5c5220c50a3542689623f7923b7e19af",
812
+ "version_major": 2,
813
+ "version_minor": 0
814
+ },
815
+ "text/plain": [
816
+ " 0%| | 0/129 [00:00<?, ?it/s]"
817
+ ]
818
+ },
819
+ "metadata": {},
820
+ "output_type": "display_data"
821
+ },
822
+ {
823
+ "name": "stdout",
824
+ "output_type": "stream",
825
+ "text": [
826
+ "{'eval_loss': 0.260906845331192, 'eval_accuracy': 0.9290573372206026, 'eval_f1': 0.9258240260144658, 'eval_precision': 0.9338417051433888, 'eval_recall': 0.9271689104847066, 'eval_runtime': 113.4782, 'eval_samples_per_second': 36.271, 'eval_steps_per_second': 1.137, 'epoch': 1.01}\n"
827
+ ]
828
+ },
829
+ {
830
+ "name": "stderr",
831
+ "output_type": "stream",
832
+ "text": [
833
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
834
+ " warnings.warn(\n"
835
+ ]
836
+ },
837
+ {
838
+ "name": "stdout",
839
+ "output_type": "stream",
840
+ "text": [
841
+ "{'loss': 0.5414, 'grad_norm': 27.897539138793945, 'learning_rate': 4.3e-05, 'epoch': 1.09}\n"
842
+ ]
843
+ },
844
+ {
845
+ "data": {
846
+ "application/vnd.jupyter.widget-view+json": {
847
+ "model_id": "974d2f0a759544b3ae1f20cd0d56e64b",
848
+ "version_major": 2,
849
+ "version_minor": 0
850
+ },
851
+ "text/plain": [
852
+ " 0%| | 0/129 [00:00<?, ?it/s]"
853
+ ]
854
+ },
855
+ "metadata": {},
856
+ "output_type": "display_data"
857
+ },
858
+ {
859
+ "name": "stderr",
860
+ "output_type": "stream",
861
+ "text": [
862
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/sklearn/metrics/_classification.py:1509: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
863
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
864
+ ]
865
+ },
866
+ {
867
+ "name": "stdout",
868
+ "output_type": "stream",
869
+ "text": [
870
+ "{'eval_loss': 0.2644219696521759, 'eval_accuracy': 0.9161807580174927, 'eval_f1': 0.9122389605151222, 'eval_precision': 0.9211654018756885, 'eval_recall': 0.9155822994203213, 'eval_runtime': 113.8144, 'eval_samples_per_second': 36.164, 'eval_steps_per_second': 1.133, 'epoch': 1.09}\n"
871
+ ]
872
+ },
873
+ {
874
+ "name": "stderr",
875
+ "output_type": "stream",
876
+ "text": [
877
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
878
+ " warnings.warn(\n"
879
+ ]
880
+ },
881
+ {
882
+ "name": "stdout",
883
+ "output_type": "stream",
884
+ "text": [
885
+ "{'loss': 0.5323, 'grad_norm': 42.06348419189453, 'learning_rate': 4.25e-05, 'epoch': 1.17}\n"
886
+ ]
887
+ },
888
+ {
889
+ "data": {
890
+ "application/vnd.jupyter.widget-view+json": {
891
+ "model_id": "a9560d3b29354120a668f2a82581251f",
892
+ "version_major": 2,
893
+ "version_minor": 0
894
+ },
895
+ "text/plain": [
896
+ " 0%| | 0/129 [00:00<?, ?it/s]"
897
+ ]
898
+ },
899
+ "metadata": {},
900
+ "output_type": "display_data"
901
+ },
902
+ {
903
+ "name": "stdout",
904
+ "output_type": "stream",
905
+ "text": [
906
+ "{'eval_loss': 0.2454349249601364, 'eval_accuracy': 0.9280855199222546, 'eval_f1': 0.922547904522988, 'eval_precision': 0.9361557799152643, 'eval_recall': 0.9256003853911859, 'eval_runtime': 117.6285, 'eval_samples_per_second': 34.992, 'eval_steps_per_second': 1.097, 'epoch': 1.17}\n"
907
+ ]
908
+ },
909
+ {
910
+ "name": "stderr",
911
+ "output_type": "stream",
912
+ "text": [
913
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
914
+ " warnings.warn(\n"
915
+ ]
916
+ },
917
+ {
918
+ "name": "stdout",
919
+ "output_type": "stream",
920
+ "text": [
921
+ "{'loss': 0.5061, 'grad_norm': 27.641525268554688, 'learning_rate': 4.2e-05, 'epoch': 1.24}\n"
922
+ ]
923
+ },
924
+ {
925
+ "data": {
926
+ "application/vnd.jupyter.widget-view+json": {
927
+ "model_id": "72904f84b88b4f8b957f68895478cc7d",
928
+ "version_major": 2,
929
+ "version_minor": 0
930
+ },
931
+ "text/plain": [
932
+ " 0%| | 0/129 [00:00<?, ?it/s]"
933
+ ]
934
+ },
935
+ "metadata": {},
936
+ "output_type": "display_data"
937
+ },
938
+ {
939
+ "name": "stdout",
940
+ "output_type": "stream",
941
+ "text": [
942
+ "{'eval_loss': 0.2481379508972168, 'eval_accuracy': 0.9268707482993197, 'eval_f1': 0.9234504572819081, 'eval_precision': 0.9308003097280835, 'eval_recall': 0.9250780783392921, 'eval_runtime': 118.0847, 'eval_samples_per_second': 34.856, 'eval_steps_per_second': 1.092, 'epoch': 1.24}\n"
943
+ ]
944
+ },
945
+ {
946
+ "name": "stderr",
947
+ "output_type": "stream",
948
+ "text": [
949
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
950
+ " warnings.warn(\n"
951
+ ]
952
+ },
953
+ {
954
+ "name": "stdout",
955
+ "output_type": "stream",
956
+ "text": [
957
+ "{'loss': 0.5898, 'grad_norm': 29.9394588470459, 'learning_rate': 4.15e-05, 'epoch': 1.32}\n"
958
+ ]
959
+ },
960
+ {
961
+ "data": {
962
+ "application/vnd.jupyter.widget-view+json": {
963
+ "model_id": "058fe79d2fe24cf3ad6a5399acaaa419",
964
+ "version_major": 2,
965
+ "version_minor": 0
966
+ },
967
+ "text/plain": [
968
+ " 0%| | 0/129 [00:00<?, ?it/s]"
969
+ ]
970
+ },
971
+ "metadata": {},
972
+ "output_type": "display_data"
973
+ },
974
+ {
975
+ "name": "stdout",
976
+ "output_type": "stream",
977
+ "text": [
978
+ "{'eval_loss': 0.23057658970355988, 'eval_accuracy': 0.934645286686103, 'eval_f1': 0.9324039888619194, 'eval_precision': 0.9388904263022082, 'eval_recall': 0.9331000934899664, 'eval_runtime': 118.2346, 'eval_samples_per_second': 34.812, 'eval_steps_per_second': 1.091, 'epoch': 1.32}\n"
979
+ ]
980
+ },
981
+ {
982
+ "name": "stderr",
983
+ "output_type": "stream",
984
+ "text": [
985
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
986
+ " warnings.warn(\n"
987
+ ]
988
+ },
989
+ {
990
+ "name": "stdout",
991
+ "output_type": "stream",
992
+ "text": [
993
+ "{'loss': 0.5277, 'grad_norm': 22.529125213623047, 'learning_rate': 4.1e-05, 'epoch': 1.4}\n"
994
+ ]
995
+ },
996
+ {
997
+ "data": {
998
+ "application/vnd.jupyter.widget-view+json": {
999
+ "model_id": "ef3f81cfed384185ade86ccb13896f17",
1000
+ "version_major": 2,
1001
+ "version_minor": 0
1002
+ },
1003
+ "text/plain": [
1004
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1005
+ ]
1006
+ },
1007
+ "metadata": {},
1008
+ "output_type": "display_data"
1009
+ },
1010
+ {
1011
+ "name": "stdout",
1012
+ "output_type": "stream",
1013
+ "text": [
1014
+ "{'eval_loss': 0.21916857361793518, 'eval_accuracy': 0.9368318756073858, 'eval_f1': 0.9326511251722105, 'eval_precision': 0.9383714959002939, 'eval_recall': 0.9349716395661107, 'eval_runtime': 118.7182, 'eval_samples_per_second': 34.67, 'eval_steps_per_second': 1.087, 'epoch': 1.4}\n"
1015
+ ]
1016
+ },
1017
+ {
1018
+ "name": "stderr",
1019
+ "output_type": "stream",
1020
+ "text": [
1021
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1022
+ " warnings.warn(\n"
1023
+ ]
1024
+ },
1025
+ {
1026
+ "name": "stdout",
1027
+ "output_type": "stream",
1028
+ "text": [
1029
+ "{'loss': 0.4824, 'grad_norm': 36.2570915222168, 'learning_rate': 4.05e-05, 'epoch': 1.48}\n"
1030
+ ]
1031
+ },
1032
+ {
1033
+ "data": {
1034
+ "application/vnd.jupyter.widget-view+json": {
1035
+ "model_id": "089d52cc63444a2c8cde2124e45caa15",
1036
+ "version_major": 2,
1037
+ "version_minor": 0
1038
+ },
1039
+ "text/plain": [
1040
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1041
+ ]
1042
+ },
1043
+ "metadata": {},
1044
+ "output_type": "display_data"
1045
+ },
1046
+ {
1047
+ "name": "stdout",
1048
+ "output_type": "stream",
1049
+ "text": [
1050
+ "{'eval_loss': 0.21711421012878418, 'eval_accuracy': 0.9336734693877551, 'eval_f1': 0.929709056352932, 'eval_precision': 0.9374557958848866, 'eval_recall': 0.9311224774812498, 'eval_runtime': 119.027, 'eval_samples_per_second': 34.58, 'eval_steps_per_second': 1.084, 'epoch': 1.48}\n"
1051
+ ]
1052
+ },
1053
+ {
1054
+ "name": "stderr",
1055
+ "output_type": "stream",
1056
+ "text": [
1057
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1058
+ " warnings.warn(\n"
1059
+ ]
1060
+ },
1061
+ {
1062
+ "name": "stdout",
1063
+ "output_type": "stream",
1064
+ "text": [
1065
+ "{'loss': 0.4632, 'grad_norm': 22.403745651245117, 'learning_rate': 4e-05, 'epoch': 1.55}\n"
1066
+ ]
1067
+ },
1068
+ {
1069
+ "data": {
1070
+ "application/vnd.jupyter.widget-view+json": {
1071
+ "model_id": "01e5da7e274a40d29c13929ed2494e9c",
1072
+ "version_major": 2,
1073
+ "version_minor": 0
1074
+ },
1075
+ "text/plain": [
1076
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1077
+ ]
1078
+ },
1079
+ "metadata": {},
1080
+ "output_type": "display_data"
1081
+ },
1082
+ {
1083
+ "name": "stdout",
1084
+ "output_type": "stream",
1085
+ "text": [
1086
+ "{'eval_loss': 0.22438718378543854, 'eval_accuracy': 0.934645286686103, 'eval_f1': 0.9314573088004496, 'eval_precision': 0.9378662507720702, 'eval_recall': 0.9326384727705795, 'eval_runtime': 121.7252, 'eval_samples_per_second': 33.814, 'eval_steps_per_second': 1.06, 'epoch': 1.55}\n"
1087
+ ]
1088
+ },
1089
+ {
1090
+ "name": "stderr",
1091
+ "output_type": "stream",
1092
+ "text": [
1093
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1094
+ " warnings.warn(\n"
1095
+ ]
1096
+ },
1097
+ {
1098
+ "name": "stdout",
1099
+ "output_type": "stream",
1100
+ "text": [
1101
+ "{'loss': 0.4882, 'grad_norm': 19.468097686767578, 'learning_rate': 3.9500000000000005e-05, 'epoch': 1.63}\n"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "data": {
1106
+ "application/vnd.jupyter.widget-view+json": {
1107
+ "model_id": "0281f400fb664087a3c2f50895bcb7f0",
1108
+ "version_major": 2,
1109
+ "version_minor": 0
1110
+ },
1111
+ "text/plain": [
1112
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1113
+ ]
1114
+ },
1115
+ "metadata": {},
1116
+ "output_type": "display_data"
1117
+ },
1118
+ {
1119
+ "name": "stdout",
1120
+ "output_type": "stream",
1121
+ "text": [
1122
+ "{'eval_loss': 0.2236798256635666, 'eval_accuracy': 0.9361030126336248, 'eval_f1': 0.9322983411373361, 'eval_precision': 0.9404104668252149, 'eval_recall': 0.9345130791697032, 'eval_runtime': 115.5911, 'eval_samples_per_second': 35.608, 'eval_steps_per_second': 1.116, 'epoch': 1.63}\n"
1123
+ ]
1124
+ },
1125
+ {
1126
+ "name": "stderr",
1127
+ "output_type": "stream",
1128
+ "text": [
1129
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1130
+ " warnings.warn(\n"
1131
+ ]
1132
+ },
1133
+ {
1134
+ "name": "stdout",
1135
+ "output_type": "stream",
1136
+ "text": [
1137
+ "{'loss': 0.4583, 'grad_norm': 29.652774810791016, 'learning_rate': 3.9000000000000006e-05, 'epoch': 1.71}\n"
1138
+ ]
1139
+ },
1140
+ {
1141
+ "data": {
1142
+ "application/vnd.jupyter.widget-view+json": {
1143
+ "model_id": "0d1e88ee9e994a8595b8920dbccb6600",
1144
+ "version_major": 2,
1145
+ "version_minor": 0
1146
+ },
1147
+ "text/plain": [
1148
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1149
+ ]
1150
+ },
1151
+ "metadata": {},
1152
+ "output_type": "display_data"
1153
+ },
1154
+ {
1155
+ "name": "stdout",
1156
+ "output_type": "stream",
1157
+ "text": [
1158
+ "{'eval_loss': 0.2227955460548401, 'eval_accuracy': 0.9327016520894071, 'eval_f1': 0.9288976376084872, 'eval_precision': 0.9372589912326925, 'eval_recall': 0.9304356780487182, 'eval_runtime': 118.2652, 'eval_samples_per_second': 34.803, 'eval_steps_per_second': 1.091, 'epoch': 1.71}\n"
1159
+ ]
1160
+ },
1161
+ {
1162
+ "name": "stderr",
1163
+ "output_type": "stream",
1164
+ "text": [
1165
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1166
+ " warnings.warn(\n"
1167
+ ]
1168
+ },
1169
+ {
1170
+ "name": "stdout",
1171
+ "output_type": "stream",
1172
+ "text": [
1173
+ "{'loss': 0.4692, 'grad_norm': 45.817378997802734, 'learning_rate': 3.85e-05, 'epoch': 1.79}\n"
1174
+ ]
1175
+ },
1176
+ {
1177
+ "data": {
1178
+ "application/vnd.jupyter.widget-view+json": {
1179
+ "model_id": "15718cac9aaf476e83184921d6eb2aff",
1180
+ "version_major": 2,
1181
+ "version_minor": 0
1182
+ },
1183
+ "text/plain": [
1184
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1185
+ ]
1186
+ },
1187
+ "metadata": {},
1188
+ "output_type": "display_data"
1189
+ },
1190
+ {
1191
+ "name": "stdout",
1192
+ "output_type": "stream",
1193
+ "text": [
1194
+ "{'eval_loss': 0.20977580547332764, 'eval_accuracy': 0.935374149659864, 'eval_f1': 0.931613286271804, 'eval_precision': 0.9370143572286509, 'eval_recall': 0.9332420309524705, 'eval_runtime': 116.3014, 'eval_samples_per_second': 35.391, 'eval_steps_per_second': 1.109, 'epoch': 1.79}\n"
1195
+ ]
1196
+ },
1197
+ {
1198
+ "name": "stderr",
1199
+ "output_type": "stream",
1200
+ "text": [
1201
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1202
+ " warnings.warn(\n"
1203
+ ]
1204
+ },
1205
+ {
1206
+ "name": "stdout",
1207
+ "output_type": "stream",
1208
+ "text": [
1209
+ "{'loss': 0.5407, 'grad_norm': 48.70719528198242, 'learning_rate': 3.8e-05, 'epoch': 1.86}\n"
1210
+ ]
1211
+ },
1212
+ {
1213
+ "data": {
1214
+ "application/vnd.jupyter.widget-view+json": {
1215
+ "model_id": "367811b91762488a90f8d82e3b5e35d6",
1216
+ "version_major": 2,
1217
+ "version_minor": 0
1218
+ },
1219
+ "text/plain": [
1220
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1221
+ ]
1222
+ },
1223
+ "metadata": {},
1224
+ "output_type": "display_data"
1225
+ },
1226
+ {
1227
+ "name": "stdout",
1228
+ "output_type": "stream",
1229
+ "text": [
1230
+ "{'eval_loss': 0.21022267639636993, 'eval_accuracy': 0.935617103984451, 'eval_f1': 0.9342051845781768, 'eval_precision': 0.9374551819177072, 'eval_recall': 0.9350949929613589, 'eval_runtime': 119.8444, 'eval_samples_per_second': 34.345, 'eval_steps_per_second': 1.076, 'epoch': 1.86}\n"
1231
+ ]
1232
+ },
1233
+ {
1234
+ "name": "stderr",
1235
+ "output_type": "stream",
1236
+ "text": [
1237
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1238
+ " warnings.warn(\n"
1239
+ ]
1240
+ },
1241
+ {
1242
+ "name": "stdout",
1243
+ "output_type": "stream",
1244
+ "text": [
1245
+ "{'loss': 0.4629, 'grad_norm': 40.25789260864258, 'learning_rate': 3.7500000000000003e-05, 'epoch': 1.94}\n"
1246
+ ]
1247
+ },
1248
+ {
1249
+ "data": {
1250
+ "application/vnd.jupyter.widget-view+json": {
1251
+ "model_id": "638ffc2535074361b9c7f32ab25e0d87",
1252
+ "version_major": 2,
1253
+ "version_minor": 0
1254
+ },
1255
+ "text/plain": [
1256
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1257
+ ]
1258
+ },
1259
+ "metadata": {},
1260
+ "output_type": "display_data"
1261
+ },
1262
+ {
1263
+ "name": "stdout",
1264
+ "output_type": "stream",
1265
+ "text": [
1266
+ "{'eval_loss': 0.20446407794952393, 'eval_accuracy': 0.9378036929057337, 'eval_f1': 0.9348527011477886, 'eval_precision': 0.9395619022083945, 'eval_recall': 0.936717917602765, 'eval_runtime': 114.3207, 'eval_samples_per_second': 36.004, 'eval_steps_per_second': 1.128, 'epoch': 1.94}\n"
1267
+ ]
1268
+ },
1269
+ {
1270
+ "name": "stderr",
1271
+ "output_type": "stream",
1272
+ "text": [
1273
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1274
+ " warnings.warn(\n"
1275
+ ]
1276
+ },
1277
+ {
1278
+ "name": "stdout",
1279
+ "output_type": "stream",
1280
+ "text": [
1281
+ "{'loss': 0.4363, 'grad_norm': 34.60853958129883, 'learning_rate': 3.7e-05, 'epoch': 2.02}\n"
1282
+ ]
1283
+ },
1284
+ {
1285
+ "data": {
1286
+ "application/vnd.jupyter.widget-view+json": {
1287
+ "model_id": "6b5c279e63a74a0ca7d999a3116dc304",
1288
+ "version_major": 2,
1289
+ "version_minor": 0
1290
+ },
1291
+ "text/plain": [
1292
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1293
+ ]
1294
+ },
1295
+ "metadata": {},
1296
+ "output_type": "display_data"
1297
+ },
1298
+ {
1299
+ "name": "stdout",
1300
+ "output_type": "stream",
1301
+ "text": [
1302
+ "{'eval_loss': 0.20228153467178345, 'eval_accuracy': 0.9373177842565598, 'eval_f1': 0.9346485296103381, 'eval_precision': 0.9397930084364895, 'eval_recall': 0.9354963888099695, 'eval_runtime': 115.3383, 'eval_samples_per_second': 35.686, 'eval_steps_per_second': 1.118, 'epoch': 2.02}\n"
1303
+ ]
1304
+ },
1305
+ {
1306
+ "name": "stderr",
1307
+ "output_type": "stream",
1308
+ "text": [
1309
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1310
+ " warnings.warn(\n"
1311
+ ]
1312
+ },
1313
+ {
1314
+ "name": "stdout",
1315
+ "output_type": "stream",
1316
+ "text": [
1317
+ "{'loss': 0.4328, 'grad_norm': 26.99337387084961, 'learning_rate': 3.65e-05, 'epoch': 2.1}\n"
1318
+ ]
1319
+ },
1320
+ {
1321
+ "data": {
1322
+ "application/vnd.jupyter.widget-view+json": {
1323
+ "model_id": "b53ad317ab1a44699c22f32119cbc833",
1324
+ "version_major": 2,
1325
+ "version_minor": 0
1326
+ },
1327
+ "text/plain": [
1328
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1329
+ ]
1330
+ },
1331
+ "metadata": {},
1332
+ "output_type": "display_data"
1333
+ },
1334
+ {
1335
+ "name": "stdout",
1336
+ "output_type": "stream",
1337
+ "text": [
1338
+ "{'eval_loss': 0.20625373721122742, 'eval_accuracy': 0.935374149659864, 'eval_f1': 0.9319561975473203, 'eval_precision': 0.9360019214632094, 'eval_recall': 0.93430737851741, 'eval_runtime': 114.9885, 'eval_samples_per_second': 35.795, 'eval_steps_per_second': 1.122, 'epoch': 2.1}\n"
1339
+ ]
1340
+ },
1341
+ {
1342
+ "name": "stderr",
1343
+ "output_type": "stream",
1344
+ "text": [
1345
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1346
+ " warnings.warn(\n"
1347
+ ]
1348
+ },
1349
+ {
1350
+ "name": "stdout",
1351
+ "output_type": "stream",
1352
+ "text": [
1353
+ "{'loss': 0.3554, 'grad_norm': 28.036481857299805, 'learning_rate': 3.6e-05, 'epoch': 2.17}\n"
1354
+ ]
1355
+ },
1356
+ {
1357
+ "data": {
1358
+ "application/vnd.jupyter.widget-view+json": {
1359
+ "model_id": "5d74e2a9b5094d0a93eb68fb5564da50",
1360
+ "version_major": 2,
1361
+ "version_minor": 0
1362
+ },
1363
+ "text/plain": [
1364
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1365
+ ]
1366
+ },
1367
+ "metadata": {},
1368
+ "output_type": "display_data"
1369
+ },
1370
+ {
1371
+ "name": "stdout",
1372
+ "output_type": "stream",
1373
+ "text": [
1374
+ "{'eval_loss': 0.19484364986419678, 'eval_accuracy': 0.9438775510204082, 'eval_f1': 0.9397775702488407, 'eval_precision': 0.9474960731219763, 'eval_recall': 0.9418466963494533, 'eval_runtime': 116.995, 'eval_samples_per_second': 35.181, 'eval_steps_per_second': 1.103, 'epoch': 2.17}\n"
1375
+ ]
1376
+ },
1377
+ {
1378
+ "name": "stderr",
1379
+ "output_type": "stream",
1380
+ "text": [
1381
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1382
+ " warnings.warn(\n"
1383
+ ]
1384
+ },
1385
+ {
1386
+ "name": "stdout",
1387
+ "output_type": "stream",
1388
+ "text": [
1389
+ "{'loss': 0.4024, 'grad_norm': 30.97524642944336, 'learning_rate': 3.55e-05, 'epoch': 2.25}\n"
1390
+ ]
1391
+ },
1392
+ {
1393
+ "data": {
1394
+ "application/vnd.jupyter.widget-view+json": {
1395
+ "model_id": "fe6a83c4081645daa098b36014c35911",
1396
+ "version_major": 2,
1397
+ "version_minor": 0
1398
+ },
1399
+ "text/plain": [
1400
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1401
+ ]
1402
+ },
1403
+ "metadata": {},
1404
+ "output_type": "display_data"
1405
+ },
1406
+ {
1407
+ "name": "stdout",
1408
+ "output_type": "stream",
1409
+ "text": [
1410
+ "{'eval_loss': 0.19847826659679413, 'eval_accuracy': 0.9387755102040817, 'eval_f1': 0.9372478002991185, 'eval_precision': 0.9397167553305189, 'eval_recall': 0.9376525555929388, 'eval_runtime': 113.5301, 'eval_samples_per_second': 36.255, 'eval_steps_per_second': 1.136, 'epoch': 2.25}\n"
1411
+ ]
1412
+ },
1413
+ {
1414
+ "name": "stderr",
1415
+ "output_type": "stream",
1416
+ "text": [
1417
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1418
+ " warnings.warn(\n"
1419
+ ]
1420
+ },
1421
+ {
1422
+ "name": "stdout",
1423
+ "output_type": "stream",
1424
+ "text": [
1425
+ "{'loss': 0.4006, 'grad_norm': 30.08006477355957, 'learning_rate': 3.5e-05, 'epoch': 2.33}\n"
1426
+ ]
1427
+ },
1428
+ {
1429
+ "data": {
1430
+ "application/vnd.jupyter.widget-view+json": {
1431
+ "model_id": "ea3e385421cd451eb7297c5c319781dd",
1432
+ "version_major": 2,
1433
+ "version_minor": 0
1434
+ },
1435
+ "text/plain": [
1436
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1437
+ ]
1438
+ },
1439
+ "metadata": {},
1440
+ "output_type": "display_data"
1441
+ },
1442
+ {
1443
+ "name": "stdout",
1444
+ "output_type": "stream",
1445
+ "text": [
1446
+ "{'eval_loss': 0.2153274416923523, 'eval_accuracy': 0.9334305150631681, 'eval_f1': 0.9275387845824814, 'eval_precision': 0.9419875530424547, 'eval_recall': 0.9310524562810819, 'eval_runtime': 116.2218, 'eval_samples_per_second': 35.415, 'eval_steps_per_second': 1.11, 'epoch': 2.33}\n"
1447
+ ]
1448
+ },
1449
+ {
1450
+ "name": "stderr",
1451
+ "output_type": "stream",
1452
+ "text": [
1453
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1454
+ " warnings.warn(\n"
1455
+ ]
1456
+ },
1457
+ {
1458
+ "name": "stdout",
1459
+ "output_type": "stream",
1460
+ "text": [
1461
+ "{'loss': 0.3935, 'grad_norm': 26.754928588867188, 'learning_rate': 3.45e-05, 'epoch': 2.41}\n"
1462
+ ]
1463
+ },
1464
+ {
1465
+ "data": {
1466
+ "application/vnd.jupyter.widget-view+json": {
1467
+ "model_id": "f54d46999e934ae7bd92203e0723b29d",
1468
+ "version_major": 2,
1469
+ "version_minor": 0
1470
+ },
1471
+ "text/plain": [
1472
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1473
+ ]
1474
+ },
1475
+ "metadata": {},
1476
+ "output_type": "display_data"
1477
+ },
1478
+ {
1479
+ "name": "stdout",
1480
+ "output_type": "stream",
1481
+ "text": [
1482
+ "{'eval_loss': 0.20206879079341888, 'eval_accuracy': 0.9392614188532555, 'eval_f1': 0.9345899675258124, 'eval_precision': 0.9415653519418664, 'eval_recall': 0.9368182103451178, 'eval_runtime': 114.9203, 'eval_samples_per_second': 35.816, 'eval_steps_per_second': 1.123, 'epoch': 2.41}\n"
1483
+ ]
1484
+ },
1485
+ {
1486
+ "name": "stderr",
1487
+ "output_type": "stream",
1488
+ "text": [
1489
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1490
+ " warnings.warn(\n"
1491
+ ]
1492
+ },
1493
+ {
1494
+ "name": "stdout",
1495
+ "output_type": "stream",
1496
+ "text": [
1497
+ "{'loss': 0.3591, 'grad_norm': 18.381698608398438, 'learning_rate': 3.4000000000000007e-05, 'epoch': 2.49}\n"
1498
+ ]
1499
+ },
1500
+ {
1501
+ "data": {
1502
+ "application/vnd.jupyter.widget-view+json": {
1503
+ "model_id": "a939a26c314c4bc48a1a37c17fea7592",
1504
+ "version_major": 2,
1505
+ "version_minor": 0
1506
+ },
1507
+ "text/plain": [
1508
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1509
+ ]
1510
+ },
1511
+ "metadata": {},
1512
+ "output_type": "display_data"
1513
+ },
1514
+ {
1515
+ "name": "stdout",
1516
+ "output_type": "stream",
1517
+ "text": [
1518
+ "{'eval_loss': 0.21262311935424805, 'eval_accuracy': 0.934645286686103, 'eval_f1': 0.9310688457336352, 'eval_precision': 0.9403388539780413, 'eval_recall': 0.9332890437488752, 'eval_runtime': 115.8495, 'eval_samples_per_second': 35.529, 'eval_steps_per_second': 1.114, 'epoch': 2.49}\n"
1519
+ ]
1520
+ },
1521
+ {
1522
+ "name": "stderr",
1523
+ "output_type": "stream",
1524
+ "text": [
1525
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1526
+ " warnings.warn(\n"
1527
+ ]
1528
+ },
1529
+ {
1530
+ "name": "stdout",
1531
+ "output_type": "stream",
1532
+ "text": [
1533
+ "{'loss': 0.4058, 'grad_norm': 38.42707824707031, 'learning_rate': 3.35e-05, 'epoch': 2.56}\n"
1534
+ ]
1535
+ },
1536
+ {
1537
+ "data": {
1538
+ "application/vnd.jupyter.widget-view+json": {
1539
+ "model_id": "50d0c42c4f8a415595c696781b01847b",
1540
+ "version_major": 2,
1541
+ "version_minor": 0
1542
+ },
1543
+ "text/plain": [
1544
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1545
+ ]
1546
+ },
1547
+ "metadata": {},
1548
+ "output_type": "display_data"
1549
+ },
1550
+ {
1551
+ "name": "stdout",
1552
+ "output_type": "stream",
1553
+ "text": [
1554
+ "{'eval_loss': 0.20197580754756927, 'eval_accuracy': 0.9378036929057337, 'eval_f1': 0.9356625337908084, 'eval_precision': 0.9393282055887032, 'eval_recall': 0.935787907740436, 'eval_runtime': 115.4582, 'eval_samples_per_second': 35.649, 'eval_steps_per_second': 1.117, 'epoch': 2.56}\n"
1555
+ ]
1556
+ },
1557
+ {
1558
+ "name": "stderr",
1559
+ "output_type": "stream",
1560
+ "text": [
1561
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1562
+ " warnings.warn(\n"
1563
+ ]
1564
+ },
1565
+ {
1566
+ "name": "stdout",
1567
+ "output_type": "stream",
1568
+ "text": [
1569
+ "{'loss': 0.396, 'grad_norm': 29.439712524414062, 'learning_rate': 3.3e-05, 'epoch': 2.64}\n"
1570
+ ]
1571
+ },
1572
+ {
1573
+ "data": {
1574
+ "application/vnd.jupyter.widget-view+json": {
1575
+ "model_id": "0462c202c3dc4b22a53d4c822d2705d7",
1576
+ "version_major": 2,
1577
+ "version_minor": 0
1578
+ },
1579
+ "text/plain": [
1580
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1581
+ ]
1582
+ },
1583
+ "metadata": {},
1584
+ "output_type": "display_data"
1585
+ },
1586
+ {
1587
+ "name": "stdout",
1588
+ "output_type": "stream",
1589
+ "text": [
1590
+ "{'eval_loss': 0.20382580161094666, 'eval_accuracy': 0.9370748299319728, 'eval_f1': 0.9338700635095457, 'eval_precision': 0.9410414254819953, 'eval_recall': 0.9357427719978543, 'eval_runtime': 113.0645, 'eval_samples_per_second': 36.404, 'eval_steps_per_second': 1.141, 'epoch': 2.64}\n"
1591
+ ]
1592
+ },
1593
+ {
1594
+ "name": "stderr",
1595
+ "output_type": "stream",
1596
+ "text": [
1597
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1598
+ " warnings.warn(\n"
1599
+ ]
1600
+ },
1601
+ {
1602
+ "name": "stdout",
1603
+ "output_type": "stream",
1604
+ "text": [
1605
+ "{'loss': 0.4157, 'grad_norm': 30.86858367919922, 'learning_rate': 3.2500000000000004e-05, 'epoch': 2.72}\n"
1606
+ ]
1607
+ },
1608
+ {
1609
+ "data": {
1610
+ "application/vnd.jupyter.widget-view+json": {
1611
+ "model_id": "48e3c7dae1ca4c43b188b4a124923773",
1612
+ "version_major": 2,
1613
+ "version_minor": 0
1614
+ },
1615
+ "text/plain": [
1616
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1617
+ ]
1618
+ },
1619
+ "metadata": {},
1620
+ "output_type": "display_data"
1621
+ },
1622
+ {
1623
+ "name": "stdout",
1624
+ "output_type": "stream",
1625
+ "text": [
1626
+ "{'eval_loss': 0.20911905169487, 'eval_accuracy': 0.9331875607385811, 'eval_f1': 0.9287535942840445, 'eval_precision': 0.9351518924076643, 'eval_recall': 0.9307581587262745, 'eval_runtime': 114.6445, 'eval_samples_per_second': 35.902, 'eval_steps_per_second': 1.125, 'epoch': 2.72}\n"
1627
+ ]
1628
+ },
1629
+ {
1630
+ "name": "stderr",
1631
+ "output_type": "stream",
1632
+ "text": [
1633
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1634
+ " warnings.warn(\n"
1635
+ ]
1636
+ },
1637
+ {
1638
+ "name": "stdout",
1639
+ "output_type": "stream",
1640
+ "text": [
1641
+ "{'loss': 0.4222, 'grad_norm': 29.99662208557129, 'learning_rate': 3.2000000000000005e-05, 'epoch': 2.8}\n"
1642
+ ]
1643
+ },
1644
+ {
1645
+ "data": {
1646
+ "application/vnd.jupyter.widget-view+json": {
1647
+ "model_id": "109e55a7b44a44b9a482037011296482",
1648
+ "version_major": 2,
1649
+ "version_minor": 0
1650
+ },
1651
+ "text/plain": [
1652
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1653
+ ]
1654
+ },
1655
+ "metadata": {},
1656
+ "output_type": "display_data"
1657
+ },
1658
+ {
1659
+ "name": "stdout",
1660
+ "output_type": "stream",
1661
+ "text": [
1662
+ "{'eval_loss': 0.1933327168226242, 'eval_accuracy': 0.9392614188532555, 'eval_f1': 0.9372313674271271, 'eval_precision': 0.9398983217337015, 'eval_recall': 0.9377663643140918, 'eval_runtime': 114.7869, 'eval_samples_per_second': 35.858, 'eval_steps_per_second': 1.124, 'epoch': 2.8}\n"
1663
+ ]
1664
+ },
1665
+ {
1666
+ "name": "stderr",
1667
+ "output_type": "stream",
1668
+ "text": [
1669
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1670
+ " warnings.warn(\n"
1671
+ ]
1672
+ },
1673
+ {
1674
+ "name": "stdout",
1675
+ "output_type": "stream",
1676
+ "text": [
1677
+ "{'loss': 0.3521, 'grad_norm': 30.95831298828125, 'learning_rate': 3.15e-05, 'epoch': 2.87}\n"
1678
+ ]
1679
+ },
1680
+ {
1681
+ "data": {
1682
+ "application/vnd.jupyter.widget-view+json": {
1683
+ "model_id": "d6ea2392bbb14a9ca7ee10bd152d377e",
1684
+ "version_major": 2,
1685
+ "version_minor": 0
1686
+ },
1687
+ "text/plain": [
1688
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1689
+ ]
1690
+ },
1691
+ "metadata": {},
1692
+ "output_type": "display_data"
1693
+ },
1694
+ {
1695
+ "name": "stdout",
1696
+ "output_type": "stream",
1697
+ "text": [
1698
+ "{'eval_loss': 0.198385551571846, 'eval_accuracy': 0.9397473275024295, 'eval_f1': 0.9380846555670623, 'eval_precision': 0.9429846718188393, 'eval_recall': 0.9388009151105176, 'eval_runtime': 115.1456, 'eval_samples_per_second': 35.746, 'eval_steps_per_second': 1.12, 'epoch': 2.87}\n"
1699
+ ]
1700
+ },
1701
+ {
1702
+ "name": "stderr",
1703
+ "output_type": "stream",
1704
+ "text": [
1705
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1706
+ " warnings.warn(\n"
1707
+ ]
1708
+ },
1709
+ {
1710
+ "name": "stdout",
1711
+ "output_type": "stream",
1712
+ "text": [
1713
+ "{'loss': 0.3925, 'grad_norm': 39.25502395629883, 'learning_rate': 3.1e-05, 'epoch': 2.95}\n"
1714
+ ]
1715
+ },
1716
+ {
1717
+ "data": {
1718
+ "application/vnd.jupyter.widget-view+json": {
1719
+ "model_id": "7312e1a6d9e84c36afd247a2f2c1b37e",
1720
+ "version_major": 2,
1721
+ "version_minor": 0
1722
+ },
1723
+ "text/plain": [
1724
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1725
+ ]
1726
+ },
1727
+ "metadata": {},
1728
+ "output_type": "display_data"
1729
+ },
1730
+ {
1731
+ "name": "stdout",
1732
+ "output_type": "stream",
1733
+ "text": [
1734
+ "{'eval_loss': 0.18743416666984558, 'eval_accuracy': 0.9382896015549077, 'eval_f1': 0.9347460787377938, 'eval_precision': 0.9389896925557485, 'eval_recall': 0.9357984136549564, 'eval_runtime': 113.3981, 'eval_samples_per_second': 36.297, 'eval_steps_per_second': 1.138, 'epoch': 2.95}\n"
1735
+ ]
1736
+ },
1737
+ {
1738
+ "name": "stderr",
1739
+ "output_type": "stream",
1740
+ "text": [
1741
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1742
+ " warnings.warn(\n"
1743
+ ]
1744
+ },
1745
+ {
1746
+ "name": "stdout",
1747
+ "output_type": "stream",
1748
+ "text": [
1749
+ "{'loss': 0.3475, 'grad_norm': 34.49732208251953, 'learning_rate': 3.05e-05, 'epoch': 3.03}\n"
1750
+ ]
1751
+ },
1752
+ {
1753
+ "data": {
1754
+ "application/vnd.jupyter.widget-view+json": {
1755
+ "model_id": "e5b85ef8da594a2f8cf915433ba23db8",
1756
+ "version_major": 2,
1757
+ "version_minor": 0
1758
+ },
1759
+ "text/plain": [
1760
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1761
+ ]
1762
+ },
1763
+ "metadata": {},
1764
+ "output_type": "display_data"
1765
+ },
1766
+ {
1767
+ "name": "stdout",
1768
+ "output_type": "stream",
1769
+ "text": [
1770
+ "{'eval_loss': 0.19942405819892883, 'eval_accuracy': 0.9382896015549077, 'eval_f1': 0.9364100359103614, 'eval_precision': 0.9410069422883035, 'eval_recall': 0.9376210497900824, 'eval_runtime': 115.7969, 'eval_samples_per_second': 35.545, 'eval_steps_per_second': 1.114, 'epoch': 3.03}\n"
1771
+ ]
1772
+ },
1773
+ {
1774
+ "name": "stderr",
1775
+ "output_type": "stream",
1776
+ "text": [
1777
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1778
+ " warnings.warn(\n"
1779
+ ]
1780
+ },
1781
+ {
1782
+ "name": "stdout",
1783
+ "output_type": "stream",
1784
+ "text": [
1785
+ "{'loss': 0.3526, 'grad_norm': 34.14906311035156, 'learning_rate': 3e-05, 'epoch': 3.11}\n"
1786
+ ]
1787
+ },
1788
+ {
1789
+ "data": {
1790
+ "application/vnd.jupyter.widget-view+json": {
1791
+ "model_id": "a1c980fe777f4de392fbe12a110ecdcd",
1792
+ "version_major": 2,
1793
+ "version_minor": 0
1794
+ },
1795
+ "text/plain": [
1796
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1797
+ ]
1798
+ },
1799
+ "metadata": {},
1800
+ "output_type": "display_data"
1801
+ },
1802
+ {
1803
+ "name": "stdout",
1804
+ "output_type": "stream",
1805
+ "text": [
1806
+ "{'eval_loss': 0.1941181868314743, 'eval_accuracy': 0.9390184645286687, 'eval_f1': 0.9351668156232781, 'eval_precision': 0.9402405362825343, 'eval_recall': 0.9372908916182126, 'eval_runtime': 112.8197, 'eval_samples_per_second': 36.483, 'eval_steps_per_second': 1.143, 'epoch': 3.11}\n"
1807
+ ]
1808
+ },
1809
+ {
1810
+ "name": "stderr",
1811
+ "output_type": "stream",
1812
+ "text": [
1813
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1814
+ " warnings.warn(\n"
1815
+ ]
1816
+ },
1817
+ {
1818
+ "name": "stdout",
1819
+ "output_type": "stream",
1820
+ "text": [
1821
+ "{'loss': 0.351, 'grad_norm': 25.905424118041992, 'learning_rate': 2.95e-05, 'epoch': 3.18}\n"
1822
+ ]
1823
+ },
1824
+ {
1825
+ "data": {
1826
+ "application/vnd.jupyter.widget-view+json": {
1827
+ "model_id": "630981b183234e758971e712e999028d",
1828
+ "version_major": 2,
1829
+ "version_minor": 0
1830
+ },
1831
+ "text/plain": [
1832
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1833
+ ]
1834
+ },
1835
+ "metadata": {},
1836
+ "output_type": "display_data"
1837
+ },
1838
+ {
1839
+ "name": "stdout",
1840
+ "output_type": "stream",
1841
+ "text": [
1842
+ "{'eval_loss': 0.18932673335075378, 'eval_accuracy': 0.9416909620991254, 'eval_f1': 0.940318085045109, 'eval_precision': 0.9438062512440089, 'eval_recall': 0.9409543473093137, 'eval_runtime': 114.6924, 'eval_samples_per_second': 35.887, 'eval_steps_per_second': 1.125, 'epoch': 3.18}\n"
1843
+ ]
1844
+ },
1845
+ {
1846
+ "name": "stderr",
1847
+ "output_type": "stream",
1848
+ "text": [
1849
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1850
+ " warnings.warn(\n"
1851
+ ]
1852
+ },
1853
+ {
1854
+ "name": "stdout",
1855
+ "output_type": "stream",
1856
+ "text": [
1857
+ "{'loss': 0.3549, 'grad_norm': 26.349733352661133, 'learning_rate': 2.9e-05, 'epoch': 3.26}\n"
1858
+ ]
1859
+ },
1860
+ {
1861
+ "data": {
1862
+ "application/vnd.jupyter.widget-view+json": {
1863
+ "model_id": "b4bf00320dee42bbad21874c2a2cf471",
1864
+ "version_major": 2,
1865
+ "version_minor": 0
1866
+ },
1867
+ "text/plain": [
1868
+ " 0%| | 0/129 [00:00<?, ?it/s]"
1869
+ ]
1870
+ },
1871
+ "metadata": {},
1872
+ "output_type": "display_data"
1873
+ },
1874
+ {
1875
+ "name": "stdout",
1876
+ "output_type": "stream",
1877
+ "text": [
1878
+ "{'eval_loss': 0.19598565995693207, 'eval_accuracy': 0.9390184645286687, 'eval_f1': 0.936971760890017, 'eval_precision': 0.9409831330401592, 'eval_recall': 0.938080489222113, 'eval_runtime': 118.5544, 'eval_samples_per_second': 34.718, 'eval_steps_per_second': 1.088, 'epoch': 3.26}\n"
1879
+ ]
1880
+ },
1881
+ {
1882
+ "name": "stderr",
1883
+ "output_type": "stream",
1884
+ "text": [
1885
+ "/Users/andrewmayes/Openclassroom/CanineNet/env/lib/python3.12/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
1886
+ " warnings.warn(\n"
1887
+ ]
1888
+ }
1889
+ ],
1890
+ "source": [
1891
+ "metrics = {metric: evaluate.load(metric) for metric in METRICS}\n",
1892
+ "\n",
1893
+ "\n",
1894
+ "# for lr in [5e-3, 5e-4, 5e-5]: # 5e-5\n",
1895
+ "# for batch in [64]: # 32\n",
1896
+ "# for model_name in [\"google/vit-base-patch16-224\", \"microsoft/swinv2-base-patch4-window16-256\", \"google/siglip-base-patch16-224\"]: # \"facebook/dinov2-base\"\n",
1897
+ "\n",
1898
+ "lr = 5e-4\n",
1899
+ "batch = 32\n",
1900
+ "model_name = \"microsoft/swinv2-base-patch4-window16-256\"\n",
1901
+ "\n",
1902
+ "image_processor = AutoImageProcessor.from_pretrained(model_name)\n",
1903
+ "model = AutoModelForImageClassification.from_pretrained(\n",
1904
+ "model_name,\n",
1905
+ "num_labels=len(label2int),\n",
1906
+ "id2label=int2label,\n",
1907
+ "label2id=label2int,\n",
1908
+ "ignore_mismatched_sizes=True,\n",
1909
+ ")\n",
1910
+ "\n",
1911
+ "# Then, in your transformations:\n",
1912
+ "def train_transform(examples, num_ops=10, magnitude=9, num_magnitude_bins=31):\n",
1913
+ "\n",
1914
+ " transformation = v2.Compose(\n",
1915
+ " [\n",
1916
+ " v2.RandAugment(\n",
1917
+ " num_ops=num_ops,\n",
1918
+ " magnitude=magnitude,\n",
1919
+ " num_magnitude_bins=num_magnitude_bins,\n",
1920
+ " )\n",
1921
+ " ]\n",
1922
+ " )\n",
1923
+ " # Ensure each image has three dimensions (in this case, ensure it's RGB)\n",
1924
+ " examples[\"pixel_values\"] = [\n",
1925
+ " image.convert(\"RGB\") for image in examples[\"pixel_values\"]\n",
1926
+ " ]\n",
1927
+ " # Apply transformations\n",
1928
+ " examples[\"pixel_values\"] = [\n",
1929
+ " image_processor(transformation(image), return_tensors=\"pt\")[\n",
1930
+ " \"pixel_values\"\n",
1931
+ " ].squeeze()\n",
1932
+ " for image in examples[\"pixel_values\"]\n",
1933
+ " ]\n",
1934
+ " return examples\n",
1935
+ "\n",
1936
+ "\n",
1937
+ "def test_transform(examples):\n",
1938
+ " # Ensure each image is RGB\n",
1939
+ " examples[\"pixel_values\"] = [\n",
1940
+ " image.convert(\"RGB\") for image in examples[\"pixel_values\"]\n",
1941
+ " ]\n",
1942
+ " # Apply processing\n",
1943
+ " examples[\"pixel_values\"] = [\n",
1944
+ " image_processor(image, return_tensors=\"pt\")[\"pixel_values\"].squeeze()\n",
1945
+ " for image in examples[\"pixel_values\"]\n",
1946
+ " ]\n",
1947
+ " return examples\n",
1948
+ "\n",
1949
+ "\n",
1950
+ "def compute_metrics(eval_pred):\n",
1951
+ " predictions, labels = eval_pred\n",
1952
+ " # predictions = np.argmax(logits, axis=-1)\n",
1953
+ " results = {}\n",
1954
+ " for key, val in metrics.items():\n",
1955
+ " if \"accuracy\" == key:\n",
1956
+ " result = next(\n",
1957
+ " iter(val.compute(predictions=predictions, references=labels).items())\n",
1958
+ " )\n",
1959
+ " if \"accuracy\" != key:\n",
1960
+ " result = next(\n",
1961
+ " iter(\n",
1962
+ " val.compute(\n",
1963
+ " predictions=predictions, references=labels, average=\"macro\"\n",
1964
+ " ).items()\n",
1965
+ " )\n",
1966
+ " )\n",
1967
+ " results[result[0]] = result[1]\n",
1968
+ " return results\n",
1969
+ "\n",
1970
+ "\n",
1971
+ "def collate_fn(examples):\n",
1972
+ " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
1973
+ " labels = torch.tensor([example[\"label\"] for example in examples])\n",
1974
+ " return {\"pixel_values\": pixel_values, \"labels\": labels}\n",
1975
+ "\n",
1976
+ "\n",
1977
+ "def preprocess_logits_for_metrics(logits, labels):\n",
1978
+ " \"\"\"\n",
1979
+ " Original Trainer may have a memory leak.\n",
1980
+ " This is a workaround to avoid storing too many tensors that are not needed.\n",
1981
+ " \"\"\"\n",
1982
+ " pred_ids = torch.argmax(logits, dim=-1)\n",
1983
+ " return pred_ids\n",
1984
+ "\n",
1985
+ "ds[\"train\"].set_transform(train_transform)\n",
1986
+ "ds[\"test\"].set_transform(test_transform)\n",
1987
+ "\n",
1988
+ "training_args = TrainingArguments(**CONFIG[\"training_args\"])\n",
1989
+ "training_args.per_device_train_batch_size = batch\n",
1990
+ "training_args.per_device_eval_batch_size = batch\n",
1991
+ "training_args.hub_model_id = f\"amaye15/{model_name.replace('/','-')}-batch{batch}-lr{lr}-standford-dogs\"\n",
1992
+ "\n",
1993
+ "mlflow.start_run(run_name=f\"{model_name.replace('/','-')}-batch{batch}-lr{lr}\")\n",
1994
+ "\n",
1995
+ "trainer = Trainer(\n",
1996
+ " model=model,\n",
1997
+ " args=training_args,\n",
1998
+ " train_dataset=ds[\"train\"],\n",
1999
+ " eval_dataset=ds[\"test\"],\n",
2000
+ " tokenizer=image_processor,\n",
2001
+ " data_collator=collate_fn,\n",
2002
+ " compute_metrics=compute_metrics,\n",
2003
+ " # callbacks=[early_stopping_callback],\n",
2004
+ " preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
2005
+ ")\n",
2006
+ "\n",
2007
+ "# Train the model\n",
2008
+ "trainer.train()\n",
2009
+ "\n",
2010
+ "trainer.push_to_hub()\n",
2011
+ "\n",
2012
+ "mlflow.end_run()"
2013
+ ]
2014
+ },
2015
+ {
2016
+ "cell_type": "code",
2017
+ "execution_count": null,
2018
+ "metadata": {},
2019
+ "outputs": [],
2020
+ "source": [
2021
+ "mlflow.end_run()"
2022
+ ]
2023
+ },
2024
+ {
2025
+ "cell_type": "code",
2026
+ "execution_count": null,
2027
+ "metadata": {},
2028
+ "outputs": [],
2029
+ "source": [
2030
+ "# training_args = TrainingArguments(**CONFIG[\"training_args\"])\n",
2031
+ "\n",
2032
+ "# image_processor = AutoImageProcessor.from_pretrained(MODELS)\n",
2033
+ "# model = AutoModelForImageClassification.from_pretrained(\n",
2034
+ "# MODELS,\n",
2035
+ "# num_labels=len(CONFIG[\"label2int\"]),\n",
2036
+ "# id2label=CONFIG[\"label2int\"],\n",
2037
+ "# label2id=CONFIG[\"int2label\"],\n",
2038
+ "# ignore_mismatched_sizes=True,\n",
2039
+ "# )\n",
2040
+ "\n",
2041
+ "\n",
2042
+ "# training_args = TrainingArguments(**CONFIG[\"training_args\"])\n",
2043
+ "\n",
2044
+ "# trainer = Trainer(\n",
2045
+ "# model=model,\n",
2046
+ "# args=training_args,\n",
2047
+ "# train_dataset=ds[\"train\"],\n",
2048
+ "# eval_dataset=ds[\"test\"],\n",
2049
+ "# tokenizer=image_processor,\n",
2050
+ "# data_collator=collate_fn,\n",
2051
+ "# compute_metrics=compute_metrics,\n",
2052
+ "# # callbacks=[early_stopping_callback],\n",
2053
+ "# preprocess_logits_for_metrics=preprocess_logits_for_metrics,\n",
2054
+ "# )\n",
2055
+ "\n",
2056
+ "# # Train the model\n",
2057
+ "# trainer.train()\n",
2058
+ "\n",
2059
+ "# mlflow.end_run()"
2060
+ ]
2061
+ }
2062
+ ],
2063
+ "metadata": {
2064
+ "kernelspec": {
2065
+ "display_name": "env",
2066
+ "language": "python",
2067
+ "name": "python3"
2068
+ },
2069
+ "language_info": {
2070
+ "codemirror_mode": {
2071
+ "name": "ipython",
2072
+ "version": 3
2073
+ },
2074
+ "file_extension": ".py",
2075
+ "mimetype": "text/x-python",
2076
+ "name": "python",
2077
+ "nbconvert_exporter": "python",
2078
+ "pygments_lexer": "ipython3",
2079
+ "version": "3.12.3"
2080
+ }
2081
+ },
2082
+ "nbformat": 4,
2083
+ "nbformat_minor": 2
2084
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b49c485d331c8a810b7deb770daa71926b779ee6ba9b6b950b22ce9530991c6e
3
+ size 5112