|
import gradio as gr |
|
from transformers import ViltProcessor, ViltForQuestionAnswering |
|
from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering |
|
from PIL import Image |
|
import torch |
|
|
|
dataset_name = "Multimodal-Fatima/OK-VQA_train" |
|
model_name = "microsoft/git-base-vqav2" |
|
model_path = "git-base-vqav2" |
|
|
|
questions = ["What can happen the objects shown are thrown on the ground?", |
|
"What was the machine beside the bowl used for?", |
|
"What kind of cars are in the photo?", |
|
"What is the hairstyle of the blond called?", |
|
"How old do you have to be in canada to do this?", |
|
"Can you guess the place where the man is playing?", |
|
"What loony tune character is in this photo?", |
|
"Whose birthday is being celebrated?", |
|
"Where can that toilet seat be bought?", |
|
"What do you call the kind of pants that the man on the right is wearing?"] |
|
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
model = AutoModelForVisualQuestionAnswering.from_pretrained(model_path) |
|
|
|
|
|
def main(select_exemple_num): |
|
selectednum = select_exemple_num |
|
exemple_img = f"image{selectednum}.jpg" |
|
img = Image.open(exemple_img) |
|
question = questions[selectednum - 1] |
|
|
|
encoding = processor(img, question, return_tensors='pt') |
|
|
|
outputs = model(**encoding) |
|
logits = outputs.logits |
|
|
|
|
|
output_str = 'pridicted : \n' |
|
predicted_classes = torch.sigmoid(logits) |
|
|
|
probs, classes = torch.topk(predicted_classes, 5) |
|
ans = '' |
|
|
|
for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()): |
|
print(prob, model.config.id2label[class_idx]) |
|
output_str += str(prob) |
|
output_str += " " |
|
output_str += model.config.id2label[class_idx] |
|
output_str += "\n" |
|
if not ans: |
|
ans = model.config.id2label[class_idx] |
|
|
|
print(ans) |
|
|
|
output_str += f"\nso I think it's answer is : \n{ans}" |
|
|
|
return exemple_img, question, output_str |
|
|
|
|
|
demo = gr.Interface( |
|
fn=main, |
|
inputs=[gr.Slider(1, len(questions), step=1)], |
|
outputs=["image", "text", "text"], |
|
) |
|
|
|
demo.launch(share=True) |
|
|