suayptalha commited on
Commit
e55adf4
·
verified ·
1 Parent(s): a32544d

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +3 -21
modeling_minGRULM.py CHANGED
@@ -8,7 +8,6 @@ from .configuration_minGRULM import MinGRULMConfig
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
10
 
11
- # Wrapper class for device compatibility
12
  class MinGRULMWrapped(nn.Module):
13
  def __init__(self, min_gru_model):
14
  super().__init__()
@@ -16,13 +15,11 @@ class MinGRULMWrapped(nn.Module):
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  def forward(self, *args, **kwargs):
19
- # Move input tensors to the correct device
20
  args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
21
  kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
22
  return self.min_gru_model(*args, **kwargs)
23
 
24
  def to(self, device):
25
- # Update device information
26
  self.device = device
27
  self.min_gru_model.to(device)
28
  return self
@@ -47,11 +44,10 @@ class MinGRULMPreTrainedModel(PreTrainedModel):
47
  module.bias.data.zero_()
48
  module.weight.data.fill_(1.0)
49
 
50
- # NaN kontrolü: Tüm parametrelerde NaN varsa, `torch.nan_to_num` kullanarak düzeltme
51
  for name, param in module.named_parameters():
52
  if torch.isnan(param).any():
53
  print(f"NaN detected in parameter {name}. Replacing with a safe number.")
54
- param.data = torch.nan_to_num(param.data, nan=1e-6) # NaN'ları 1e-6 ile değiştiriyoruz
55
 
56
 
57
  class MinGRULMForCausalLM(PreTrainedModel):
@@ -61,7 +57,6 @@ class MinGRULMForCausalLM(PreTrainedModel):
61
  def __init__(self, config: MinGRULMConfig):
62
  super().__init__(config)
63
 
64
- # Load model from minGRULM library and wrap it
65
  raw_min_gru = minGRULM(
66
  num_tokens=config.vocab_size,
67
  dim=config.d_model,
@@ -72,18 +67,15 @@ class MinGRULMForCausalLM(PreTrainedModel):
72
  )
73
  self.model = MinGRULMWrapped(raw_min_gru)
74
 
75
- # Language modeling head
76
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
77
 
78
  self.post_init()
79
 
80
  def post_init(self):
81
- # Ensure tied weights and any additional setup
82
  super().post_init()
83
  self.tie_weights()
84
 
85
  def tie_weights(self):
86
- # Tie lm_head weights to the embedding layer weights
87
  self.lm_head.weight = self.model.min_gru_model.token_emb.weight
88
 
89
  def get_input_embeddings(self):
@@ -96,17 +88,14 @@ class MinGRULMForCausalLM(PreTrainedModel):
96
  return self.lm_head
97
 
98
  def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
99
- # Ensure that inputs for generation are properly handled
100
  return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
101
 
102
  def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
103
- # Forward pass through the wrapped model
104
  logits = self.model(input_ids)
105
 
106
- # NaN kontrolü: Eğer logits'te NaN varsa, `torch.nan_to_num` kullanarak düzeltme
107
  if torch.isnan(logits).any():
108
  print("NaN detected in logits! Replacing with a safe number.")
109
- logits = torch.nan_to_num(logits, nan=1e-6) # NaN'ları 1e-6 ile değiştiriyoruz
110
 
111
  loss = None
112
  if labels is not None:
@@ -118,10 +107,9 @@ class MinGRULMForCausalLM(PreTrainedModel):
118
  shift_labels.view(-1),
119
  )
120
 
121
- # NaN kontrolü: Eğer loss'ta NaN varsa, `torch.nan_to_num` kullanarak düzeltme
122
  if torch.isnan(loss).any():
123
  print("NaN detected in loss! Replacing with a safe number.")
124
- loss = torch.nan_to_num(loss, nan=1e-6) # NaN'ları 1e-6 ile değiştiriyoruz
125
 
126
  if not return_dict:
