import torch from transformers import PreTrainedModel class BiEncoderModelRegression(torch.nn.Module): def __init__(self, base_model, config=None, loss_fn="mse"): super().__init__() self.base_model = base_model self.cos = torch.nn.CosineSimilarity(dim=1) self.loss_fn = loss_fn self.config = config def forward(self, input_ids_text1, attention_mask_text1, input_ids_text2, attention_mask_text2, labels=None): outputs_text1 = self.base_model(input_ids_text1, attention_mask=attention_mask_text1) outputs_text2 = self.base_model(input_ids_text2, attention_mask=attention_mask_text2) cls_embedding_text1 = outputs_text1.last_hidden_state[:, 0, :] cls_embedding_text2 = outputs_text2.last_hidden_state[:, 0, :] cos_sim = self.cos(cls_embedding_text1, cls_embedding_text2) loss = None if labels is not None: if self.loss_fn == "mse": loss_fct = torch.nn.MSELoss() elif self.loss_fn == "mae": loss_fct = torch.nn.L1Loss() elif self.loss_fn == "cosine_embedding": loss_fct = torch.nn.CosineEmbeddingLoss() labels_cosine = 2 * (labels > 0.5).float() - 1 return {"loss": loss_fct(cls_embedding_text1, cls_embedding_text2, labels_cosine), "logits": cos_sim} loss = loss_fct(cos_sim, labels) return {"loss": loss, "logits": cos_sim}