sri-doc / main.py
kavg's picture
fixed entity merging issue
ccb5ac8
raw
history blame
6.52 kB
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode
from handwritting_detection import DetectHandwritting
import pandas as pd
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
config['settings'] = settings
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['processor'] = Preprocessor(settings.TOKENIZER)
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
config['TROCR_API'] = settings.TROCR_API_URL
yield
# Clean up and release the resources
config.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def api_home():
return {'detail': 'Welcome to Sri-Doc space'}
@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
content = await file.read()
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
except:
raise HTTPException(status_code=400, detail="Entity identification failed")
try:
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="Relation extraction failed")
return reOutput
@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
try:
head, file = file.split(',')
str_as_bytes = str.encode(file)
content = b64decode(str_as_bytes)
except:
raise HTTPException(status_code=400, detail="Invalid image")
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="Entity identification failed")
try:
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="Relation extraction failed")
return reOutput
def ApplyOCR(content):
try:
image = Image.open(io.BytesIO(content))
except:
raise HTTPException(status_code=400, detail="Invalid image")
try:
printed_img, handwritten_imgs = DetectHandwritting(image)
except:
raise HTTPException(status_code=400, detail="Handwritting detection failed")
try:
jpeg_bytes = io.BytesIO()
printed_img.save(jpeg_bytes, format='PNG')
# printed_img.save('temp/printed_text_image.jpeg', format='PNG')
printed_content = jpeg_bytes.getvalue()
vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
printed_ocr_df = vision_client.ocr(printed_content, printed_img)
# printed_ocr_df.to_csv('temp/complete_image_ocr.csv', index=False)
# return printed_ocr_df, image
except Exception as e:
raise HTTPException(status_code=400, detail="Printed OCR process failed")
try:
trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL)
handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
except Exception as e:
print(e)
raise HTTPException(status_code=400, detail="handwritten OCR process failed")
ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
# ocr_df = printed_ocr_df
return ocr_df, image
def LabelTokens(ocr_df, image):
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
token_labels = tokenClassificationOutput['token_labels']
input_ids = tokenClassificationOutput['input_ids']
attention_mask = tokenClassificationOutput["attention_mask"]
bbox_org = tokenClassificationOutput["bbox"]
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
entities = merged_output['entities']
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
decoded_entities = []
for entity in entities:
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
entity['label'] = id2label[entity['label']]
config['re_model'].to(config['device'])
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
with torch.no_grad():
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
decoded_pred_relations = []
for relation in outputs.pred_relations[0]:
head_start, head_end = relation['head']
tail_start, tail_end = relation['tail']
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
decoded_pred_relations.append((question, answer))
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}