127
  return (loss, logits) if loss is not None else (logits,)
@@ -148,15 +136,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
148
  save_directory (str): Directory to save the model.
149
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
150
  """
151
- # Create the save directory if it doesn't exist
152
  os.makedirs(save_directory, exist_ok=True)
153
 
154
- # Check if safe_serialization is enabled
155
  if safe_serialization:
156
  print("Saving with safe serialization.")
157
 
158
- # Save the model's state_dict (model weights)
159
- #state_dict = self.state_dict()
160
  state_dict = {}
161
 
162
  for name, param in self.model.min_gru_model.named_parameters():
@@ -168,9 +152,7 @@ class MinGRULMForCausalLM(PreTrainedModel):
168
  state_dict['config'] = self.config.__dict__
169
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
170
 
171
- # Save the configuration
172
  self.config.save_pretrained(save_directory)
173
  else:
174
  print("Saving without safe serialization.")
175
- # If not safe_serialization, use the default save mechanism from the base class
176
  super().save_pretrained(save_directory)
 
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
10
 
 
11
  class MinGRULMWrapped(nn.Module):
12
  def __init__(self, min_gru_model):
13
  super().__init__()
 
15
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
  def forward(self, *args, **kwargs):
 
18
  args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
19
  kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
20
  return self.min_gru_model(*args, **kwargs)
21
 
22
  def to(self, device):
 
23
  self.device = device
24
  self.min_gru_model.to(device)
25
  return self
 
44
  module.bias.data.zero_()
45
  module.weight.data.fill_(1.0)
46
 
 
47
  for name, param in module.named_parameters():
48
  if torch.isnan(param).any():
49
  print(f"NaN detected in parameter {name}. Replacing with a safe number.")
50
+ param.data = torch.nan_to_num(param.data, nan=1e-6)
51
 
52
 
53
  class MinGRULMForCausalLM(PreTrainedModel):
 
57
  def __init__(self, config: MinGRULMConfig):
58
  super().__init__(config)
59
 
 
60
  raw_min_gru = minGRULM(
61
  num_tokens=config.vocab_size,
62
  dim=config.d_model,
 
67
  )
68
  self.model = MinGRULMWrapped(raw_min_gru)
69
 
 
70
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
71
 
72
  self.post_init()
73
 
74
  def post_init(self):
 
75
  super().post_init()
76
  self.tie_weights()
77
 
78
  def tie_weights(self):
 
79
  self.lm_head.weight = self.model.min_gru_model.token_emb.weight
80
 
81
  def get_input_embeddings(self):
 
88
  return self.lm_head
89
 
90
  def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
 
91
  return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
92
 
93
  def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
 
94
  logits = self.model(input_ids)
95
 
 
96
  if torch.isnan(logits).any():
97
  print("NaN detected in logits! Replacing with a safe number.")
98
+ logits = torch.nan_to_num(logits, nan=1e-6)
99
 
100
  loss = None
101
  if labels is not None:
 
107
  shift_labels.view(-1),
108
  )
109
 
 
110
  if torch.isnan(loss).any():
111
  print("NaN detected in loss! Replacing with a safe number.")
112
+ loss = torch.nan_to_num(loss, nan=1e-6)
113
 
114
  if not return_dict:
115
  return (loss, logits) if loss is not None else (logits,)
 
136
  save_directory (str): Directory to save the model.
137
  safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
138
  """
 
139
  os.makedirs(save_directory, exist_ok=True)
140
 
 
141
  if safe_serialization:
142
  print("Saving with safe serialization.")
143
 
 
 
144
  state_dict = {}
145
 
146
  for name, param in self.model.min_gru_model.named_parameters():
 
152
  state_dict['config'] = self.config.__dict__
153
  torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
154
 
 
155
  self.config.save_pretrained(save_directory)
156
  else:
157
  print("Saving without safe serialization.")
 
158
  super().save_pretrained(save_directory)