suayptalha
commited on
Update modeling_minGRULM.py
Browse files- modeling_minGRULM.py +34 -10
modeling_minGRULM.py
CHANGED
@@ -28,6 +28,7 @@ class MinGRULMWrapped(nn.Module):
|
|
28 |
return self
|
29 |
|
30 |
|
|
|
31 |
class MinGRULMPreTrainedModel(PreTrainedModel):
|
32 |
config_class = MinGRULMConfig
|
33 |
base_model_prefix = "model"
|
@@ -45,6 +46,7 @@ class MinGRULMPreTrainedModel(PreTrainedModel):
|
|
45 |
elif isinstance(module, nn.LayerNorm):
|
46 |
module.bias.data.zero_()
|
47 |
module.weight.data.fill_(1.0)
|
|
|
48 |
|
49 |
class MinGRULMForCausalLM(PreTrainedModel):
|
50 |
config_class = MinGRULMConfig
|
@@ -119,6 +121,25 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
119 |
logits=logits,
|
120 |
)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
@classmethod
|
123 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
124 |
"""
|
@@ -137,15 +158,18 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
137 |
"""
|
138 |
# Create the save directory if it doesn't exist
|
139 |
os.makedirs(save_directory, exist_ok=True)
|
140 |
-
|
141 |
-
#
|
142 |
-
torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
|
143 |
-
|
144 |
-
# Save the configuration
|
145 |
-
self.config.save_pretrained(save_directory)
|
146 |
-
|
147 |
-
# Optionally print messages based on the safe_serialization flag
|
148 |
if safe_serialization:
|
149 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
else:
|
151 |
-
print("
|
|
|
|
|
|
28 |
return self
|
29 |
|
30 |
|
31 |
+
|
32 |
class MinGRULMPreTrainedModel(PreTrainedModel):
|
33 |
config_class = MinGRULMConfig
|
34 |
base_model_prefix = "model"
|
|
|
46 |
elif isinstance(module, nn.LayerNorm):
|
47 |
module.bias.data.zero_()
|
48 |
module.weight.data.fill_(1.0)
|
49 |
+
|
50 |
|
51 |
class MinGRULMForCausalLM(PreTrainedModel):
|
52 |
config_class = MinGRULMConfig
|
|
|
121 |
logits=logits,
|
122 |
)
|
123 |
|
124 |
+
def state_dict(self):
|
125 |
+
"""
|
126 |
+
Custom state_dict function to return the model's state dict.
|
127 |
+
This includes the wrapped model and any extra components like the language model head.
|
128 |
+
"""
|
129 |
+
state_dict = {}
|
130 |
+
|
131 |
+
# Add min_gru_model's state_dict
|
132 |
+
state_dict['model'] = self.model.min_gru_model.state_dict()
|
133 |
+
|
134 |
+
# Add lm_head's state_dict
|
135 |
+
state_dict['lm_head'] = self.lm_head.state_dict()
|
136 |
+
|
137 |
+
# Optionally, add config if needed
|
138 |
+
state_dict['config'] = self.config.state_dict()
|
139 |
+
|
140 |
+
return state_dict
|
141 |
+
|
142 |
+
|
143 |
@classmethod
|
144 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
145 |
"""
|
|
|
158 |
"""
|
159 |
# Create the save directory if it doesn't exist
|
160 |
os.makedirs(save_directory, exist_ok=True)
|
161 |
+
|
162 |
+
# Check if safe_serialization is enabled
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
if safe_serialization:
|
164 |
+
print("Saving with safe serialization.")
|
165 |
+
|
166 |
+
# Save the model's state_dict (model weights)
|
167 |
+
state_dict = self.state_dict()
|
168 |
+
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
169 |
+
|
170 |
+
# Save the configuration
|
171 |
+
self.config.save_pretrained(save_directory)
|
172 |
else:
|
173 |
+
print("Saving without safe serialization.")
|
174 |
+
# If not safe_serialization, use the default save mechanism from the base class
|
175 |
+
super().save_pretrained(save_directory)
|