File size: 6,521 Bytes
a228fac
 
 
 
0dd8e27
a228fac
 
12af45e
a228fac
 
 
 
0d655c9
1be0846
 
a228fac
 
 
 
 
 
 
613ad82
a228fac
0dd8e27
a228fac
 
1be0846
a228fac
 
 
 
 
 
7b11b8d
 
 
 
a228fac
 
0d655c9
12af45e
 
 
 
 
 
16f0a9b
 
 
 
ccb5ac8
 
16f0a9b
a228fac
 
9d6cf42
 
12af45e
 
 
 
 
 
 
 
 
 
 
ccb5ac8
 
 
 
12af45e
ccb5ac8
 
 
0d655c9
 
12af45e
 
 
 
 
1be0846
 
 
 
75c54a9
1be0846
 
 
16f0a9b
 
 
613ad82
16f0a9b
 
 
 
1be0846
 
16f0a9b
 
 
 
 
 
 
1be0846
16f0a9b
12af45e
 
1be0846
12af45e
a228fac
 
12af45e
a228fac
0dd8e27
a228fac
 
 
0dd8e27
a228fac
0dd8e27
16f0a9b
0dd8e27
 
 
 
 
 
 
 
 
 
 
a228fac
0dd8e27
a228fac
 
 
 
0dd8e27
 
 
 
 
 
 
a228fac
0dd8e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode 
from handwritting_detection import DetectHandwritting
import pandas as pd
config = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = Settings()
    config['settings'] = settings
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    config['processor'] = Preprocessor(settings.TOKENIZER)
    config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
    config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
    config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
    config['TROCR_API'] = settings.TROCR_API_URL
    yield
    # Clean up and release the resources
    config.clear()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def api_home():
    return {'detail': 'Welcome to Sri-Doc space'}

@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
  content = await file.read()
  ocr_df, image = ApplyOCR(content)
  if len(ocr_df) < 2:
    raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
  try:
    tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
  except:
    raise HTTPException(status_code=400, detail="Entity identification failed")
  
  try:
   reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  except Exception as e:
    print(e)
    raise HTTPException(status_code=400, detail="Relation extraction failed")
  return reOutput

@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
  try:
    head, file = file.split(',')
    str_as_bytes = str.encode(file)
    content = b64decode(str_as_bytes)
  except:
    raise HTTPException(status_code=400, detail="Invalid image")
  ocr_df, image = ApplyOCR(content)
  if len(ocr_df) < 2:
    raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
  try:
    tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
  except Exception as e:
    print(e)
    raise HTTPException(status_code=400, detail="Entity identification failed")
  try:
    reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  except Exception as e:
    print(e)
    raise HTTPException(status_code=400, detail="Relation extraction failed")
  return reOutput

def ApplyOCR(content):
  try:
    image = Image.open(io.BytesIO(content))
  except:
    raise HTTPException(status_code=400, detail="Invalid image")

  try:
    printed_img, handwritten_imgs = DetectHandwritting(image)
  except:
    raise HTTPException(status_code=400, detail="Handwritting detection failed")

  try:
    jpeg_bytes = io.BytesIO()
    printed_img.save(jpeg_bytes, format='PNG')
    # printed_img.save('temp/printed_text_image.jpeg', format='PNG')
    printed_content = jpeg_bytes.getvalue()
    vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
    printed_ocr_df = vision_client.ocr(printed_content, printed_img)
    # printed_ocr_df.to_csv('temp/complete_image_ocr.csv', index=False)
    # return  printed_ocr_df, image
  except Exception as e:
    raise HTTPException(status_code=400, detail="Printed OCR process failed")
  
  try:
    trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
    handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
  except Exception as e:
    print(e)
    raise HTTPException(status_code=400, detail="handwritten OCR process failed")
  
  ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
  # ocr_df = printed_ocr_df
  return ocr_df, image


def LabelTokens(ocr_df, image):
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size

def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
  token_labels = tokenClassificationOutput['token_labels']
  input_ids = tokenClassificationOutput['input_ids']
  attention_mask = tokenClassificationOutput["attention_mask"]
  bbox_org = tokenClassificationOutput["bbox"]

  merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)

  entities = merged_output['entities']
  input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
  bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
  attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])

  id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
  decoded_entities = []
  for entity in entities:
    decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
    entity['label'] = id2label[entity['label']]

  config['re_model'].to(config['device'])
  entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
  with torch.no_grad():
    outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)

  decoded_pred_relations = []
  for relation in outputs.pred_relations[0]:
    head_start, head_end = relation['head']
    tail_start, tail_end = relation['tail']
    question =  config['tokenizer'].decode(input_ids[0][head_start:head_end])
    answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
    decoded_pred_relations.append((question, answer))

  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}