File size: 1,518 Bytes
c65ed72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

import torch
from transformers import PreTrainedModel

class BiEncoderModelRegression(torch.nn.Module):
    def __init__(self, base_model, config=None, loss_fn="mse"):
        super().__init__()
        self.base_model = base_model
        self.cos = torch.nn.CosineSimilarity(dim=1)
        self.loss_fn = loss_fn
        self.config = config

    def forward(self, input_ids_text1, attention_mask_text1, input_ids_text2, attention_mask_text2, labels=None):
        outputs_text1 = self.base_model(input_ids_text1, attention_mask=attention_mask_text1)
        outputs_text2 = self.base_model(input_ids_text2, attention_mask=attention_mask_text2)
        
        cls_embedding_text1 = outputs_text1.last_hidden_state[:, 0, :]
        cls_embedding_text2 = outputs_text2.last_hidden_state[:, 0, :]
        
        cos_sim = self.cos(cls_embedding_text1, cls_embedding_text2)
        
        loss = None
        if labels is not None:
            if self.loss_fn == "mse":
                loss_fct = torch.nn.MSELoss()
            elif self.loss_fn == "mae":
                loss_fct = torch.nn.L1Loss()
            elif self.loss_fn == "cosine_embedding":
                loss_fct = torch.nn.CosineEmbeddingLoss()
                labels_cosine = 2 * (labels > 0.5).float() - 1
                return {"loss": loss_fct(cls_embedding_text1, cls_embedding_text2, labels_cosine), "logits": cos_sim}
            
            loss = loss_fct(cos_sim, labels)
            
        return {"loss": loss, "logits": cos_sim}