EdgeTA / dnns /vilt /__init__.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import timm
from timm.models._factory import load_checkpoint
import torch
import os
from typing import List, Union
from torch import nn
from torch.jit import Final
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from utils.dl.common.model import get_model_device, set_module
import torch.nn.functional as F
from utils.common.log import logger
from transformers import ViltModel, ViltForQuestionAnswering
import torch.nn.functional as F
def vilt_b_32(num_classes):
"""
Vilt for VQA
settings based on the dataset VQAv2 (3129 classes):
1. use half of classes for LoRA adaptation
2. use this half of classes for DA evaluation (using corruptions for generating domain shifts),
and use another half of classes for CL evaluation.
"""
#model = ViltForQuestionAnswering.from_pretrained('dandelin/vilt-b32-mlm-itm')
model = ViltForQuestionAnswering.from_pretrained('new_impl/mm/Vis_bert/QuestionAnswering/vilt')
linear = model.classifier[3]
new_linear = nn.Linear(linear.in_features, num_classes, bias=True)
set_module(model, 'classifier.3', new_linear)
return model
if __name__ == '__main__':
model = vilt_b_32(1565)
print(model)
from transformers import ViltProcessor, ViltModel
from PIL import Image
import requests
# prepare image and text
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
text = "hello world"
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm")
model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm-itm")
inputs = processor(image, text, return_tensors="pt")
print(inputs)
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
print(last_hidden_states.shape)