IlayMalinyak commited on
Commit
99dc7bf
·
1 Parent(s): 81f7a68

3layer kan

Browse files
tasks/models/frugal_2025-01-26/frugal_kan_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f5652f3be0033b5a249ded449c1f7c40dd2f32d649b07ca7c2f6158a4b57cb5
3
+ size 1710980
tasks/models/frugal_2025-01-27/CNNEncoder_frugal_2.json ADDED
The diff for this file is too large to render. See raw diff
 
tasks/models/frugal_2025-01-27/frugal_kan_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f520cff8b9531981e16a8b009b6a55fb8ca98573fc4d3dc6806df60b07a49c2
3
+ size 1710980
tasks/run.py CHANGED
@@ -9,6 +9,7 @@ import yaml
9
  import datetime
10
  import json
11
  import numpy as np
 
12
 
13
  # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
  current_date = datetime.date.today().strftime("%Y-%m-%d")
@@ -56,6 +57,13 @@ model = CNNKan(model_args, conformer_args, kan_args.get_dict())
56
  # model.kan.speed()
57
  # model = KanEncoder(kan_args.get_dict())
58
  model = model.to(local_rank)
 
 
 
 
 
 
 
59
  # model = DDP(model, device_ids=[local_rank], output_device=local_rank)
60
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
61
  print(f"Number of parameters: {num_params}")
 
9
  import datetime
10
  import json
11
  import numpy as np
12
+ from collections import OrderedDict
13
 
14
  # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  current_date = datetime.date.today().strftime("%Y-%m-%d")
 
57
  # model.kan.speed()
58
  # model = KanEncoder(kan_args.get_dict())
59
  model = model.to(local_rank)
60
+ state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
61
+ new_state_dict = OrderedDict()
62
+ for key, value in state_dict.items():
63
+ if key.startswith('module.'):
64
+ key = key[7:]
65
+ new_state_dict[key] = value
66
+ missing, unexpected = model.load_state_dict(new_state_dict)
67
  # model = DDP(model, device_ids=[local_rank], output_device=local_rank)
68
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
69
  print(f"Number of parameters: {num_params}")
tasks/utils/config.yaml CHANGED
@@ -12,7 +12,7 @@ Data:
12
  max_days_lc: 270
13
  lc_freq: 0.0208
14
  create_umap: True
15
- checkpoint_path: 'tasks/models/frugal_2025-01-21/frugal_kan_2.pth'
16
 
17
  CNNEncoder:
18
  # Model
@@ -31,7 +31,7 @@ CNNEncoder:
31
  avg_output: False
32
 
33
  KAN:
34
- layers_hidden: [1125,32,8,8,1]
35
  grid_min: -1.2
36
  grid_max: 1.2
37
  num_grids: 8
 
12
  max_days_lc: 270
13
  lc_freq: 0.0208
14
  create_umap: True
15
+ checkpoint_path: 'tasks/models/frugal_2025-01-27/frugal_kan_2.pth'
16
 
17
  CNNEncoder:
18
  # Model
 
31
  avg_output: False
32
 
33
  KAN:
34
+ layers_hidden: [1125,32,8,1]
35
  grid_min: -1.2
36
  grid_max: 1.2
37
  num_grids: 8