Update modeling_moe_mistral.py
Browse files- modeling_moe_mistral.py +4 -4
modeling_moe_mistral.py
CHANGED
@@ -220,11 +220,11 @@ class MoE(nn.Module):
|
|
220 |
flat_expert_indices = expert_indices.view(-1)
|
221 |
|
222 |
x = x.repeat_interleave(self.num_experts_per_token, dim=0)
|
223 |
-
|
224 |
for i, expert in enumerate(self.experts):
|
225 |
-
|
226 |
-
|
227 |
-
return
|
228 |
|
229 |
|
230 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
|
|
220 |
flat_expert_indices = expert_indices.view(-1)
|
221 |
|
222 |
x = x.repeat_interleave(self.num_experts_per_token, dim=0)
|
223 |
+
y = torch.empty_like(x)
|
224 |
for i, expert in enumerate(self.experts):
|
225 |
+
y[flat_expert_indices == i] = expert(y[flat_expert_indices == i])
|
226 |
+
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(dim=1)
|
227 |
+
return y.view(*orig_shape)
|
228 |
|
229 |
|
230 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|