iakarshu commited on
Commit
0089672
·
1 Parent(s): 0695986

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements.txt
2
+ import gradio as gr
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
+ from torch import cuda
5
+ from utils import convert_ans_to_token, convert_ques_to_token, rotate, convert_token_to_ques, convert_token_to_answer
6
+ from modeling import LaTr_for_pretraining, LaTr_for_finetuning, LaTrForVQA
7
+ from dataset import load_json_file, get_specific_file, resize_align_bbox, get_tokens_with_boxes, create_features
8
+ import torch.nn as nn
9
+ from PIL import Image, ImageDraw
10
+ import pytesseract
11
+ import pandas as pd
12
+ from tqdm.auto import tqdm
13
+ import numpy as np
14
+ import json
15
+ import os
16
+
17
+
18
+ # install PyTesseract
19
+ os.system('pip install -q pytesseract')
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+
22
+
23
+ # Default Library import
24
+
25
+
26
+ # For the purpose of displaying the progress of map function
27
+ tqdm.pandas()
28
+
29
+ # Visualization libraries
30
+
31
+ # Specific libraries of LaTr
32
+
33
+ # Setting the hyperparameters as well as primary configurations
34
+
35
+ PAD_TOKEN_BOX = [0, 0, 0, 0]
36
+ max_seq_len = 512
37
+ batch_size = 2
38
+ target_size = (500, 384)
39
+ t5_model = "t5-base"
40
+
41
+
42
+ device = 'cuda' if cuda.is_available() else 'cpu'
43
+
44
+
45
+ # Configuration for the model
46
+ config = {
47
+ 't5_model': 't5-base',
48
+ 'vocab_size': 32128,
49
+ 'hidden_state': 768,
50
+ 'max_2d_position_embeddings': 1001,
51
+ 'classes': 32128, # number of tokens
52
+ 'seq_len': 512
53
+ }
54
+
55
+ tokenizer = T5Tokenizer.from_pretrained(t5_model)
56
+ latr = LaTrForVQA(config, max_steps=max_steps)
57
+ url = 'https://www.kaggleusercontent.com/kf/99663112/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..tGHcmnLDazeyRNWAxV-KDQ.6unLNRwl7AyVy0Qz3ONE1m_mRNmgC-8VGyS61PdkSeBMV7PpG2B1cD5liuLlok5LQiYrGujrULdtIXKTqCUU_PA3MMSRhi1VKkGMdtrzJLMvzA4jxlWh_qak8P89w4ir4LENyuPCan24M0MOLXYjrm4d1iiy4Hg8pp2o5zWgs0OrVYoh_AJNazOD7pRIjLEAqnM-Pa0LSmvJkfN7j3Zn_Fu9jJ7Pq3Z0rWVtEb-PbeY06f9t-0QK6-JU8K2LdQjuBaCxjgB3BlufgFhKuhU3CZXsJitG7tDnwMSl4JImGfMmBntE2kn9-0dl_aANxaQd2Lsy8KGUDNAdQ2vBpowGQ0-tgDT_w7DpG6DzmUlmzIegqJF1-JyurCO0TrX_RatoPa7jGzuqA5vUT4263-MkoAlR0Xuulq4_pwGV-WnJsrcLuuDtEKFVsYjQvikWM3c9Arw0MsXchYCQkl_OZ6ZqYZ6TZrYxujHE2B6nHxu0F-5xj33vQ2ojaMpHtDplTnqCe4TdmzRWV6LhopfL4x1NXIXry8we4IqgPPwnIy3G2lZVR39nPmNR-8IGjbvweVr6Ci6y1COdbLR4JiTMVc_Nvf2glVKRjppTdcEwLv-j1YR8JsZpZvjaOEokrNkyCG7J0PLJAHlY8iX-pRdBG4vivbSHxnKl3Qppa689VH0RARpOsOBYv-IF-rM1nSmKq7Ci.tXi1B0oNQFlUtxesMcma3w/models/epoch=0-step=34602.ckpt'
58
+
59
+
60
+ try:
61
+ latr = latr.load_from_checkpoint(url)
62
+ print("Checkpoint loaded successfully")
63
+ except:
64
+ print("Checkpoint not loaded")
65
+ pass
66
+
67
+
68
+ image = gr.inputs.Image(type="pil")
69
+ question = gr.inputs.Textbox(label="Question")
70
+ answer = gr.outputs.Textbox(label="Predicted answer")
71
+ examples = [["remote.jpg", "what number is the button near the top left?"]]
72
+
73
+
74
+ def answer_question(image, question):
75
+ image.save('sample_img.jpg')
76
+
77
+ # Extracting features from the image
78
+ img, boxes, tokenized_words = create_features(image_path='sample_img.jpg',
79
+ tokenizer=tokenizer,
80
+ target_size=target_size,
81
+ max_seq_length=max_seq_length,
82
+ use_ocr=True
83
+ )
84
+
85
+ ## Converting the boxes as per the format required for model input
86
+ boxes = torch.as_tensor(boxes, dtype=torch.int32)
87
+ width = (boxes[:, 2] - boxes[:, 0]).view(-1, 1)
88
+ height = (boxes[:, 3] - boxes[:, 1]).view(-1, 1)
89
+ boxes = torch.cat([boxes, width, height], axis = -1)
90
+
91
+ ## Clamping the value,as some of the box values are out of bound
92
+ boxes[:, 0] = torch.clamp(boxes[:, 0], min = 0, max = 0)
93
+ boxes[:, 2] = torch.clamp(boxes[:, 2], min = 1000, max = 1000)
94
+ boxes[:, 4] = torch.clamp(boxes[:, 4], min = 1000, max = 1000)
95
+
96
+ boxes[:, 1] = torch.clamp(boxes[:, 1], min = 0, max = 0)
97
+ boxes[:, 3] = torch.clamp(boxes[:, 3], min = 1000, max = 1000)
98
+ boxes[:, 5] = torch.clamp(boxes[:, 5], min = 1000, max = 1000)
99
+
100
+ ## Tensor tokenized words
101
+ tokenized_words = torch.as_tensor(tokenized_words, dtype=torch.int32)
102
+
103
+ img = transforms.ToTensor()(img)
104
+ question = convert_ques_to_token(question = question, tokenizer = tokenizer)
105
+
106
+ ## Expanding the dimension for inference
107
+ img = img.unsqueeze(0)
108
+ boxes = boxes.unsqueeze(0)
109
+ tokenized_words = tokenized_words.unsqueeze(0)
110
+ question = question.unsqueeze(0)
111
+
112
+ encoding = {'img': img, 'boxes': boxes, 'tokenized_words': tokenized_words, 'question': question}
113
+
114
+ with torch.no_grad():
115
+ logits = latr.forward(encoding)
116
+ logits = logits.squeeze(0)
117
+
118
+ _, preds = torch.max(logits, dim = 1)
119
+ preds = preds.detach().cpu()
120
+ mask = torch.clamp(preds, min = 0, max = 1)
121
+ last_non_zero_argument = (mask != 0).nonzero()[1][-1]
122
+
123
+ predicted_ans = convert_token_to_ques(individual_ans_pred[:last_non_zero_argument], tokenizer)
124
+ return predicted_ans
125
+
126
+
127
+ # Taken from here: https://huggingface.co/spaces/nielsr/vilt-vqa/blob/main/app.py
128
+ title = "Interactive demo: laTr (Layout Aware Transformer) for VQA"
129
+ description = "Gradio Demo for LaTr (Layout Aware Transformer),trained on TextVQA Dataset. To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
130
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.12494' target='_blank'>LaTr: Layout-aware transformer for scene-text VQA,a novel multimodal architecture for Scene Text Visual Question Answering (STVQA)</a> | <a href='https://github.com/uakarsh/latr' target='_blank'>Github Repo</a></p>"
131
+
132
+ interface = gr.Interface(fn=answer_question,
133
+ inputs=[image, question],
134
+ outputs=answer,
135
+ examples=examples,
136
+ title=title,
137
+ description=description,
138
+ article=article,
139
+ enable_queue=True)
140
+ interface.launch(debug=True)