|
import torch |
|
|
|
from transformers import XLMRobertaModel as XLMRobertaModelBase |
|
|
|
|
|
class XLMRobertaModel(XLMRobertaModelBase): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.question_projection = torch.nn.Linear(768, 512) |
|
self.answer_projection = torch.nn.Linear(768, 512) |
|
|
|
def _embed(self, input_ids, attention_mask, projection): |
|
outputs = super().__call__(input_ids, attention_mask=attention_mask) |
|
sequence_output = outputs[0] |
|
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float() |
|
embeddings = torch.sum(sequence_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
return torch.tanh(projection(embeddings)) |
|
|
|
def question(self, input_ids, attention_mask): |
|
return self._embed(input_ids, attention_mask, self.question_projection) |
|
|
|
def answer(self, input_ids, attention_mask): |
|
return self._embed(input_ids, attention_mask, self.answer_projection) |