suayptalha commited on
Commit
2077bda
·
verified ·
1 Parent(s): 8165af0

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +11 -7
modeling_minGRULM.py CHANGED
@@ -99,16 +99,15 @@ class MinGRULMForCausalLM(PreTrainedModel):
99
  # Ensure that inputs for generation are properly handled
100
  return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
101
 
102
- def forward(
103
- self,
104
- input_ids: torch.LongTensor,
105
- labels: Optional[torch.LongTensor] = None,
106
- return_dict: Optional[bool] = True,
107
- **kwargs
108
- ):
109
  # Forward pass through the wrapped model
110
  logits = self.model(input_ids)
111
 
 
 
 
 
 
112
  loss = None
113
  if labels is not None:
114
  shift_logits = logits[..., :-1, :].contiguous()
@@ -119,6 +118,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
119
  shift_labels.view(-1),
120
  )
121
 
 
 
 
 
 
122
  if not return_dict:
123
  return (loss, logits) if loss is not None else (logits,)
124
 
 
99
  # Ensure that inputs for generation are properly handled
100
  return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
101
 
102
+ def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
 
 
 
 
 
 
103
  # Forward pass through the wrapped model
104
  logits = self.model(input_ids)
105
 
106
+ # NaN kontrolü: Eğer logits'te NaN varsa, sıfırlama
107
+ if torch.isnan(logits).any():
108
+ print("NaN detected in logits! Replacing with zeros.")
109
+ logits = torch.nan_to_num(logits, nan=0.0)
110
+
111
  loss = None
112
  if labels is not None:
113
  shift_logits = logits[..., :-1, :].contiguous()
 
118
  shift_labels.view(-1),
119
  )
120
 
121
+ # NaN kontrolü: Eğer loss'ta NaN varsa, sıfırlama
122
+ if torch.isnan(loss).any():
123
+ print("NaN detected in loss! Replacing with zeros.")
124
+ loss = torch.tensor(0.0, device=loss.device) # NaN olan loss'u sıfırlıyoruz
125
+
126
  if not return_dict:
127
  return (loss, logits) if loss is not None else (logits,)
128