suayptalha commited on
Commit
f2f029b
·
verified ·
1 Parent(s): a9cdf5f

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +13 -9
modeling_minGRULM.py CHANGED
@@ -130,18 +130,22 @@ class MinGRULMForCausalLM(PreTrainedModel):
130
  def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
131
  """
132
  Save the model and configuration to a directory.
133
-
134
  Args:
135
  save_directory (str): Directory to save the model.
136
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
137
  """
 
 
 
 
 
 
 
 
 
 
138
  if safe_serialization:
139
- print("true")
140
- # Save the model's state_dict (model weights)
141
- torch.save(self.state_dict(), f"{save_directory}/pytorch_model.bin")
142
- # Save the configuration to a file
143
- self.config.save_pretrained(save_directory)
144
  else:
145
- print("false")
146
- # If not safe_serialization, you can use the default save mechanism from the base class
147
- super().save_pretrained(save_directory)
 
130
  def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True):
131
  """
132
  Save the model and configuration to a directory.
133
+
134
  Args:
135
  save_directory (str): Directory to save the model.
136
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
137
  """
138
+ # Create the save directory if it doesn't exist
139
+ os.makedirs(save_directory, exist_ok=True)
140
+
141
+ # Save the model's state_dict (model weights)
142
+ torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin"))
143
+
144
+ # Save the configuration
145
+ self.config.save_pretrained(save_directory)
146
+
147
+ # Optionally print messages based on the safe_serialization flag
148
  if safe_serialization:
149
+ print("Model and configuration have been saved safely.")
 
 
 
 
150
  else:
151
+ print("Model and configuration have been saved (unsafe serialization).")