suayptalha commited on
Commit
109fa21
·
verified ·
1 Parent(s): f2f029b

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +34 -10
modeling_minGRULM.py CHANGED
@@ -28,6 +28,7 @@ class MinGRULMWrapped(nn.Module):
28
  return self
29
 
30
 
 
31
  class MinGRULMPreTrainedModel(PreTrainedModel):
32
  config_class = MinGRULMConfig
33
  base_model_prefix = "model"
@@ -45,6 +46,7 @@ class MinGRULMPreTrainedModel(PreTrainedModel):
45
  elif isinstance(module, nn.LayerNorm):
46
  module.bias.data.zero_()
47
  module.weight.data.fill_(1.0)
 
48
 
49
  class MinGRULMForCausalLM(PreTrainedModel):
50
  config_class = MinGRULMConfig
@@ -119,6 +121,25 @@ class MinGRULMForCausalLM(PreTrainedModel):
119
  logits=logits,
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  @classmethod
123
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
124
  """
@@ -137,15 +158,18 @@ class MinGRULMForCausalLM(PreTrainedModel):
137
  """
138
  # Create the save directory if it doesn't exist
139
  os.makedirs(save_directory, exist_ok=True)
140
-
141
- # Save the model's state_dict (model weights)
142
- torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
143
-
144
- # Save the configuration
145
- self.config.save_pretrained(save_directory)
146
-
147
- # Optionally print messages based on the safe_serialization flag
148
  if safe_serialization:
149
- print("Model and configuration have been saved safely.")
 
 
 
 
 
 
 
150
  else:
151
- print("Model and configuration have been saved (unsafe serialization).")
 
 
 
28
  return self
29
 
30
 
31
+
32
  class MinGRULMPreTrainedModel(PreTrainedModel):
33
  config_class = MinGRULMConfig
34
  base_model_prefix = "model"
 
46
  elif isinstance(module, nn.LayerNorm):
47
  module.bias.data.zero_()
48
  module.weight.data.fill_(1.0)
49
+
50
 
51
  class MinGRULMForCausalLM(PreTrainedModel):
52
  config_class = MinGRULMConfig
 
121
  logits=logits,
122
  )
123
 
124
+ def state_dict(self):
125
+ """
126
+ Custom state_dict function to return the model's state dict.
127
+ This includes the wrapped model and any extra components like the language model head.
128
+ """
129
+ state_dict = {}
130
+
131
+ # Add min_gru_model's state_dict
132
+ state_dict['model'] = self.model.min_gru_model.state_dict()
133
+
134
+ # Add lm_head's state_dict
135
+ state_dict['lm_head'] = self.lm_head.state_dict()
136
+
137
+ # Optionally, add config if needed
138
+ state_dict['config'] = self.config.state_dict()
139
+
140
+ return state_dict
141
+
142
+
143
  @classmethod
144
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
145
  """
 
158
  """
159
  # Create the save directory if it doesn't exist
160
  os.makedirs(save_directory, exist_ok=True)
161
+
162
+ # Check if safe_serialization is enabled
 
 
 
 
 
 
163
  if safe_serialization:
164
+ print("Saving with safe serialization.")
165
+
166
+ # Save the model's state_dict (model weights)
167
+ state_dict = self.state_dict()
168
+ torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
169
+
170
+ # Save the configuration
171
+ self.config.save_pretrained(save_directory)
172
  else:
173
+ print("Saving without safe serialization.")
174
+ # If not safe_serialization, use the default save mechanism from the base class
175
+ super().save_pretrained(save_directory)