suayptalha commited on
Commit
783bbd7
·
verified ·
1 Parent(s): 8042f2b

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +21 -6
modeling_minGRULM.py CHANGED
@@ -13,7 +13,7 @@ class MinGRULMWrapped(nn.Module):
13
  def __init__(self, min_gru_model):
14
  super().__init__()
15
  self.min_gru_model = min_gru_model
16
- self.device = torch.device("cuda") # Default device
17
 
18
  def forward(self, *args, **kwargs):
19
  # Move input tensors to the correct device
@@ -45,9 +45,11 @@ class MinGRULMPreTrainedModel(PreTrainedModel):
45
  elif isinstance(module, nn.LayerNorm):
46
  module.bias.data.zero_()
47
  module.weight.data.fill_(1.0)
 
 
 
 
48
 
49
-
50
- class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
51
  def __init__(self, config: MinGRULMConfig):
52
  super().__init__(config)
53
 
@@ -68,9 +70,8 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
68
  self.post_init()
69
 
70
  def post_init(self):
 
71
  super().post_init()
72
-
73
- # Ensure tied weights
74
  self.tie_weights()
75
 
76
  def tie_weights(self):
@@ -116,4 +117,18 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
116
  return CausalLMOutputWithPast(
117
  loss=loss,
118
  logits=logits,
119
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def __init__(self, min_gru_model):
14
  super().__init__()
15
  self.min_gru_model = min_gru_model
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  def forward(self, *args, **kwargs):
19
  # Move input tensors to the correct device
 
45
  elif isinstance(module, nn.LayerNorm):
46
  module.bias.data.zero_()
47
  module.weight.data.fill_(1.0)
48
+
49
+ class MinGRULMForCausalLM(PreTrainedModel):
50
+ config_class = MinGRULMConfig
51
+ base_model_prefix = "model"
52
 
 
 
53
  def __init__(self, config: MinGRULMConfig):
54
  super().__init__(config)
55
 
 
70
  self.post_init()
71
 
72
  def post_init(self):
73
+ # Ensure tied weights and any additional setup
74
  super().post_init()
 
 
75
  self.tie_weights()
76
 
77
  def tie_weights(self):
 
117
  return CausalLMOutputWithPast(
118
  loss=loss,
119
  logits=logits,
120
+ )
121
+
122
+ @classmethod
123
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
124
+ """
125
+ Load model from a pretrained checkpoint.
126
+ """
127
+ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
128
+ return model
129
+
130
+ def save_pretrained(self, save_directory):
131
+ """
132
+ Save the model and configuration to a directory.
133
+ """
134
+ super().save_pretrained(save_directory)