bjoernp commited on
Commit
9b9979d
·
1 Parent(s): f93990d

Update modeling_moe_mistral.py

Browse files
Files changed (1) hide show
  1. modeling_moe_mistral.py +5 -6
modeling_moe_mistral.py CHANGED
@@ -215,17 +215,16 @@ class MoE(nn.Module):
215
  orig_shape = x.shape
216
  x = x.view(-1, x.shape[-1])
217
 
218
- scores = self.gate(x)
219
  expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
220
- expert_weights = expert_weights.softmax(dim=-1)
221
  flat_expert_indices = expert_indices.view(-1)
222
 
223
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
224
- y = torch.empty_like(x)
225
  for i, expert in enumerate(self.experts):
226
- y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
227
- y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
228
- return y.view(*orig_shape)
229
 
230
 
231
  # Copied from transformers.models.llama.modeling_llama.repeat_kv
 
215
  orig_shape = x.shape
216
  x = x.view(-1, x.shape[-1])
217
 
218
+ scores = self.gate(x).softmax(dim=-1)
219
  expert_weights, expert_indices = torch.topk(scores, self.num_experts_per_token, dim=-1)
 
220
  flat_expert_indices = expert_indices.view(-1)
221
 
222
  x = x.repeat_interleave(self.num_experts_per_token, dim=0)
223
+ x = torch.empty_like(x)
224
  for i, expert in enumerate(self.experts):
225
+ x[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
226
+ x = (x.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
227
+ return x.view(*orig_shape)
228
 
229
 
230
  # Copied from transformers.models.llama.modeling_llama.repeat_kv