|
|
|
import torch |
|
from transformers import PreTrainedModel |
|
|
|
class BiEncoderModelRegression(torch.nn.Module): |
|
def __init__(self, base_model, config=None, loss_fn="mse"): |
|
super().__init__() |
|
self.base_model = base_model |
|
self.cos = torch.nn.CosineSimilarity(dim=1) |
|
self.loss_fn = loss_fn |
|
self.config = config |
|
|
|
def forward(self, input_ids_text1, attention_mask_text1, input_ids_text2, attention_mask_text2, labels=None): |
|
outputs_text1 = self.base_model(input_ids_text1, attention_mask=attention_mask_text1) |
|
outputs_text2 = self.base_model(input_ids_text2, attention_mask=attention_mask_text2) |
|
|
|
cls_embedding_text1 = outputs_text1.last_hidden_state[:, 0, :] |
|
cls_embedding_text2 = outputs_text2.last_hidden_state[:, 0, :] |
|
|
|
cos_sim = self.cos(cls_embedding_text1, cls_embedding_text2) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.loss_fn == "mse": |
|
loss_fct = torch.nn.MSELoss() |
|
elif self.loss_fn == "mae": |
|
loss_fct = torch.nn.L1Loss() |
|
elif self.loss_fn == "cosine_embedding": |
|
loss_fct = torch.nn.CosineEmbeddingLoss() |
|
labels_cosine = 2 * (labels > 0.5).float() - 1 |
|
return {"loss": loss_fct(cls_embedding_text1, cls_embedding_text2, labels_cosine), "logits": cos_sim} |
|
|
|
loss = loss_fct(cos_sim, labels) |
|
|
|
return {"loss": loss, "logits": cos_sim} |
|
|