Spaces:
Runtime error
Runtime error
##### VQA MED Demo | |
import gradio as gr | |
from transformers import ViltProcessor, ViltForQuestionAnswering | |
import torch | |
import torch.nn as nn | |
from transformers import CLIPTokenizer | |
from CLIP import clip | |
from Transformers_for_Caption import Transformer_Caption | |
import numpy as np | |
import torchvision.transforms as transforms | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
class Config(object): | |
def __init__(self): | |
# Learning Rates | |
# Transformer | |
self.hidden_dim = 512 | |
self.pad_token_id = 0 | |
self.max_position_embeddings = 76 | |
self.layer_norm_eps = 1e-12 | |
self.dropout = 0.1 | |
self.vocab_size = 49408 | |
self.enc_layers = 1 | |
self.dec_layers = 1 | |
self.dim_feedforward = 1024 #2048 | |
self.nheads = 4 | |
self.pre_norm = True | |
# Dataset | |
#self.dir = os.getcwd() + '/data/coco' | |
self.limit = -1 | |
##### OUR MODEL | |
class VQA_Net(nn.Module): | |
def __init__(self, num_classes): | |
super(VQA_Net,self).__init__() | |
#self.VIT = deit_base_distilled_patch16_224(pretrained=True) | |
#self.VIT =vit_base_patch16_224_dino(pretrained=True) | |
#self.VIT = vit_base_patch32_sam_224(pretrained=True) ###### please not that we used only 6 layers | |
#self.VIT=maxvit_rmlp_nano_rw_256(pretrained=True) | |
#self.VIT = vit_base_patch8_224(pretrained=True) | |
#self.VIT=m = tf_efficientnetv2_m(pretrained=True, features_only=True, out_indices=(1,3), feature_location='expansion') | |
self.backbone, _ = clip.load('ViT-B/32', device, jit=False) | |
self.input_proj = nn.LayerNorm(512) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1)) | |
self.transformer_decoder = Transformer_Caption(config,num_decoder_layers=2) | |
self.mlp = nn.Sequential(nn.Sequential(nn.Linear(512, num_classes))) # MLP(256, 512, 30522, 1) 49408) | |
#self.samples_proj = nn.Sequential(nn.Linear(768,512)) | |
self.samples_proj = nn.Identity() | |
self.question_proj = nn.Identity() #nn.Sequential(nn.Linear(512, 512,bias=False)) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1)) | |
#self.tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
def forward(self, samples, question_in, answer_out, mask_answer): | |
# print('Here') | |
#print(samples.shape) | |
_, _, samples = self.backbone.encode_image(samples) | |
#samples=self.VIT(samples) | |
#print(samples.shape) | |
samples=samples.float() | |
#samples = self.VIT(samples) | |
#print(`samples.shape) | |
#samples = samples.view(-1, 512, 8 * 8) | |
# print(img_seq.shape) | |
#samples = samples.permute(0, 2, 1) | |
#samples=samples[:,0:,:] @ self.samples_proj | |
samples = self.samples_proj(samples) | |
#print(samples.shape) | |
#print(samples.shape) | |
_, _,question_in = self.backbone.encode_text(question_in) | |
#print(question_in.shape) | |
#samples = self.samples_proj(samples.float()) | |
question_in = self.question_proj(question_in.float()) | |
#print(question_in.shape) | |
#print(samples.shape) | |
samples = torch.cat((samples, question_in), dim=1) | |
#print(samples.shape) | |
# src, mask = features[-1].decompose() | |
# assert mask is not None | |
hs = self.transformer_decoder(self.input_proj(samples.permute(1, 0, 2).float()), answer_out, tgt_mask=mask_answer) | |
out = self.mlp(hs.permute(1, 0, 2)) | |
# print(out.shape) | |
return out | |
config = Config() | |
Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
My_VQA = VQA_Net(num_classes=len(Tokenizer)) | |
My_VQA.load_state_dict(torch.load("./PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar",map_location= torch.device(device))) | |
tfms = transforms.Compose([ | |
#transforms.Lambda(under_max), | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
# transforms.Normalize(0.5, 0.5), | |
]) | |
def answer_question(image, text_question): | |
with torch.no_grad(): | |
for iter in range(1): | |
start_token = Tokenizer.convert_tokens_to_ids("<|startoftext|>") | |
# end_token = Tokenizer.convert_tokens_to_ids("<|endoftext|>") | |
# start_token=tokenizer.convert_tokens_to_ids(tokenizer._cls_token) | |
caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long) | |
cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool) | |
caption[:, 0] = start_token | |
cap_mask[:, 0] = False | |
if text_question.find('?') > -1: | |
text_question = text_question.split('?')[0].lower() | |
text_question= np.array(Tokenizer.encode_plus(text_question, max_length=77, pad_to_max_length=True,return_attention_mask=True, | |
return_token_type_ids=False, truncation=True)['input_ids']) | |
#print(torch.Tensor(text_question).unsqueeze(0).long()) | |
for i in range(config.max_position_embeddings - 1): | |
predictions = My_VQA(image.unsqueeze(0),torch.Tensor(text_question).unsqueeze(0).long(), caption,cap_mask) | |
predictions = predictions[:, i, :] | |
predicted_id = torch.argmax(predictions, axis=-1) | |
caption[:, i + 1] = predicted_id[0] | |
cap_mask[:, i + 1] = False | |
if predicted_id[0] == 49407: | |
break | |
#print('question:') | |
#print(batch_test['question']) | |
cap_result_intermediate = Tokenizer.decode(caption[0].tolist(), skip_special_tokens=True) | |
#print('+++++++++++++++++++++++++++++++++++') | |
#print("True:") | |
# print(ref_sentence) | |
cap_result = cap_result_intermediate.split('!') | |
#ref_sentence = batch_test['answer'].lower() | |
#print(ref_sentence) | |
#print("Predict:") | |
#print(cap_result) | |
# image_disp=inv_Normalize(batch_test['image'])[0].permute(1,2,0).detach().cpu().numpy() | |
# print('************************') | |
# plt.imshow(image_disp) | |
return cap_result | |
def infer_answer_question(image, text): | |
if text is None: | |
cap_result = "please write a question" | |
elif image is None: | |
cap_result = "please upload an image" | |
else: | |
image_encoded = tfms(image) | |
cap_result=answer_question(image_encoded,text)[0] | |
return cap_result | |
image = gr.Image(type="pil") | |
question = gr.Textbox(label="Question") | |
answer = gr.Textbox(label="Predicted answer") | |
examples = [["train_0000.jpg", "Where are liver stem cells (oval cells) located?"], | |
["train_0001.jpg", "What are stained here with an immunohistochemical stain for cytokeratin 7?"], | |
["train_0002.jpg", "What are bile duct cells and canals of Hering stained here with for cytokeratin 7?"], | |
["train_0003.jpg", "Are bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?"], | |
["train_0018.jpg", "Is there an infarct in the brain hypertrophy?"], | |
["train_0019.jpg", "What is ischemic coagulative necrosis?"]] | |
title = "Vision–Language Model for Visual Question Answering in Medical Imagery" | |
description = "Y Bazi, MMA Rahhal, L Bashmal, M Zuair. <a href='https://www.mdpi.com/2306-5354/10/3/380' target='_blank'> Vision–Language Model for Visual Question Answering in Medical Imagery</a>. Bioengineering, 2023<br><br>"\ | |
"Gradio Demo for VQA medical model trained on PathVQA dataset, To use it, upload your image and type a question and click 'submit', or click one of the examples to load them." \ | |
### link to paper and github code | |
website = "" | |
article = f"<p style='text-align: center'><a href='{website}' target='_blank'>BigMed@KSU</a></p>" | |
interface = gr.Interface(fn=infer_answer_question, | |
inputs=[image, question], | |
outputs=answer, | |
examples=examples, | |
title=title, | |
description=description, | |
article=article) | |
interface.launch(debug=True, enable_queue=True) | |