add debug flag
Browse files- modeling_llama.py +4 -3
modeling_llama.py
CHANGED
@@ -195,9 +195,10 @@ class LoRAMoeLayer(torch.nn.Module):
|
|
195 |
weights, selected_experts = torch.topk(
|
196 |
gate_logits, self.num_experts_per_tok
|
197 |
) # b,s,n
|
198 |
-
if
|
199 |
-
|
200 |
-
|
|
|
201 |
weights = F.softmax(weights, dim=2, dtype=torch.float).to(
|
202 |
previous_dtype
|
203 |
) # b,s,n
|
|
|
195 |
weights, selected_experts = torch.topk(
|
196 |
gate_logits, self.num_experts_per_tok
|
197 |
) # b,s,n
|
198 |
+
if hasattr(config, "show_debug") and config["show_debug"] == True:
|
199 |
+
if self._layer_idx == 0 or self._layer_idx == 16 or self._layer_idx == 31:
|
200 |
+
print(f"{self._name}_{self._layer_idx}: {selected_experts}")
|
201 |
+
print("-"*10)
|
202 |
weights = F.softmax(weights, dim=2, dtype=torch.float).to(
|
203 |
previous_dtype
|
204 |
) # b,s,n
|