if001 commited on
Commit
7a518b4
·
1 Parent(s): 23ebed3
Files changed (1) hide show
  1. modeling_llama.py +2 -5
modeling_llama.py CHANGED
@@ -69,7 +69,6 @@ def convert_linear_to_moe(
69
  in_features: int,
70
  out_features: int,
71
  bias: bool = True,
72
- show_debug: bool = False,
73
  ):
74
  """Converts nn.Linear to MoeLayer
75
  Args:
@@ -89,7 +88,6 @@ def convert_linear_to_moe(
89
  bias=bias,
90
  name=name,
91
  layer_idx=layer_idx,
92
- show_debug=show_debug
93
  )
94
  else:
95
  return MoeLayer(
@@ -145,7 +143,7 @@ class MoeLayer(nn.Module):
145
  return results
146
 
147
  class LoRAMoeLayer(torch.nn.Module):
148
- def __init__(self, config, in_features, out_features, bias, name = "", layer_idx = -1, show_debug=False) -> None:
149
  super().__init__()
150
 
151
  self.config = config
@@ -195,8 +193,7 @@ class LoRAMoeLayer(torch.nn.Module):
195
  weights, selected_experts = torch.topk(
196
  gate_logits, self.num_experts_per_tok
197
  ) # b,s,n
198
- if hasattr(config, "show_debug") and config["show_debug"] == True:
199
- if self._layer_idx == 0 or self._layer_idx == 16 or self._layer_idx == 31:
200
  print(f"{self._name}_{self._layer_idx}: {selected_experts}")
201
  print("-"*10)
202
  weights = F.softmax(weights, dim=2, dtype=torch.float).to(
 
69
  in_features: int,
70
  out_features: int,
71
  bias: bool = True,
 
72
  ):
73
  """Converts nn.Linear to MoeLayer
74
  Args:
 
88
  bias=bias,
89
  name=name,
90
  layer_idx=layer_idx,
 
91
  )
92
  else:
93
  return MoeLayer(
 
143
  return results
144
 
145
  class LoRAMoeLayer(torch.nn.Module):
146
+ def __init__(self, config, in_features, out_features, bias, name = "", layer_idx = -1) -> None:
147
  super().__init__()
148
 
149
  self.config = config
 
193
  weights, selected_experts = torch.topk(
194
  gate_logits, self.num_experts_per_tok
195
  ) # b,s,n
196
+ if getattr(self.config, "show_debug", False) and self._layer_idx == 0 or self._layer_idx == 16 or self._layer_idx == 31:
 
197
  print(f"{self._name}_{self._layer_idx}: {selected_experts}")
198
  print("-"*10)
199
  weights = F.softmax(weights, dim=2, dtype=torch.float).to(