kavg commited on
Commit
12af45e
β€’
1 Parent(s): 42cab8f

Added error handling

Browse files
Files changed (2) hide show
  1. download_model.ipynb +20 -23
  2. main.py +36 -12
download_model.ipynb CHANGED
@@ -59,12 +59,24 @@
59
  },
60
  {
61
  "cell_type": "code",
62
- "execution_count": 4,
63
  "metadata": {},
64
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
65
  "source": [
66
  "# download the model\n",
67
- "MODEL = \"pierreguillou/lilt-xlm-roberta-base-finetuned-funsd-iob-original\"\n",
68
  "model = LiltForTokenClassification.from_pretrained(MODEL)\n",
69
  "\n",
70
  "# save the model\n",
@@ -83,28 +95,13 @@
83
  "cell_type": "code",
84
  "execution_count": 5,
85
  "metadata": {},
86
- "outputs": [
87
- {
88
- "name": "stderr",
89
- "output_type": "stream",
90
- "text": [
91
- "Downloading config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 794/794 [00:00<00:00, 61.2kB/s]\n",
92
- "d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Gihantha Kavishka\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
93
- "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
94
- " warnings.warn(message)\n",
95
- "Downloading pytorch_model.bin: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.15G/1.15G [08:10<00:00, 2.34MB/s]\n",
96
- "Some weights of the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re were not used when initializing LiltModel: ['extractor.rel_classifier.linear.weight', 'extractor.entity_emb.weight', 'extractor.ffnn_tail.0.weight', 'extractor.ffnn_tail.3.bias', 'extractor.ffnn_head.3.weight', 'extractor.ffnn_head.0.weight', 'extractor.ffnn_tail.0.bias', 'extractor.ffnn_head.3.bias', 'extractor.rel_classifier.bilinear.weight', 'extractor.rel_classifier.linear.bias', 'extractor.ffnn_head.0.bias', 'extractor.ffnn_tail.3.weight']\n",
97
- "- This IS expected if you are initializing LiltModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
98
- "- This IS NOT expected if you are initializing LiltModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
99
- "Some weights of LiltModel were not initialized from the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re and are newly initialized: ['lilt.pooler.dense.bias', 'lilt.pooler.dense.weight']\n",
100
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
101
- ]
102
- }
103
- ],
104
  "source": [
 
 
105
  "# download the model\n",
106
- "MODEL = 'kavg/layoutxlm-finetuned-xfund-fr-re'\n",
107
- "model = LiltModel.from_pretrained(MODEL)\n",
108
  "\n",
109
  "# save the model\n",
110
  "save_dir = \"models/lilt-re\"\n",
 
59
  },
60
  {
61
  "cell_type": "code",
62
+ "execution_count": 3,
63
  "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "name": "stderr",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "Downloading config.json: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.13k/1.13k [00:00<00:00, 283kB/s]\n",
70
+ "d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Gihantha Kavishka\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
71
+ "To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
72
+ " warnings.warn(message)\n",
73
+ "Downloading model.safetensors: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1.13G/1.13G [08:02<00:00, 2.35MB/s]\n"
74
+ ]
75
+ }
76
+ ],
77
  "source": [
78
  "# download the model\n",
79
+ "MODEL = \"kavg/LiLT-SER-Sin\"\n",
80
  "model = LiltForTokenClassification.from_pretrained(MODEL)\n",
81
  "\n",
82
  "# save the model\n",
 
95
  "cell_type": "code",
96
  "execution_count": 5,
97
  "metadata": {},
98
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  "source": [
100
+ "from models import LiLTRobertaLikeForRelationExtraction\n",
101
+ "\n",
102
  "# download the model\n",
103
+ "MODEL = 'kavg/LiLT-RE-IT-Sin'\n",
104
+ "model = LiLTRobertaLikeForRelationExtraction.from_pretrained(MODEL)\n",
105
  "\n",
106
  "# save the model\n",
107
  "save_dir = \"models/lilt-re\"\n",
main.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  from transformers import LiltForTokenClassification, AutoTokenizer
6
  import token_classification
7
  import torch
8
- from fastapi import FastAPI, UploadFile, Form
9
  from contextlib import asynccontextmanager
10
  import json
11
  import io
@@ -32,25 +32,49 @@ app = FastAPI(lifespan=lifespan)
32
  @app.post("/submit-doc")
33
  async def ProcessDocument(file: UploadFile):
34
  content = await file.read()
35
- tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
36
- reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
 
 
 
 
 
 
37
  return reOutput
38
 
39
  @app.post("/submit-doc-base64")
40
  async def ProcessDocument(file: str = Form(...)):
41
- head, file = file.split(',')
42
- str_as_bytes = str.encode(file)
43
- content = b64decode(str_as_bytes)
44
- tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
45
- reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
 
 
 
 
 
 
 
 
 
46
  return reOutput
47
 
48
- def LabelTokens(content):
49
- image = Image.open(io.BytesIO(content))
50
- ocr_df = config['vision_client'].ocr(content, image)
 
 
 
 
 
 
 
 
 
51
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
52
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
53
- return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, ocr_df, image.size
54
 
55
  def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
56
  token_labels = tokenClassificationOutput['token_labels']
 
5
  from transformers import LiltForTokenClassification, AutoTokenizer
6
  import token_classification
7
  import torch
8
+ from fastapi import FastAPI, UploadFile, Form, HTTPException
9
  from contextlib import asynccontextmanager
10
  import json
11
  import io
 
32
  @app.post("/submit-doc")
33
  async def ProcessDocument(file: UploadFile):
34
  content = await file.read()
35
+ ocr_df, image = ApplyOCR(content)
36
+ if len(ocr_df) < 2:
37
+ raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
38
+ try:
39
+ tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
40
+ reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
41
+ except:
42
+ raise HTTPException(status_code=400, detail="Invalid Image")
43
  return reOutput
44
 
45
  @app.post("/submit-doc-base64")
46
  async def ProcessDocument(file: str = Form(...)):
47
+ try:
48
+ head, file = file.split(',')
49
+ str_as_bytes = str.encode(file)
50
+ content = b64decode(str_as_bytes)
51
+ except:
52
+ raise HTTPException(status_code=400, detail="Invalid image")
53
+ ocr_df, image = ApplyOCR(content)
54
+ if len(ocr_df) < 2:
55
+ raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
56
+ try:
57
+ tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
58
+ reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
59
+ except:
60
+ raise HTTPException(status_code=400, detail="Invalid Image")
61
  return reOutput
62
 
63
+ def ApplyOCR(content):
64
+ try:
65
+ image = Image.open(io.BytesIO(content))
66
+ except:
67
+ raise HTTPException(status_code=400, detail="Invalid image")
68
+ try:
69
+ ocr_df = config['vision_client'].ocr(content, image)
70
+ except:
71
+ raise HTTPException(status_code=400, detail="OCR process failed")
72
+ return ocr_df, image
73
+
74
+ def LabelTokens(ocr_df, image):
75
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
76
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
77
+ return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
78
 
79
  def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
80
  token_labels = tokenClassificationOutput['token_labels']