suayptalha commited on
Commit
b08ebd4
·
verified ·
1 Parent(s): ae6d9ff

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +11 -5
modeling_minGRULM.py CHANGED
@@ -61,13 +61,19 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
61
 
62
  # Language modeling head
63
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
64
-
65
- # Copy weights instead of sharing them
66
- with torch.no_grad():
67
- self.lm_head.weight.data.copy_(self.model.min_gru_model.token_emb.weight.data)
68
-
69
  self.post_init()
70
 
 
 
 
 
 
 
 
 
 
 
71
  def get_input_embeddings(self):
72
  return self.model.min_gru_model.token_emb
73
 
 
61
 
62
  # Language modeling head
63
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
64
+
 
 
 
 
65
  self.post_init()
66
 
67
+ def post_init(self):
68
+ super().post_init()
69
+
70
+ # Ensure tied weights
71
+ self.tie_weights()
72
+
73
+ def tie_weights(self):
74
+ # Tie lm_head weights to the embedding layer weights
75
+ self.lm_head.weight = self.model.min_gru_model.token_emb.weight
76
+
77
  def get_input_embeddings(self):
78
  return self.model.min_gru_model.token_emb
79