File size: 1,182 Bytes
3bd5293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from Multilingual_CLIP.multilingual_clip import Config_MCLIP
import transformers
import torch


class MultilingualCLIP(transformers.PreTrainedModel):
    config_class = Config_MCLIP.MCLIPConfig

    def __init__(self, config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)
        self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
        self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
                                                    out_features=config.numDims)

    def forward(self, txt, tokenizer, device):
        txt_tok = tokenizer(txt, padding='max_length', max_length=77, truncation=True, return_tensors='pt').to(device)
        embs = self.transformer(**txt_tok)
        print(embs.keys())
        embs = embs[0]
        att = txt_tok['attention_mask']
        embs = (embs * att.unsqueeze(2)) / att.sum(dim=1)[:, None].unsqueeze(2)
        return self.LinearTransformation(embs)

    @classmethod
    def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
        model.load_state_dict(state_dict)
        return model, [], [], []