suayptalha
commited on
Update modeling_minGRULM.py
Browse files- modeling_minGRULM.py +12 -1
modeling_minGRULM.py
CHANGED
@@ -93,7 +93,7 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
93 |
# Ensure that inputs for generation are properly handled
|
94 |
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
95 |
|
96 |
-
|
97 |
self,
|
98 |
input_ids: torch.LongTensor,
|
99 |
labels: Optional[torch.LongTensor] = None,
|
@@ -103,6 +103,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
103 |
# Forward pass through the wrapped model
|
104 |
logits = self.model(input_ids)
|
105 |
|
|
|
|
|
|
|
|
|
|
|
106 |
loss = None
|
107 |
if labels is not None:
|
108 |
shift_logits = logits[..., :-1, :].contiguous()
|
@@ -113,6 +118,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
113 |
shift_labels.view(-1),
|
114 |
)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
116 |
if not return_dict:
|
117 |
return (loss, logits) if loss is not None else (logits,)
|
118 |
|
@@ -121,6 +131,7 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
121 |
logits=logits,
|
122 |
)
|
123 |
|
|
|
124 |
@classmethod
|
125 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
126 |
"""
|
|
|
93 |
# Ensure that inputs for generation are properly handled
|
94 |
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
95 |
|
96 |
+
ddef forward(
|
97 |
self,
|
98 |
input_ids: torch.LongTensor,
|
99 |
labels: Optional[torch.LongTensor] = None,
|
|
|
103 |
# Forward pass through the wrapped model
|
104 |
logits = self.model(input_ids)
|
105 |
|
106 |
+
# NaN kontrolü
|
107 |
+
if torch.isnan(logits).any():
|
108 |
+
print("NaN detected in logits! Replacing with zeros.")
|
109 |
+
logits = torch.nan_to_num(logits, nan=0.0) # NaN'ları sıfırla değiştir
|
110 |
+
|
111 |
loss = None
|
112 |
if labels is not None:
|
113 |
shift_logits = logits[..., :-1, :].contiguous()
|
|
|
118 |
shift_labels.view(-1),
|
119 |
)
|
120 |
|
121 |
+
# NaN kontrolü için loss üzerinde de aynı işlemi uygulayın
|
122 |
+
if torch.isnan(loss).any():
|
123 |
+
print("NaN detected in loss! Setting loss to 0.")
|
124 |
+
loss = torch.tensor(0.0, device=loss.device) # NaN olan loss'u sıfırla değiştir
|
125 |
+
|
126 |
if not return_dict:
|
127 |
return (loss, logits) if loss is not None else (logits,)
|
128 |
|
|
|
131 |
logits=logits,
|
132 |
)
|
133 |
|
134 |
+
|
135 |
@classmethod
|
136 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
137 |
"""
|