##### VQA MED Demo import gradio as gr from transformers import ViltProcessor, ViltForQuestionAnswering import torch import torch.nn as nn from transformers import CLIPTokenizer from CLIP import clip from Transformers_for_Caption import Transformer_Caption import numpy as np import torchvision.transforms as transforms device = "cuda" if torch.cuda.is_available() else "cpu" class Config(object): def __init__(self): # Learning Rates # Transformer self.hidden_dim = 512 self.pad_token_id = 0 self.max_position_embeddings = 76 self.layer_norm_eps = 1e-12 self.dropout = 0.1 self.vocab_size = 49408 self.enc_layers = 1 self.dec_layers = 1 self.dim_feedforward = 1024 #2048 self.nheads = 4 self.pre_norm = True # Dataset #self.dir = os.getcwd() + '/data/coco' self.limit = -1 ##### OUR MODEL class VQA_Net(nn.Module): def __init__(self, num_classes): super(VQA_Net,self).__init__() #self.VIT = deit_base_distilled_patch16_224(pretrained=True) #self.VIT =vit_base_patch16_224_dino(pretrained=True) #self.VIT = vit_base_patch32_sam_224(pretrained=True) ###### please not that we used only 6 layers #self.VIT=maxvit_rmlp_nano_rw_256(pretrained=True) #self.VIT = vit_base_patch8_224(pretrained=True) #self.VIT=m = tf_efficientnetv2_m(pretrained=True, features_only=True, out_indices=(1,3), feature_location='expansion') self.backbone, _ = clip.load('ViT-B/32', device, jit=False) self.input_proj = nn.LayerNorm(512) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1)) self.transformer_decoder = Transformer_Caption(config,num_decoder_layers=2) self.mlp = nn.Sequential(nn.Sequential(nn.Linear(512, num_classes))) # MLP(256, 512, 30522, 1) 49408) #self.samples_proj = nn.Sequential(nn.Linear(768,512)) self.samples_proj = nn.Identity() self.question_proj = nn.Identity() #nn.Sequential(nn.Linear(512, 512,bias=False)) # nn.Sequential(nn.LayerNorm(768),nn.Linear(768,768),nn.GELU(),nn.Dropout(0.1)) #self.tokenizer=CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") def forward(self, samples, question_in, answer_out, mask_answer): # print('Here') #print(samples.shape) _, _, samples = self.backbone.encode_image(samples) #samples=self.VIT(samples) #print(samples.shape) samples=samples.float() #samples = self.VIT(samples) #print(`samples.shape) #samples = samples.view(-1, 512, 8 * 8) # print(img_seq.shape) #samples = samples.permute(0, 2, 1) #samples=samples[:,0:,:] @ self.samples_proj samples = self.samples_proj(samples) #print(samples.shape) #print(samples.shape) _, _,question_in = self.backbone.encode_text(question_in) #print(question_in.shape) #samples = self.samples_proj(samples.float()) question_in = self.question_proj(question_in.float()) #print(question_in.shape) #print(samples.shape) samples = torch.cat((samples, question_in), dim=1) #print(samples.shape) # src, mask = features[-1].decompose() # assert mask is not None hs = self.transformer_decoder(self.input_proj(samples.permute(1, 0, 2).float()), answer_out, tgt_mask=mask_answer) out = self.mlp(hs.permute(1, 0, 2)) # print(out.shape) return out config = Config() Tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") My_VQA = VQA_Net(num_classes=len(Tokenizer)) My_VQA.load_state_dict(torch.load("./PathVQA_2Decoders_1024_30iterations_Trial4_CLIPVIT32.pth.tar",map_location= torch.device(device))) tfms = transforms.Compose([ #transforms.Lambda(under_max), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # transforms.Normalize(0.5, 0.5), ]) def answer_question(image, text_question): with torch.no_grad(): for iter in range(1): start_token = Tokenizer.convert_tokens_to_ids("<|startoftext|>") # end_token = Tokenizer.convert_tokens_to_ids("<|endoftext|>") # start_token=tokenizer.convert_tokens_to_ids(tokenizer._cls_token) caption = torch.zeros((1, config.max_position_embeddings), dtype=torch.long) cap_mask = torch.ones((1, config.max_position_embeddings), dtype=torch.bool) caption[:, 0] = start_token cap_mask[:, 0] = False if text_question.find('?') > -1: text_question = text_question.split('?')[0].lower() text_question= np.array(Tokenizer.encode_plus(text_question, max_length=77, pad_to_max_length=True,return_attention_mask=True, return_token_type_ids=False, truncation=True)['input_ids']) #print(torch.Tensor(text_question).unsqueeze(0).long()) for i in range(config.max_position_embeddings - 1): predictions = My_VQA(image.unsqueeze(0),torch.Tensor(text_question).unsqueeze(0).long(), caption,cap_mask) predictions = predictions[:, i, :] predicted_id = torch.argmax(predictions, axis=-1) caption[:, i + 1] = predicted_id[0] cap_mask[:, i + 1] = False if predicted_id[0] == 49407: break #print('question:') #print(batch_test['question']) cap_result_intermediate = Tokenizer.decode(caption[0].tolist(), skip_special_tokens=True) #print('+++++++++++++++++++++++++++++++++++') #print("True:") # print(ref_sentence) cap_result = cap_result_intermediate.split('!') #ref_sentence = batch_test['answer'].lower() #print(ref_sentence) #print("Predict:") #print(cap_result) # image_disp=inv_Normalize(batch_test['image'])[0].permute(1,2,0).detach().cpu().numpy() # print('************************') # plt.imshow(image_disp) return cap_result def infer_answer_question(image, text): if text is None: cap_result = "please write a question" elif image is None: cap_result = "please upload an image" else: image_encoded = tfms(image) cap_result=answer_question(image_encoded,text)[0] return cap_result image = gr.Image(type="pil") question = gr.Textbox(label="Question") answer = gr.Textbox(label="Predicted answer") examples = [["train_0000.jpg", "Where are liver stem cells (oval cells) located?"], ["train_0001.jpg", "What are stained here with an immunohistochemical stain for cytokeratin 7?"], ["train_0002.jpg", "What are bile duct cells and canals of Hering stained here with for cytokeratin 7?"], ["train_0003.jpg", "Are bile duct cells and canals of Hering stained here with an immunohistochemical stain for cytokeratin 7?"], ["train_0018.jpg", "Is there an infarct in the brain hypertrophy?"], ["train_0019.jpg", "What is ischemic coagulative necrosis?"]] title = "Vision–Language Model for Visual Question Answering in Medical Imagery" description = "Y Bazi, MMA Rahhal, L Bashmal, M Zuair. Vision–Language Model for Visual Question Answering in Medical Imagery. Bioengineering, 2023

"\ "Gradio Demo for VQA medical model trained on PathVQA dataset, To use it, upload your image and type a question and click 'submit', or click one of the examples to load them." \ ### link to paper and github code website = "" article = f"

BigMed@KSU

" interface = gr.Interface(fn=infer_answer_question, inputs=[image, question], outputs=answer, examples=examples, title=title, description=description, article=article) interface.launch(debug=True, enable_queue=True)