suayptalha
commited on
Update modeling_minGRULM.py
Browse files- modeling_minGRULM.py +11 -7
modeling_minGRULM.py
CHANGED
@@ -99,16 +99,15 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
99 |
# Ensure that inputs for generation are properly handled
|
100 |
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
101 |
|
102 |
-
def forward(
|
103 |
-
self,
|
104 |
-
input_ids: torch.LongTensor,
|
105 |
-
labels: Optional[torch.LongTensor] = None,
|
106 |
-
return_dict: Optional[bool] = True,
|
107 |
-
**kwargs
|
108 |
-
):
|
109 |
# Forward pass through the wrapped model
|
110 |
logits = self.model(input_ids)
|
111 |
|
|
|
|
|
|
|
|
|
|
|
112 |
loss = None
|
113 |
if labels is not None:
|
114 |
shift_logits = logits[..., :-1, :].contiguous()
|
@@ -119,6 +118,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
119 |
shift_labels.view(-1),
|
120 |
)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
122 |
if not return_dict:
|
123 |
return (loss, logits) if loss is not None else (logits,)
|
124 |
|
|
|
99 |
# Ensure that inputs for generation are properly handled
|
100 |
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
101 |
|
102 |
+
def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# Forward pass through the wrapped model
|
104 |
logits = self.model(input_ids)
|
105 |
|
106 |
+
# NaN kontrolü: Eğer logits'te NaN varsa, sıfırlama
|
107 |
+
if torch.isnan(logits).any():
|
108 |
+
print("NaN detected in logits! Replacing with zeros.")
|
109 |
+
logits = torch.nan_to_num(logits, nan=0.0)
|
110 |
+
|
111 |
loss = None
|
112 |
if labels is not None:
|
113 |
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
118 |
shift_labels.view(-1),
|
119 |
)
|
120 |
|
121 |
+
# NaN kontrolü: Eğer loss'ta NaN varsa, sıfırlama
|
122 |
+
if torch.isnan(loss).any():
|
123 |
+
print("NaN detected in loss! Replacing with zeros.")
|
124 |
+
loss = torch.tensor(0.0, device=loss.device) # NaN olan loss'u sıfırlıyoruz
|
125 |
+
|
126 |
if not return_dict:
|
127 |
return (loss, logits) if loss is not None else (logits,)
|
128 |
|