fix: local
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -1244,18 +1244,10 @@ class NomicMoELayer(nn.Module):
|
|
1244 |
|
1245 |
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
1246 |
batch_size, seq_len, hidden_dim = x.shape
|
1247 |
-
if attention_mask is not None:
|
1248 |
-
valid_indices = attention_mask.bool().view(-1)
|
1249 |
-
x_valid = x.view(-1, hidden_dim)[valid_indices]
|
1250 |
|
1251 |
weights, top_weights, top_experts = self.router(x)
|
1252 |
out = self.experts(x, weights, top_weights, top_experts)
|
1253 |
|
1254 |
-
if attention_mask is not None:
|
1255 |
-
full_out = torch.zeros(batch_size * seq_len, hidden_dim, dtype=out.dtype, device=out.device)
|
1256 |
-
full_out[valid_indices] = out
|
1257 |
-
out = full_out.view(batch_size, seq_len, hidden_dim)
|
1258 |
-
|
1259 |
return out
|
1260 |
|
1261 |
|
|
|
1244 |
|
1245 |
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
1246 |
batch_size, seq_len, hidden_dim = x.shape
|
|
|
|
|
|
|
1247 |
|
1248 |
weights, top_weights, top_experts = self.router(x)
|
1249 |
out = self.experts(x, weights, top_weights, top_experts)
|
1250 |
|
|
|
|
|
|
|
|
|
|
|
1251 |
return out
|
1252 |
|
1253 |
|