Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
359596a
·
1 Parent(s): f69b8d1

fix: local

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +0 -8
modeling_hf_nomic_bert.py CHANGED
@@ -1244,18 +1244,10 @@ class NomicMoELayer(nn.Module):
1244
 
1245
  def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
1246
  batch_size, seq_len, hidden_dim = x.shape
1247
- if attention_mask is not None:
1248
- valid_indices = attention_mask.bool().view(-1)
1249
- x_valid = x.view(-1, hidden_dim)[valid_indices]
1250
 
1251
  weights, top_weights, top_experts = self.router(x)
1252
  out = self.experts(x, weights, top_weights, top_experts)
1253
 
1254
- if attention_mask is not None:
1255
- full_out = torch.zeros(batch_size * seq_len, hidden_dim, dtype=out.dtype, device=out.device)
1256
- full_out[valid_indices] = out
1257
- out = full_out.view(batch_size, seq_len, hidden_dim)
1258
-
1259
  return out
1260
 
1261
 
 
1244
 
1245
  def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
1246
  batch_size, seq_len, hidden_dim = x.shape
 
 
 
1247
 
1248
  weights, top_weights, top_experts = self.router(x)
1249
  out = self.experts(x, weights, top_weights, top_experts)
1250
 
 
 
 
 
 
1251
  return out
1252
 
1253