roberta-base-use-qa-bg / modeling_roberta.py
rmihaylov's picture
add model
07ea3b8
import torch
from transformers import XLMRobertaModel as XLMRobertaModelBase
class XLMRobertaModel(XLMRobertaModelBase):
def __init__(self, config):
super().__init__(config)
self.question_projection = torch.nn.Linear(768, 512)
self.answer_projection = torch.nn.Linear(768, 512)
def _embed(self, input_ids, attention_mask, projection):
outputs = super().__call__(input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
embeddings = torch.sum(sequence_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return torch.tanh(projection(embeddings))
def question(self, input_ids, attention_mask):
return self._embed(input_ids, attention_mask, self.question_projection)
def answer(self, input_ids, attention_mask):
return self._embed(input_ids, attention_mask, self.answer_projection)