inf-wse-v2-base-zh / modeling_sparse.py
SamuelYang's picture
Upload 8 files
0e4c95f verified
raw
history blame contribute delete
738 Bytes
import torch
from transformers import RoFormerForMaskedLM
class RoFormerForSparseEmbeddingV2(RoFormerForMaskedLM):
def forward(self, input_ids, attention_mask, return_sparse=False):
logits = super().forward(input_ids, attention_mask)['logits'] # [B,L,V]
token_mask = (1 - attention_mask.unsqueeze(-1)) * -1e4 # [B,L,1]
token_mask[:, 0, :] = -1e4
last_ind = torch.sum(attention_mask, -1, keepdim=True).unsqueeze(-1) - 1 # [B,1,1]
token_mask = torch.scatter(token_mask, -2, last_ind, -1e4)
logits = logits + token_mask
emb = torch.log(1 + torch.max(torch.relu(logits), dim=-2).values) # [B,V]
if return_sparse:
emb = emb.to_sparse()
return emb