new endpoint for mobile app that supports base64 encoding
Browse files- base64_test.ipynb +0 -0
- main.py +13 -13
base64_test.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
main.py
CHANGED
@@ -5,11 +5,12 @@ from PIL import Image
|
|
5 |
from transformers import LiltForTokenClassification, AutoTokenizer
|
6 |
import token_classification
|
7 |
import torch
|
8 |
-
from fastapi import FastAPI, UploadFile
|
9 |
from contextlib import asynccontextmanager
|
10 |
import json
|
11 |
import io
|
12 |
from models import LiLTRobertaLikeForRelationExtraction
|
|
|
13 |
config = {}
|
14 |
|
15 |
@asynccontextmanager
|
@@ -30,12 +31,20 @@ app = FastAPI(lifespan=lifespan)
|
|
30 |
|
31 |
@app.post("/submit-doc")
|
32 |
async def ProcessDocument(file: UploadFile):
|
33 |
-
|
|
|
34 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
35 |
return reOutput
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
image = Image.open(io.BytesIO(content))
|
40 |
ocr_df = config['vision_client'].ocr(content, image)
|
41 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
@@ -74,14 +83,5 @@ def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
|
74 |
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
|
75 |
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
|
76 |
decoded_pred_relations.append((question, answer))
|
77 |
-
# print("Question:", question)
|
78 |
-
# print("Answer:", answer)
|
79 |
-
## This prints bboxes of each question and answer
|
80 |
-
# for item in merged_words:
|
81 |
-
# if item['text'] == question:
|
82 |
-
# print('Question', item['box'])
|
83 |
-
# if item['text'] == answer:
|
84 |
-
# print('Answer', item['box'])
|
85 |
-
# print("----------")
|
86 |
|
87 |
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
|
|
|
5 |
from transformers import LiltForTokenClassification, AutoTokenizer
|
6 |
import token_classification
|
7 |
import torch
|
8 |
+
from fastapi import FastAPI, UploadFile, Form
|
9 |
from contextlib import asynccontextmanager
|
10 |
import json
|
11 |
import io
|
12 |
from models import LiLTRobertaLikeForRelationExtraction
|
13 |
+
from base64 import b64decode
|
14 |
config = {}
|
15 |
|
16 |
@asynccontextmanager
|
|
|
31 |
|
32 |
@app.post("/submit-doc")
|
33 |
async def ProcessDocument(file: UploadFile):
|
34 |
+
content = await file.read()
|
35 |
+
tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
|
36 |
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
37 |
return reOutput
|
38 |
|
39 |
+
@app.post("/submit-doc-mobile")
|
40 |
+
async def ProcessDocument(base64str: str = Form(...)):
|
41 |
+
str_as_bytes = str.encode(base64str)
|
42 |
+
content = b64decode(str_as_bytes)
|
43 |
+
tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
|
44 |
+
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
45 |
+
return reOutput
|
46 |
+
|
47 |
+
def LabelTokens(content):
|
48 |
image = Image.open(io.BytesIO(content))
|
49 |
ocr_df = config['vision_client'].ocr(content, image)
|
50 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
|
|
83 |
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
|
84 |
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
|
85 |
decoded_pred_relations.append((question, answer))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
|