Spaces:
Running
Running
madhavkotecha
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pycparser.ply.yacc import token
|
2 |
+
from ultralytics import YOLO
|
3 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoModelForCausalLM, pipeline, AutoModelForMaskedLM
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from nltk.translate import bleu_score
|
8 |
+
from nltk.translate.bleu_score import SmoothingFunction
|
9 |
+
import torch
|
10 |
+
|
11 |
+
yolo_weights_path = "final_wts.pt"
|
12 |
+
|
13 |
+
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
14 |
+
|
15 |
+
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
|
16 |
+
trocr_model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten').to(device)
|
17 |
+
trocr_model.config.num_beams = 1
|
18 |
+
|
19 |
+
yolo_model = YOLO(yolo_weights_path).to('mps')
|
20 |
+
unmasker_large = pipeline('fill-mask', model='roberta-large', device=device)
|
21 |
+
roberta_model = AutoModelForMaskedLM.from_pretrained("roberta-large").to(device)
|
22 |
+
|
23 |
+
print(f'TrOCR and YOLO Models loaded on {device}')
|
24 |
+
|
25 |
+
|
26 |
+
-------------------------------------------------------
|
27 |
+
|
28 |
+
|
29 |
+
CONFIDENCE_THRESHOLD = 0.72
|
30 |
+
BLEU_THRESHOLD = 0.6
|
31 |
+
|
32 |
+
|
33 |
+
def inference(image_path, debug=False, return_texts='final'):
|
34 |
+
def get_cropped_images(image_path):
|
35 |
+
results = yolo_model(image_path, save=True)
|
36 |
+
patches = []
|
37 |
+
ys = []
|
38 |
+
for box in sorted(results[0].boxes, key=lambda x: x.xywh[0][1]):
|
39 |
+
image = Image.open(image_path).convert("RGB")
|
40 |
+
x_center, y_center, w, h = box.xywh[0].cpu().numpy()
|
41 |
+
x, y = x_center - w / 2, y_center - h / 2
|
42 |
+
cropped_image = image.crop((x, y, x + w, y + h))
|
43 |
+
patches.append(cropped_image)
|
44 |
+
ys.append(y)
|
45 |
+
bounding_box_path = results[0].save_dir + results[0].path[results[0].path.rindex('/'):-4] + '.jpg'
|
46 |
+
return patches, ys, bounding_box_path
|
47 |
+
|
48 |
+
def get_model_output(images):
|
49 |
+
pixel_values = processor(images=images, return_tensors="pt").pixel_values.to(device)
|
50 |
+
output = trocr_model.generate(pixel_values, return_dict_in_generate=True, output_logits=True, max_new_tokens=30)
|
51 |
+
generated_texts = processor.batch_decode(output.sequences, skip_special_tokens=True)
|
52 |
+
generated_tokens = [processor.tokenizer.convert_ids_to_tokens(seq) for seq in output.sequences]
|
53 |
+
stacked_logits = torch.stack(output.logits, dim=1)
|
54 |
+
return generated_texts, stacked_logits, generated_tokens
|
55 |
+
|
56 |
+
def get_scores(logits):
|
57 |
+
scores = logits.softmax(-1).max(-1).values.mean(-1)
|
58 |
+
return scores
|
59 |
+
|
60 |
+
def post_process_texts(generated_texts):
|
61 |
+
for i in range(len(generated_texts)):
|
62 |
+
if len(generated_texts[i]) > 2 and generated_texts[i][:2] == '# ':
|
63 |
+
generated_texts[i] = generated_texts[i][2:]
|
64 |
+
|
65 |
+
if len(generated_texts[i]) > 2 and generated_texts[i][-2:] == ' #':
|
66 |
+
generated_texts[i] = generated_texts[i][:-2]
|
67 |
+
return generated_texts
|
68 |
+
|
69 |
+
def get_qualified_texts(generated_texts, scores, y, logits, tokens):
|
70 |
+
qualified_texts = []
|
71 |
+
for text, score, y_i, logits_i, tokens_i in zip(generated_texts, scores, y, logits, tokens):
|
72 |
+
if score > CONFIDENCE_THRESHOLD:
|
73 |
+
qualified_texts.append({
|
74 |
+
'text': text,
|
75 |
+
'score': score,
|
76 |
+
'y': y_i,
|
77 |
+
'logits': logits_i,
|
78 |
+
'tokens': tokens_i
|
79 |
+
})
|
80 |
+
return qualified_texts
|
81 |
+
|
82 |
+
def get_adjacent_bleu_scores(qualified_texts):
|
83 |
+
def get_bleu_score(hypothesis, references):
|
84 |
+
weights = [0.5, 0.5]
|
85 |
+
smoothing = SmoothingFunction()
|
86 |
+
return bleu_score.sentence_bleu(references, hypothesis, weights=weights,
|
87 |
+
smoothing_function=smoothing.method1)
|
88 |
+
|
89 |
+
for i in range(len(qualified_texts)):
|
90 |
+
hyp = qualified_texts[i]['text'].split()
|
91 |
+
bleu = 0
|
92 |
+
if i < len(qualified_texts) - 1:
|
93 |
+
ref = qualified_texts[i + 1]['text'].split()
|
94 |
+
bleu = get_bleu_score(hyp, [ref])
|
95 |
+
qualified_texts[i]['bleu'] = bleu
|
96 |
+
return qualified_texts
|
97 |
+
|
98 |
+
def remove_overlapping_texts(qualified_texts):
|
99 |
+
final_texts = []
|
100 |
+
new = True
|
101 |
+
for i in range(len(qualified_texts)):
|
102 |
+
if new:
|
103 |
+
final_texts.append(qualified_texts[i])
|
104 |
+
else:
|
105 |
+
if final_texts[-1]['score'] < qualified_texts[i]['score']:
|
106 |
+
final_texts[-1] = qualified_texts[i]
|
107 |
+
new = qualified_texts[i]['bleu'] < BLEU_THRESHOLD
|
108 |
+
return final_texts
|
109 |
+
|
110 |
+
cropped_images, y, bounding_box_path = get_cropped_images(image_path)
|
111 |
+
if debug:
|
112 |
+
print('Number of cropped images:', len(cropped_images))
|
113 |
+
generated_texts, logits, gen_tokens = get_model_output(cropped_images)
|
114 |
+
normalised_scores = get_scores(logits)
|
115 |
+
if return_texts == 'generated':
|
116 |
+
return pd.DataFrame({
|
117 |
+
'text': generated_texts,
|
118 |
+
'score': normalised_scores,
|
119 |
+
'y': y,
|
120 |
+
})
|
121 |
+
generated_texts = post_process_texts(generated_texts)
|
122 |
+
if return_texts == 'post_processed':
|
123 |
+
return pd.DataFrame({
|
124 |
+
'text': generated_texts,
|
125 |
+
'score': normalised_scores,
|
126 |
+
'y': y
|
127 |
+
})
|
128 |
+
qualified_texts = get_qualified_texts(generated_texts, normalised_scores, y, logits, gen_tokens)
|
129 |
+
if return_texts == 'qualified':
|
130 |
+
return pd.DataFrame(qualified_texts)
|
131 |
+
qualified_texts = get_adjacent_bleu_scores(qualified_texts)
|
132 |
+
if return_texts == 'qualified_with_bleu':
|
133 |
+
return pd.DataFrame(qualified_texts)
|
134 |
+
final_texts = remove_overlapping_texts(qualified_texts)
|
135 |
+
final_texts_df = pd.DataFrame(final_texts, columns=['text', 'score', 'y'])
|
136 |
+
final_tokens = [text['tokens'] for text in final_texts]
|
137 |
+
final_logits = [text['logits'] for text in final_texts]
|
138 |
+
if return_texts == 'final':
|
139 |
+
return final_texts_df
|
140 |
+
|
141 |
+
return final_texts_df, bounding_box_path, final_tokens, final_logits, generated_texts
|
142 |
+
|
143 |
+
|
144 |
+
image_path = "raw_dataset/g06-037h.png"
|
145 |
+
df, bounding_path, tokens, logits, gen_texts = inference(image_path, debug=False, return_texts='final_v2')
|
146 |
+
|
147 |
+
----------------------------------------------------------------
|
148 |
+
|
149 |
+
|
150 |
+
def get_new_logits(tokens):
|
151 |
+
inputs = tokens.reshape(1, -1)
|
152 |
+
# Get the logits from the model
|
153 |
+
with torch.no_grad():
|
154 |
+
outputs = roberta_model(input_ids=inputs, attention_mask=torch.ones(inputs.shape).to(device))
|
155 |
+
logits = outputs.logits
|
156 |
+
|
157 |
+
|
158 |
+
logits_flattened = logits.reshape(-1, slogits.shape[-1])
|
159 |
+
print(processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True))
|
160 |
+
return logits.reshape(tokens.shape + (logits.shape[-1],))
|
161 |
+
|
162 |
+
|
163 |
+
slogits = torch.stack([logit for logit in logits], dim=0)
|
164 |
+
tokens = slogits.argmax(-1)
|
165 |
+
confidence = slogits.softmax(-1).max(-1).values
|
166 |
+
indices = torch.where(confidence < 0.5)
|
167 |
+
# put 50264(mask) when confidence < 0.5
|
168 |
+
for i, j in zip(indices[0], indices[1]):
|
169 |
+
if i != 6:
|
170 |
+
continue
|
171 |
+
tokens[i, j] = torch.tensor(50264)
|
172 |
+
|
173 |
+
new_logits = get_new_logits(tokens)
|
174 |
+
|
175 |
+
|
176 |
+
----------------------------------------------------------------
|
177 |
+
|
178 |
+
|
179 |
+
for i, j in zip(indices[0], indices[1]):
|
180 |
+
slogits[i, j] = slogits[i, j] * 0.1 + new_logits[i, j] * 0.5
|
181 |
+
|
182 |
+
logits_flattened = slogits.reshape(-1, slogits.shape[-1])
|
183 |
+
processor.batch_decode([logits_flattened.argmax(-1)], skip_special_tokens=True)
|
184 |
+
|
185 |
+
|