suayptalha
commited on
Update modeling_minGRULM.py
Browse files- modeling_minGRULM.py +3 -21
modeling_minGRULM.py
CHANGED
@@ -8,7 +8,6 @@ from .configuration_minGRULM import MinGRULMConfig
|
|
8 |
from minGRU_pytorch.minGRULM import minGRULM
|
9 |
|
10 |
|
11 |
-
# Wrapper class for device compatibility
|
12 |
class MinGRULMWrapped(nn.Module):
|
13 |
def __init__(self, min_gru_model):
|
14 |
super().__init__()
|
@@ -16,13 +15,11 @@ class MinGRULMWrapped(nn.Module):
|
|
16 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
def forward(self, *args, **kwargs):
|
19 |
-
# Move input tensors to the correct device
|
20 |
args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
21 |
kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
22 |
return self.min_gru_model(*args, **kwargs)
|
23 |
|
24 |
def to(self, device):
|
25 |
-
# Update device information
|
26 |
self.device = device
|
27 |
self.min_gru_model.to(device)
|
28 |
return self
|
@@ -47,11 +44,10 @@ class MinGRULMPreTrainedModel(PreTrainedModel):
|
|
47 |
module.bias.data.zero_()
|
48 |
module.weight.data.fill_(1.0)
|
49 |
|
50 |
-
# NaN kontrolü: Tüm parametrelerde NaN varsa, `torch.nan_to_num` kullanarak düzeltme
|
51 |
for name, param in module.named_parameters():
|
52 |
if torch.isnan(param).any():
|
53 |
print(f"NaN detected in parameter {name}. Replacing with a safe number.")
|
54 |
-
param.data = torch.nan_to_num(param.data, nan=1e-6)
|
55 |
|
56 |
|
57 |
class MinGRULMForCausalLM(PreTrainedModel):
|
@@ -61,7 +57,6 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
61 |
def __init__(self, config: MinGRULMConfig):
|
62 |
super().__init__(config)
|
63 |
|
64 |
-
# Load model from minGRULM library and wrap it
|
65 |
raw_min_gru = minGRULM(
|
66 |
num_tokens=config.vocab_size,
|
67 |
dim=config.d_model,
|
@@ -72,18 +67,15 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
72 |
)
|
73 |
self.model = MinGRULMWrapped(raw_min_gru)
|
74 |
|
75 |
-
# Language modeling head
|
76 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
77 |
|
78 |
self.post_init()
|
79 |
|
80 |
def post_init(self):
|
81 |
-
# Ensure tied weights and any additional setup
|
82 |
super().post_init()
|
83 |
self.tie_weights()
|
84 |
|
85 |
def tie_weights(self):
|
86 |
-
# Tie lm_head weights to the embedding layer weights
|
87 |
self.lm_head.weight = self.model.min_gru_model.token_emb.weight
|
88 |
|
89 |
def get_input_embeddings(self):
|
@@ -96,17 +88,14 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
96 |
return self.lm_head
|
97 |
|
98 |
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
|
99 |
-
# Ensure that inputs for generation are properly handled
|
100 |
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
101 |
|
102 |
def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
|
103 |
-
# Forward pass through the wrapped model
|
104 |
logits = self.model(input_ids)
|
105 |
|
106 |
-
# NaN kontrolü: Eğer logits'te NaN varsa, `torch.nan_to_num` kullanarak düzeltme
|
107 |
if torch.isnan(logits).any():
|
108 |
print("NaN detected in logits! Replacing with a safe number.")
|
109 |
-
logits = torch.nan_to_num(logits, nan=1e-6)
|
110 |
|
111 |
loss = None
|
112 |
if labels is not None:
|
@@ -118,10 +107,9 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
118 |
shift_labels.view(-1),
|
119 |
)
|
120 |
|
121 |
-
# NaN kontrolü: Eğer loss'ta NaN varsa, `torch.nan_to_num` kullanarak düzeltme
|
122 |
if torch.isnan(loss).any():
|
123 |
print("NaN detected in loss! Replacing with a safe number.")
|
124 |
-
loss = torch.nan_to_num(loss, nan=1e-6)
|
125 |
|
126 |
if not return_dict:
|
127 |
return (loss, logits) if loss is not None else (logits,)
|
@@ -148,15 +136,11 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
148 |
save_directory (str): Directory to save the model.
|
149 |
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
|
150 |
"""
|
151 |
-
# Create the save directory if it doesn't exist
|
152 |
os.makedirs(save_directory, exist_ok=True)
|
153 |
|
154 |
-
# Check if safe_serialization is enabled
|
155 |
if safe_serialization:
|
156 |
print("Saving with safe serialization.")
|
157 |
|
158 |
-
# Save the model's state_dict (model weights)
|
159 |
-
#state_dict = self.state_dict()
|
160 |
state_dict = {}
|
161 |
|
162 |
for name, param in self.model.min_gru_model.named_parameters():
|
@@ -168,9 +152,7 @@ class MinGRULMForCausalLM(PreTrainedModel):
|
|
168 |
state_dict['config'] = self.config.__dict__
|
169 |
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
170 |
|
171 |
-
# Save the configuration
|
172 |
self.config.save_pretrained(save_directory)
|
173 |
else:
|
174 |
print("Saving without safe serialization.")
|
175 |
-
# If not safe_serialization, use the default save mechanism from the base class
|
176 |
super().save_pretrained(save_directory)
|
|
|
8 |
from minGRU_pytorch.minGRULM import minGRULM
|
9 |
|
10 |
|
|
|
11 |
class MinGRULMWrapped(nn.Module):
|
12 |
def __init__(self, min_gru_model):
|
13 |
super().__init__()
|
|
|
15 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
|
17 |
def forward(self, *args, **kwargs):
|
|
|
18 |
args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
19 |
kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
20 |
return self.min_gru_model(*args, **kwargs)
|
21 |
|
22 |
def to(self, device):
|
|
|
23 |
self.device = device
|
24 |
self.min_gru_model.to(device)
|
25 |
return self
|
|
|
44 |
module.bias.data.zero_()
|
45 |
module.weight.data.fill_(1.0)
|
46 |
|
|
|
47 |
for name, param in module.named_parameters():
|
48 |
if torch.isnan(param).any():
|
49 |
print(f"NaN detected in parameter {name}. Replacing with a safe number.")
|
50 |
+
param.data = torch.nan_to_num(param.data, nan=1e-6)
|
51 |
|
52 |
|
53 |
class MinGRULMForCausalLM(PreTrainedModel):
|
|
|
57 |
def __init__(self, config: MinGRULMConfig):
|
58 |
super().__init__(config)
|
59 |
|
|
|
60 |
raw_min_gru = minGRULM(
|
61 |
num_tokens=config.vocab_size,
|
62 |
dim=config.d_model,
|
|
|
67 |
)
|
68 |
self.model = MinGRULMWrapped(raw_min_gru)
|
69 |
|
|
|
70 |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
71 |
|
72 |
self.post_init()
|
73 |
|
74 |
def post_init(self):
|
|
|
75 |
super().post_init()
|
76 |
self.tie_weights()
|
77 |
|
78 |
def tie_weights(self):
|
|
|
79 |
self.lm_head.weight = self.model.min_gru_model.token_emb.weight
|
80 |
|
81 |
def get_input_embeddings(self):
|
|
|
88 |
return self.lm_head
|
89 |
|
90 |
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
|
|
|
91 |
return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
|
92 |
|
93 |
def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
|
|
|
94 |
logits = self.model(input_ids)
|
95 |
|
|
|
96 |
if torch.isnan(logits).any():
|
97 |
print("NaN detected in logits! Replacing with a safe number.")
|
98 |
+
logits = torch.nan_to_num(logits, nan=1e-6)
|
99 |
|
100 |
loss = None
|
101 |
if labels is not None:
|
|
|
107 |
shift_labels.view(-1),
|
108 |
)
|
109 |
|
|
|
110 |
if torch.isnan(loss).any():
|
111 |
print("NaN detected in loss! Replacing with a safe number.")
|
112 |
+
loss = torch.nan_to_num(loss, nan=1e-6)
|
113 |
|
114 |
if not return_dict:
|
115 |
return (loss, logits) if loss is not None else (logits,)
|
|
|
136 |
save_directory (str): Directory to save the model.
|
137 |
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
|
138 |
"""
|
|
|
139 |
os.makedirs(save_directory, exist_ok=True)
|
140 |
|
|
|
141 |
if safe_serialization:
|
142 |
print("Saving with safe serialization.")
|
143 |
|
|
|
|
|
144 |
state_dict = {}
|
145 |
|
146 |
for name, param in self.model.min_gru_model.named_parameters():
|
|
|
152 |
state_dict['config'] = self.config.__dict__
|
153 |
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
154 |
|
|
|
155 |
self.config.save_pretrained(save_directory)
|
156 |
else:
|
157 |
print("Saving without safe serialization.")
|
|
|
158 |
super().save_pretrained(save_directory)
|