suayptalha commited on
Commit
2f89d54
·
verified ·
1 Parent(s): c11afbe

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +6 -7
modeling_minGRULM.py CHANGED
@@ -147,14 +147,13 @@ class MinGRULMForCausalLM(PreTrainedModel):
147
  # Save the model's state_dict (model weights)
148
  #state_dict = self.state_dict()
149
  state_dict = {}
150
-
151
- # Add min_gru_model's state_dict
152
- state_dict['model'] = self.model.min_gru_model.state_dict()
153
-
154
- # Add lm_head's state_dict
155
- state_dict['lm_head'] = self.lm_head.state_dict()
156
 
157
- # Add config as a dictionary (not state_dict, since it is not available)
 
 
158
  state_dict['config'] = self.config.__dict__
159
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
160
 
 
147
  # Save the model's state_dict (model weights)
148
  #state_dict = self.state_dict()
149
  state_dict = {}
150
+
151
+ for name, param in self.model.min_gru_model.named_parameters():
152
+ state_dict[f"model.{name}"] = param
 
 
 
153
 
154
+ for name, param in self.lm_head.named_parameters():
155
+ state_dict[f"lm_head.{name}"] = param
156
+
157
  state_dict['config'] = self.config.__dict__
158
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
159