suayptalha commited on
Commit
c11afbe
·
verified ·
1 Parent(s): 44cb4e5

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +11 -19
modeling_minGRULM.py CHANGED
@@ -121,24 +121,6 @@ class MinGRULMForCausalLM(PreTrainedModel):
121
  logits=logits,
122
  )
123
 
124
- def state_dict(self):
125
- """
126
- Custom state_dict function to return the model's state dict.
127
- This includes the wrapped model and any extra components like the language model head.
128
- """
129
- state_dict = {}
130
-
131
- # Add min_gru_model's state_dict
132
- state_dict['model'] = self.model.min_gru_model.state_dict()
133
-
134
- # Add lm_head's state_dict
135
- state_dict['lm_head'] = self.lm_head.state_dict()
136
-
137
- # Add config as a dictionary (not state_dict, since it is not available)
138
- state_dict['config'] = self.config.__dict__
139
-
140
- return state_dict
141
-
142
  @classmethod
143
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
144
  """
@@ -163,7 +145,17 @@ class MinGRULMForCausalLM(PreTrainedModel):
163
  print("Saving with safe serialization.")
164
 
165
  # Save the model's state_dict (model weights)
166
- state_dict = self.state_dict()
 
 
 
 
 
 
 
 
 
 
167
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
168
 
169
  # Save the configuration
 
121
  logits=logits,
122
  )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  @classmethod
125
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
126
  """
 
145
  print("Saving with safe serialization.")
146
 
147
  # Save the model's state_dict (model weights)
148
+ #state_dict = self.state_dict()
149
+ state_dict = {}
150
+
151
+ # Add min_gru_model's state_dict
152
+ state_dict['model'] = self.model.min_gru_model.state_dict()
153
+
154
+ # Add lm_head's state_dict
155
+ state_dict['lm_head'] = self.lm_head.state_dict()
156
+
157
+ # Add config as a dictionary (not state_dict, since it is not available)
158
+ state_dict['config'] = self.config.__dict__
159
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
160
 
161
  # Save the configuration