suayptalha commited on
Commit
498e304
·
verified ·
1 Parent(s): 2f89d54

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +12 -1
modeling_minGRULM.py CHANGED
@@ -93,7 +93,7 @@ class MinGRULMForCausalLM(PreTrainedModel):
93
  # Ensure that inputs for generation are properly handled
94
  return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
95
 
96
- def forward(
97
  self,
98
  input_ids: torch.LongTensor,
99
  labels: Optional[torch.LongTensor] = None,
@@ -103,6 +103,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
103
  # Forward pass through the wrapped model
104
  logits = self.model(input_ids)
105
 
 
 
 
 
 
106
  loss = None
107
  if labels is not None:
108
  shift_logits = logits[..., :-1, :].contiguous()
@@ -113,6 +118,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
113
  shift_labels.view(-1),
114
  )
115
 
 
 
 
 
 
116
  if not return_dict:
117
  return (loss, logits) if loss is not None else (logits,)
118
 
@@ -121,6 +131,7 @@ class MinGRULMForCausalLM(PreTrainedModel):
121
  logits=logits,
122
  )
123
 
 
124
  @classmethod
125
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
126
  """
 
93
  # Ensure that inputs for generation are properly handled
94
  return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
95
 
96
+ ddef forward(
97
  self,
98
  input_ids: torch.LongTensor,
99
  labels: Optional[torch.LongTensor] = None,
 
103
  # Forward pass through the wrapped model
104
  logits = self.model(input_ids)
105
 
106
+ # NaN kontrolü
107
+ if torch.isnan(logits).any():
108
+ print("NaN detected in logits! Replacing with zeros.")
109
+ logits = torch.nan_to_num(logits, nan=0.0) # NaN'ları sıfırla değiştir
110
+
111
  loss = None
112
  if labels is not None:
113
  shift_logits = logits[..., :-1, :].contiguous()
 
118
  shift_labels.view(-1),
119
  )
120
 
121
+ # NaN kontrolü için loss üzerinde de aynı işlemi uygulayın
122
+ if torch.isnan(loss).any():
123
+ print("NaN detected in loss! Setting loss to 0.")
124
+ loss = torch.tensor(0.0, device=loss.device) # NaN olan loss'u sıfırla değiştir
125
+
126
  if not return_dict:
127
  return (loss, logits) if loss is not None else (logits,)
128
 
 
131
  logits=logits,
132
  )
133
 
134
+
135
  @classmethod
136
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
137
  """