Added error handling
Browse files- download_model.ipynb +20 -23
- main.py +36 -12
download_model.ipynb
CHANGED
@@ -59,12 +59,24 @@
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
-
"execution_count":
|
63 |
"metadata": {},
|
64 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
"source": [
|
66 |
"# download the model\n",
|
67 |
-
"MODEL = \"
|
68 |
"model = LiltForTokenClassification.from_pretrained(MODEL)\n",
|
69 |
"\n",
|
70 |
"# save the model\n",
|
@@ -83,28 +95,13 @@
|
|
83 |
"cell_type": "code",
|
84 |
"execution_count": 5,
|
85 |
"metadata": {},
|
86 |
-
"outputs": [
|
87 |
-
{
|
88 |
-
"name": "stderr",
|
89 |
-
"output_type": "stream",
|
90 |
-
"text": [
|
91 |
-
"Downloading config.json: 100%|ββββββββββ| 794/794 [00:00<00:00, 61.2kB/s]\n",
|
92 |
-
"d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Gihantha Kavishka\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
|
93 |
-
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
|
94 |
-
" warnings.warn(message)\n",
|
95 |
-
"Downloading pytorch_model.bin: 100%|ββββββββββ| 1.15G/1.15G [08:10<00:00, 2.34MB/s]\n",
|
96 |
-
"Some weights of the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re were not used when initializing LiltModel: ['extractor.rel_classifier.linear.weight', 'extractor.entity_emb.weight', 'extractor.ffnn_tail.0.weight', 'extractor.ffnn_tail.3.bias', 'extractor.ffnn_head.3.weight', 'extractor.ffnn_head.0.weight', 'extractor.ffnn_tail.0.bias', 'extractor.ffnn_head.3.bias', 'extractor.rel_classifier.bilinear.weight', 'extractor.rel_classifier.linear.bias', 'extractor.ffnn_head.0.bias', 'extractor.ffnn_tail.3.weight']\n",
|
97 |
-
"- This IS expected if you are initializing LiltModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
98 |
-
"- This IS NOT expected if you are initializing LiltModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
99 |
-
"Some weights of LiltModel were not initialized from the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re and are newly initialized: ['lilt.pooler.dense.bias', 'lilt.pooler.dense.weight']\n",
|
100 |
-
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
101 |
-
]
|
102 |
-
}
|
103 |
-
],
|
104 |
"source": [
|
|
|
|
|
105 |
"# download the model\n",
|
106 |
-
"MODEL = 'kavg/
|
107 |
-
"model =
|
108 |
"\n",
|
109 |
"# save the model\n",
|
110 |
"save_dir = \"models/lilt-re\"\n",
|
|
|
59 |
},
|
60 |
{
|
61 |
"cell_type": "code",
|
62 |
+
"execution_count": 3,
|
63 |
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"name": "stderr",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"Downloading config.json: 100%|ββββββββββ| 1.13k/1.13k [00:00<00:00, 283kB/s]\n",
|
70 |
+
"d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Gihantha Kavishka\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
|
71 |
+
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
|
72 |
+
" warnings.warn(message)\n",
|
73 |
+
"Downloading model.safetensors: 100%|ββββββββββ| 1.13G/1.13G [08:02<00:00, 2.35MB/s]\n"
|
74 |
+
]
|
75 |
+
}
|
76 |
+
],
|
77 |
"source": [
|
78 |
"# download the model\n",
|
79 |
+
"MODEL = \"kavg/LiLT-SER-Sin\"\n",
|
80 |
"model = LiltForTokenClassification.from_pretrained(MODEL)\n",
|
81 |
"\n",
|
82 |
"# save the model\n",
|
|
|
95 |
"cell_type": "code",
|
96 |
"execution_count": 5,
|
97 |
"metadata": {},
|
98 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
"source": [
|
100 |
+
"from models import LiLTRobertaLikeForRelationExtraction\n",
|
101 |
+
"\n",
|
102 |
"# download the model\n",
|
103 |
+
"MODEL = 'kavg/LiLT-RE-IT-Sin'\n",
|
104 |
+
"model = LiLTRobertaLikeForRelationExtraction.from_pretrained(MODEL)\n",
|
105 |
"\n",
|
106 |
"# save the model\n",
|
107 |
"save_dir = \"models/lilt-re\"\n",
|
main.py
CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
|
|
5 |
from transformers import LiltForTokenClassification, AutoTokenizer
|
6 |
import token_classification
|
7 |
import torch
|
8 |
-
from fastapi import FastAPI, UploadFile, Form
|
9 |
from contextlib import asynccontextmanager
|
10 |
import json
|
11 |
import io
|
@@ -32,25 +32,49 @@ app = FastAPI(lifespan=lifespan)
|
|
32 |
@app.post("/submit-doc")
|
33 |
async def ProcessDocument(file: UploadFile):
|
34 |
content = await file.read()
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
return reOutput
|
38 |
|
39 |
@app.post("/submit-doc-base64")
|
40 |
async def ProcessDocument(file: str = Form(...)):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
return reOutput
|
47 |
|
48 |
-
def
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
52 |
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
|
53 |
-
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask},
|
54 |
|
55 |
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
56 |
token_labels = tokenClassificationOutput['token_labels']
|
|
|
5 |
from transformers import LiltForTokenClassification, AutoTokenizer
|
6 |
import token_classification
|
7 |
import torch
|
8 |
+
from fastapi import FastAPI, UploadFile, Form, HTTPException
|
9 |
from contextlib import asynccontextmanager
|
10 |
import json
|
11 |
import io
|
|
|
32 |
@app.post("/submit-doc")
|
33 |
async def ProcessDocument(file: UploadFile):
|
34 |
content = await file.read()
|
35 |
+
ocr_df, image = ApplyOCR(content)
|
36 |
+
if len(ocr_df) < 2:
|
37 |
+
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
|
38 |
+
try:
|
39 |
+
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
|
40 |
+
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
41 |
+
except:
|
42 |
+
raise HTTPException(status_code=400, detail="Invalid Image")
|
43 |
return reOutput
|
44 |
|
45 |
@app.post("/submit-doc-base64")
|
46 |
async def ProcessDocument(file: str = Form(...)):
|
47 |
+
try:
|
48 |
+
head, file = file.split(',')
|
49 |
+
str_as_bytes = str.encode(file)
|
50 |
+
content = b64decode(str_as_bytes)
|
51 |
+
except:
|
52 |
+
raise HTTPException(status_code=400, detail="Invalid image")
|
53 |
+
ocr_df, image = ApplyOCR(content)
|
54 |
+
if len(ocr_df) < 2:
|
55 |
+
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
|
56 |
+
try:
|
57 |
+
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
|
58 |
+
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
|
59 |
+
except:
|
60 |
+
raise HTTPException(status_code=400, detail="Invalid Image")
|
61 |
return reOutput
|
62 |
|
63 |
+
def ApplyOCR(content):
|
64 |
+
try:
|
65 |
+
image = Image.open(io.BytesIO(content))
|
66 |
+
except:
|
67 |
+
raise HTTPException(status_code=400, detail="Invalid image")
|
68 |
+
try:
|
69 |
+
ocr_df = config['vision_client'].ocr(content, image)
|
70 |
+
except:
|
71 |
+
raise HTTPException(status_code=400, detail="OCR process failed")
|
72 |
+
return ocr_df, image
|
73 |
+
|
74 |
+
def LabelTokens(ocr_df, image):
|
75 |
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
76 |
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
|
77 |
+
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
|
78 |
|
79 |
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
|
80 |
token_labels = tokenClassificationOutput['token_labels']
|