medmac01
Added multilingual_clip module
3bd5293
raw
history blame
1.18 kB
from Multilingual_CLIP.multilingual_clip import Config_MCLIP
import transformers
import torch
class MultilingualCLIP(transformers.PreTrainedModel):
config_class = Config_MCLIP.MCLIPConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
out_features=config.numDims)
def forward(self, txt, tokenizer, device):
txt_tok = tokenizer(txt, padding='max_length', max_length=77, truncation=True, return_tensors='pt').to(device)
embs = self.transformer(**txt_tok)
print(embs.keys())
embs = embs[0]
att = txt_tok['attention_mask']
embs = (embs * att.unsqueeze(2)) / att.sum(dim=1)[:, None].unsqueeze(2)
return self.LinearTransformation(embs)
@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
model.load_state_dict(state_dict)
return model, [], [], []