Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Requirements.txt
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
4 |
+
from torch import cuda
|
5 |
+
from utils import convert_ans_to_token, convert_ques_to_token, rotate, convert_token_to_ques, convert_token_to_answer
|
6 |
+
from modeling import LaTr_for_pretraining, LaTr_for_finetuning, LaTrForVQA
|
7 |
+
from dataset import load_json_file, get_specific_file, resize_align_bbox, get_tokens_with_boxes, create_features
|
8 |
+
import torch.nn as nn
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
+
import pytesseract
|
11 |
+
import pandas as pd
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
import numpy as np
|
14 |
+
import json
|
15 |
+
import os
|
16 |
+
|
17 |
+
|
18 |
+
# install PyTesseract
|
19 |
+
os.system('pip install -q pytesseract')
|
20 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
+
|
22 |
+
|
23 |
+
# Default Library import
|
24 |
+
|
25 |
+
|
26 |
+
# For the purpose of displaying the progress of map function
|
27 |
+
tqdm.pandas()
|
28 |
+
|
29 |
+
# Visualization libraries
|
30 |
+
|
31 |
+
# Specific libraries of LaTr
|
32 |
+
|
33 |
+
# Setting the hyperparameters as well as primary configurations
|
34 |
+
|
35 |
+
PAD_TOKEN_BOX = [0, 0, 0, 0]
|
36 |
+
max_seq_len = 512
|
37 |
+
batch_size = 2
|
38 |
+
target_size = (500, 384)
|
39 |
+
t5_model = "t5-base"
|
40 |
+
|
41 |
+
|
42 |
+
device = 'cuda' if cuda.is_available() else 'cpu'
|
43 |
+
|
44 |
+
|
45 |
+
# Configuration for the model
|
46 |
+
config = {
|
47 |
+
't5_model': 't5-base',
|
48 |
+
'vocab_size': 32128,
|
49 |
+
'hidden_state': 768,
|
50 |
+
'max_2d_position_embeddings': 1001,
|
51 |
+
'classes': 32128, # number of tokens
|
52 |
+
'seq_len': 512
|
53 |
+
}
|
54 |
+
|
55 |
+
tokenizer = T5Tokenizer.from_pretrained(t5_model)
|
56 |
+
latr = LaTrForVQA(config, max_steps=max_steps)
|
57 |
+
url = 'https://www.kaggleusercontent.com/kf/99663112/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..tGHcmnLDazeyRNWAxV-KDQ.6unLNRwl7AyVy0Qz3ONE1m_mRNmgC-8VGyS61PdkSeBMV7PpG2B1cD5liuLlok5LQiYrGujrULdtIXKTqCUU_PA3MMSRhi1VKkGMdtrzJLMvzA4jxlWh_qak8P89w4ir4LENyuPCan24M0MOLXYjrm4d1iiy4Hg8pp2o5zWgs0OrVYoh_AJNazOD7pRIjLEAqnM-Pa0LSmvJkfN7j3Zn_Fu9jJ7Pq3Z0rWVtEb-PbeY06f9t-0QK6-JU8K2LdQjuBaCxjgB3BlufgFhKuhU3CZXsJitG7tDnwMSl4JImGfMmBntE2kn9-0dl_aANxaQd2Lsy8KGUDNAdQ2vBpowGQ0-tgDT_w7DpG6DzmUlmzIegqJF1-JyurCO0TrX_RatoPa7jGzuqA5vUT4263-MkoAlR0Xuulq4_pwGV-WnJsrcLuuDtEKFVsYjQvikWM3c9Arw0MsXchYCQkl_OZ6ZqYZ6TZrYxujHE2B6nHxu0F-5xj33vQ2ojaMpHtDplTnqCe4TdmzRWV6LhopfL4x1NXIXry8we4IqgPPwnIy3G2lZVR39nPmNR-8IGjbvweVr6Ci6y1COdbLR4JiTMVc_Nvf2glVKRjppTdcEwLv-j1YR8JsZpZvjaOEokrNkyCG7J0PLJAHlY8iX-pRdBG4vivbSHxnKl3Qppa689VH0RARpOsOBYv-IF-rM1nSmKq7Ci.tXi1B0oNQFlUtxesMcma3w/models/epoch=0-step=34602.ckpt'
|
58 |
+
|
59 |
+
|
60 |
+
try:
|
61 |
+
latr = latr.load_from_checkpoint(url)
|
62 |
+
print("Checkpoint loaded successfully")
|
63 |
+
except:
|
64 |
+
print("Checkpoint not loaded")
|
65 |
+
pass
|
66 |
+
|
67 |
+
|
68 |
+
image = gr.inputs.Image(type="pil")
|
69 |
+
question = gr.inputs.Textbox(label="Question")
|
70 |
+
answer = gr.outputs.Textbox(label="Predicted answer")
|
71 |
+
examples = [["remote.jpg", "what number is the button near the top left?"]]
|
72 |
+
|
73 |
+
|
74 |
+
def answer_question(image, question):
|
75 |
+
image.save('sample_img.jpg')
|
76 |
+
|
77 |
+
# Extracting features from the image
|
78 |
+
img, boxes, tokenized_words = create_features(image_path='sample_img.jpg',
|
79 |
+
tokenizer=tokenizer,
|
80 |
+
target_size=target_size,
|
81 |
+
max_seq_length=max_seq_length,
|
82 |
+
use_ocr=True
|
83 |
+
)
|
84 |
+
|
85 |
+
## Converting the boxes as per the format required for model input
|
86 |
+
boxes = torch.as_tensor(boxes, dtype=torch.int32)
|
87 |
+
width = (boxes[:, 2] - boxes[:, 0]).view(-1, 1)
|
88 |
+
height = (boxes[:, 3] - boxes[:, 1]).view(-1, 1)
|
89 |
+
boxes = torch.cat([boxes, width, height], axis = -1)
|
90 |
+
|
91 |
+
## Clamping the value,as some of the box values are out of bound
|
92 |
+
boxes[:, 0] = torch.clamp(boxes[:, 0], min = 0, max = 0)
|
93 |
+
boxes[:, 2] = torch.clamp(boxes[:, 2], min = 1000, max = 1000)
|
94 |
+
boxes[:, 4] = torch.clamp(boxes[:, 4], min = 1000, max = 1000)
|
95 |
+
|
96 |
+
boxes[:, 1] = torch.clamp(boxes[:, 1], min = 0, max = 0)
|
97 |
+
boxes[:, 3] = torch.clamp(boxes[:, 3], min = 1000, max = 1000)
|
98 |
+
boxes[:, 5] = torch.clamp(boxes[:, 5], min = 1000, max = 1000)
|
99 |
+
|
100 |
+
## Tensor tokenized words
|
101 |
+
tokenized_words = torch.as_tensor(tokenized_words, dtype=torch.int32)
|
102 |
+
|
103 |
+
img = transforms.ToTensor()(img)
|
104 |
+
question = convert_ques_to_token(question = question, tokenizer = tokenizer)
|
105 |
+
|
106 |
+
## Expanding the dimension for inference
|
107 |
+
img = img.unsqueeze(0)
|
108 |
+
boxes = boxes.unsqueeze(0)
|
109 |
+
tokenized_words = tokenized_words.unsqueeze(0)
|
110 |
+
question = question.unsqueeze(0)
|
111 |
+
|
112 |
+
encoding = {'img': img, 'boxes': boxes, 'tokenized_words': tokenized_words, 'question': question}
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
logits = latr.forward(encoding)
|
116 |
+
logits = logits.squeeze(0)
|
117 |
+
|
118 |
+
_, preds = torch.max(logits, dim = 1)
|
119 |
+
preds = preds.detach().cpu()
|
120 |
+
mask = torch.clamp(preds, min = 0, max = 1)
|
121 |
+
last_non_zero_argument = (mask != 0).nonzero()[1][-1]
|
122 |
+
|
123 |
+
predicted_ans = convert_token_to_ques(individual_ans_pred[:last_non_zero_argument], tokenizer)
|
124 |
+
return predicted_ans
|
125 |
+
|
126 |
+
|
127 |
+
# Taken from here: https://huggingface.co/spaces/nielsr/vilt-vqa/blob/main/app.py
|
128 |
+
title = "Interactive demo: laTr (Layout Aware Transformer) for VQA"
|
129 |
+
description = "Gradio Demo for LaTr (Layout Aware Transformer),trained on TextVQA Dataset. To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
|
130 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.12494' target='_blank'>LaTr: Layout-aware transformer for scene-text VQA,a novel multimodal architecture for Scene Text Visual Question Answering (STVQA)</a> | <a href='https://github.com/uakarsh/latr' target='_blank'>Github Repo</a></p>"
|
131 |
+
|
132 |
+
interface = gr.Interface(fn=answer_question,
|
133 |
+
inputs=[image, question],
|
134 |
+
outputs=answer,
|
135 |
+
examples=examples,
|
136 |
+
title=title,
|
137 |
+
description=description,
|
138 |
+
article=article,
|
139 |
+
enable_queue=True)
|
140 |
+
interface.launch(debug=True)
|