import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering from transformers import AutoProcessor, AutoModelForVisualQuestionAnswering from PIL import Image import torch dataset_name = "Multimodal-Fatima/OK-VQA_train" model_name = "microsoft/git-base-vqav2" model_path = "git-base-vqav2" questions = ["What can happen the objects shown are thrown on the ground?", "What was the machine beside the bowl used for?", "What kind of cars are in the photo?", "What is the hairstyle of the blond called?", "How old do you have to be in canada to do this?", "Can you guess the place where the man is playing?", "What loony tune character is in this photo?", "Whose birthday is being celebrated?", "Where can that toilet seat be bought?", "What do you call the kind of pants that the man on the right is wearing?"] processor = AutoProcessor.from_pretrained(model_path) model = AutoModelForVisualQuestionAnswering.from_pretrained(model_path) def main(select_exemple_num): selectednum = select_exemple_num exemple_img = f"image{selectednum}.jpg" img = Image.open(exemple_img) question = questions[selectednum - 1] encoding = processor(img, question, return_tensors='pt') outputs = model(**encoding) logits = outputs.logits # --- output_str = 'pridicted : \n' predicted_classes = torch.sigmoid(logits) probs, classes = torch.topk(predicted_classes, 5) ans = '' for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()): print(prob, model.config.id2label[class_idx]) output_str += str(prob) output_str += " " output_str += model.config.id2label[class_idx] output_str += "\n" if not ans: ans = model.config.id2label[class_idx] print(ans) # --- output_str += f"\nso I think it's answer is : \n{ans}" return exemple_img, question, output_str demo = gr.Interface( fn=main, inputs=[gr.Slider(1, len(questions), step=1)], outputs=["image", "text", "text"], ) demo.launch(share=True)