kavg commited on
Commit
0d655c9
1 Parent(s): 0dd8e27

new endpoint for mobile app that supports base64 encoding

Browse files
Files changed (2) hide show
  1. base64_test.ipynb +0 -0
  2. main.py +13 -13
base64_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
main.py CHANGED
@@ -5,11 +5,12 @@ from PIL import Image
5
  from transformers import LiltForTokenClassification, AutoTokenizer
6
  import token_classification
7
  import torch
8
- from fastapi import FastAPI, UploadFile
9
  from contextlib import asynccontextmanager
10
  import json
11
  import io
12
  from models import LiLTRobertaLikeForRelationExtraction
 
13
  config = {}
14
 
15
  @asynccontextmanager
@@ -30,12 +31,20 @@ app = FastAPI(lifespan=lifespan)
30
 
31
  @app.post("/submit-doc")
32
  async def ProcessDocument(file: UploadFile):
33
- tokenClassificationOutput, ocr_df, img_size = await LabelTokens(file)
 
34
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
35
  return reOutput
36
 
37
- async def LabelTokens(file):
38
- content = await file.read()
 
 
 
 
 
 
 
39
  image = Image.open(io.BytesIO(content))
40
  ocr_df = config['vision_client'].ocr(content, image)
41
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
@@ -74,14 +83,5 @@ def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
74
  question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
75
  answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
76
  decoded_pred_relations.append((question, answer))
77
- # print("Question:", question)
78
- # print("Answer:", answer)
79
- ## This prints bboxes of each question and answer
80
- # for item in merged_words:
81
- # if item['text'] == question:
82
- # print('Question', item['box'])
83
- # if item['text'] == answer:
84
- # print('Answer', item['box'])
85
- # print("----------")
86
 
87
  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}
 
5
  from transformers import LiltForTokenClassification, AutoTokenizer
6
  import token_classification
7
  import torch
8
+ from fastapi import FastAPI, UploadFile, Form
9
  from contextlib import asynccontextmanager
10
  import json
11
  import io
12
  from models import LiLTRobertaLikeForRelationExtraction
13
+ from base64 import b64decode
14
  config = {}
15
 
16
  @asynccontextmanager
 
31
 
32
  @app.post("/submit-doc")
33
  async def ProcessDocument(file: UploadFile):
34
+ content = await file.read()
35
+ tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
36
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
37
  return reOutput
38
 
39
+ @app.post("/submit-doc-mobile")
40
+ async def ProcessDocument(base64str: str = Form(...)):
41
+ str_as_bytes = str.encode(base64str)
42
+ content = b64decode(str_as_bytes)
43
+ tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
44
+ reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
45
+ return reOutput
46
+
47
+ def LabelTokens(content):
48
  image = Image.open(io.BytesIO(content))
49
  ocr_df = config['vision_client'].ocr(content, image)
50
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
 
83
  question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
84
  answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
85
  decoded_pred_relations.append((question, answer))
 
 
 
 
 
 
 
 
 
86
 
87
  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}