Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
99dc7bf
1
Parent(s):
81f7a68
3layer kan
Browse files
tasks/models/frugal_2025-01-26/frugal_kan_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f5652f3be0033b5a249ded449c1f7c40dd2f32d649b07ca7c2f6158a4b57cb5
|
3 |
+
size 1710980
|
tasks/models/frugal_2025-01-27/CNNEncoder_frugal_2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tasks/models/frugal_2025-01-27/frugal_kan_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f520cff8b9531981e16a8b009b6a55fb8ca98573fc4d3dc6806df60b07a49c2
|
3 |
+
size 1710980
|
tasks/run.py
CHANGED
@@ -9,6 +9,7 @@ import yaml
|
|
9 |
import datetime
|
10 |
import json
|
11 |
import numpy as np
|
|
|
12 |
|
13 |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
@@ -56,6 +57,13 @@ model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
|
56 |
# model.kan.speed()
|
57 |
# model = KanEncoder(kan_args.get_dict())
|
58 |
model = model.to(local_rank)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
60 |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
61 |
print(f"Number of parameters: {num_params}")
|
|
|
9 |
import datetime
|
10 |
import json
|
11 |
import numpy as np
|
12 |
+
from collections import OrderedDict
|
13 |
|
14 |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
|
|
57 |
# model.kan.speed()
|
58 |
# model = KanEncoder(kan_args.get_dict())
|
59 |
model = model.to(local_rank)
|
60 |
+
state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
|
61 |
+
new_state_dict = OrderedDict()
|
62 |
+
for key, value in state_dict.items():
|
63 |
+
if key.startswith('module.'):
|
64 |
+
key = key[7:]
|
65 |
+
new_state_dict[key] = value
|
66 |
+
missing, unexpected = model.load_state_dict(new_state_dict)
|
67 |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
68 |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
69 |
print(f"Number of parameters: {num_params}")
|
tasks/utils/config.yaml
CHANGED
@@ -12,7 +12,7 @@ Data:
|
|
12 |
max_days_lc: 270
|
13 |
lc_freq: 0.0208
|
14 |
create_umap: True
|
15 |
-
checkpoint_path: 'tasks/models/frugal_2025-01-
|
16 |
|
17 |
CNNEncoder:
|
18 |
# Model
|
@@ -31,7 +31,7 @@ CNNEncoder:
|
|
31 |
avg_output: False
|
32 |
|
33 |
KAN:
|
34 |
-
layers_hidden: [1125,32,8,
|
35 |
grid_min: -1.2
|
36 |
grid_max: 1.2
|
37 |
num_grids: 8
|
|
|
12 |
max_days_lc: 270
|
13 |
lc_freq: 0.0208
|
14 |
create_umap: True
|
15 |
+
checkpoint_path: 'tasks/models/frugal_2025-01-27/frugal_kan_2.pth'
|
16 |
|
17 |
CNNEncoder:
|
18 |
# Model
|
|
|
31 |
avg_output: False
|
32 |
|
33 |
KAN:
|
34 |
+
layers_hidden: [1125,32,8,1]
|
35 |
grid_min: -1.2
|
36 |
grid_max: 1.2
|
37 |
num_grids: 8
|