File size: 6,521 Bytes
a228fac 0dd8e27 a228fac 12af45e a228fac 0d655c9 1be0846 a228fac 613ad82 a228fac 0dd8e27 a228fac 1be0846 a228fac 7b11b8d a228fac 0d655c9 12af45e 16f0a9b ccb5ac8 16f0a9b a228fac 9d6cf42 12af45e ccb5ac8 12af45e ccb5ac8 0d655c9 12af45e 1be0846 75c54a9 1be0846 16f0a9b 613ad82 16f0a9b 1be0846 16f0a9b 1be0846 16f0a9b 12af45e 1be0846 12af45e a228fac 12af45e a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 16f0a9b 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode
from handwritting_detection import DetectHandwritting
import pandas as pd
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
config['settings'] = settings
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['processor'] = Preprocessor(settings.TOKENIZER)
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
config['TROCR_API'] = settings.TROCR_API_URL
yield
# Clean up and release the resources
config.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def api_home():
return {'detail': 'Welcome to Sri-Doc space'}
@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
content = await file.read()
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
except:
raise HTTPException(status_code=400, detail="Entity identification failed")
try:
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="Relation extraction failed")
return reOutput
@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
try:
head, file = file.split(',')
str_as_bytes = str.encode(file)
content = b64decode(str_as_bytes)
except:
raise HTTPException(status_code=400, detail="Invalid image")
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="Entity identification failed")
try:
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="Relation extraction failed")
return reOutput
def ApplyOCR(content):
try:
image = Image.open(io.BytesIO(content))
except:
raise HTTPException(status_code=400, detail="Invalid image")
try:
printed_img, handwritten_imgs = DetectHandwritting(image)
except:
raise HTTPException(status_code=400, detail="Handwritting detection failed")
try:
jpeg_bytes = io.BytesIO()
printed_img.save(jpeg_bytes, format='PNG')
# printed_img.save('temp/printed_text_image.jpeg', format='PNG')
printed_content = jpeg_bytes.getvalue()
vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
printed_ocr_df = vision_client.ocr(printed_content, printed_img)
# printed_ocr_df.to_csv('temp/complete_image_ocr.csv', index=False)
# return printed_ocr_df, image
except Exception as e:
raise HTTPException(status_code=400, detail="Printed OCR process failed")
try:
trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="handwritten OCR process failed")
ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
# ocr_df = printed_ocr_df
return ocr_df, image
def LabelTokens(ocr_df, image):
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
token_labels = tokenClassificationOutput['token_labels']
input_ids = tokenClassificationOutput['input_ids']
attention_mask = tokenClassificationOutput["attention_mask"]
bbox_org = tokenClassificationOutput["bbox"]
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
entities = merged_output['entities']
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
decoded_entities = []
for entity in entities:
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
entity['label'] = id2label[entity['label']]
config['re_model'].to(config['device'])
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
with torch.no_grad():
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
decoded_pred_relations = []
for relation in outputs.pred_relations[0]:
head_start, head_end = relation['head']
tail_start, tail_end = relation['tail']
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
decoded_pred_relations.append((question, answer))
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)} |