|
import torch |
|
from transformers import RoFormerForMaskedLM |
|
|
|
|
|
class RoFormerForSparseEmbeddingV2(RoFormerForMaskedLM): |
|
def forward(self, input_ids, attention_mask, return_sparse=False): |
|
logits = super().forward(input_ids, attention_mask)['logits'] |
|
token_mask = (1 - attention_mask.unsqueeze(-1)) * -1e4 |
|
token_mask[:, 0, :] = -1e4 |
|
last_ind = torch.sum(attention_mask, -1, keepdim=True).unsqueeze(-1) - 1 |
|
token_mask = torch.scatter(token_mask, -2, last_ind, -1e4) |
|
logits = logits + token_mask |
|
emb = torch.log(1 + torch.max(torch.relu(logits), dim=-2).values) |
|
|
|
if return_sparse: |
|
emb = emb.to_sparse() |
|
|
|
return emb |
|
|