|
from config import Settings |
|
from preprocess import Preprocessor |
|
import ocr |
|
from PIL import Image |
|
from transformers import LiltForTokenClassification, AutoTokenizer |
|
import token_classification |
|
import torch |
|
from fastapi import FastAPI, UploadFile, Form, HTTPException |
|
from contextlib import asynccontextmanager |
|
import json |
|
import io |
|
from models import LiLTRobertaLikeForRelationExtraction |
|
from base64 import b64decode |
|
from handwritting_detection import DetectHandwritting |
|
import pandas as pd |
|
config = {} |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
settings = Settings() |
|
config['settings'] = settings |
|
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
config['processor'] = Preprocessor(settings.TOKENIZER) |
|
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER) |
|
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL) |
|
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL) |
|
config['TROCR_API'] = settings.TROCR_API_URL |
|
yield |
|
|
|
config.clear() |
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
@app.get("/") |
|
def api_home(): |
|
return {'detail': 'Welcome to Sri-Doc space'} |
|
|
|
@app.post("/submit-doc") |
|
async def ProcessDocument(file: UploadFile): |
|
content = await file.read() |
|
ocr_df, image = ApplyOCR(content) |
|
if len(ocr_df) < 2: |
|
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image") |
|
try: |
|
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image) |
|
except: |
|
raise HTTPException(status_code=400, detail="Entity identification failed") |
|
|
|
try: |
|
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size) |
|
except Exception as e: |
|
print(e) |
|
raise HTTPException(status_code=400, detail="Relation extraction failed") |
|
return reOutput |
|
|
|
@app.post("/submit-doc-base64") |
|
async def ProcessDocument(file: str = Form(...)): |
|
try: |
|
head, file = file.split(',') |
|
str_as_bytes = str.encode(file) |
|
content = b64decode(str_as_bytes) |
|
except: |
|
raise HTTPException(status_code=400, detail="Invalid image") |
|
ocr_df, image = ApplyOCR(content) |
|
if len(ocr_df) < 2: |
|
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image") |
|
try: |
|
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image) |
|
except Exception as e: |
|
print(e) |
|
raise HTTPException(status_code=400, detail="Entity identification failed") |
|
try: |
|
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size) |
|
except Exception as e: |
|
print(e) |
|
raise HTTPException(status_code=400, detail="Relation extraction failed") |
|
return reOutput |
|
|
|
def ApplyOCR(content): |
|
try: |
|
image = Image.open(io.BytesIO(content)) |
|
except: |
|
raise HTTPException(status_code=400, detail="Invalid image") |
|
|
|
try: |
|
printed_img, handwritten_imgs = DetectHandwritting(image) |
|
except: |
|
raise HTTPException(status_code=400, detail="Handwritting detection failed") |
|
|
|
try: |
|
jpeg_bytes = io.BytesIO() |
|
printed_img.save(jpeg_bytes, format='PNG') |
|
|
|
printed_content = jpeg_bytes.getvalue() |
|
vision_client = ocr.VisionClient(config['settings'].GCV_AUTH) |
|
printed_ocr_df = vision_client.ocr(printed_content, printed_img) |
|
|
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=400, detail="Printed OCR process failed") |
|
|
|
try: |
|
trocr_client = ocr.TrOCRClient(config['settings'].TROCR_API_URL) |
|
handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image) |
|
except Exception as e: |
|
print(e) |
|
raise HTTPException(status_code=400, detail="handwritten OCR process failed") |
|
|
|
ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df]) |
|
|
|
return ocr_df, image |
|
|
|
|
|
def LabelTokens(ocr_df, image): |
|
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image) |
|
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping) |
|
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size |
|
|
|
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size): |
|
token_labels = tokenClassificationOutput['token_labels'] |
|
input_ids = tokenClassificationOutput['input_ids'] |
|
attention_mask = tokenClassificationOutput["attention_mask"] |
|
bbox_org = tokenClassificationOutput["bbox"] |
|
|
|
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org) |
|
|
|
entities = merged_output['entities'] |
|
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device']) |
|
bbox = torch.tensor([merged_output['bbox']]).to(config['device']) |
|
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device']) |
|
|
|
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2} |
|
decoded_entities = [] |
|
for entity in entities: |
|
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']]))) |
|
entity['label'] = id2label[entity['label']] |
|
|
|
config['re_model'].to(config['device']) |
|
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]} |
|
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}] |
|
with torch.no_grad(): |
|
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations) |
|
|
|
decoded_pred_relations = [] |
|
for relation in outputs.pred_relations[0]: |
|
head_start, head_end = relation['head'] |
|
tail_start, tail_end = relation['tail'] |
|
question = config['tokenizer'].decode(input_ids[0][head_start:head_end]) |
|
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end]) |
|
decoded_pred_relations.append((question, answer)) |
|
|
|
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)} |