kavg commited on
Commit
1be0846
1 Parent(s): 613ad82

implemented two ocr methods

Browse files
Files changed (4) hide show
  1. config.py +4 -0
  2. handwritting_detection.py +41 -0
  3. main.py +23 -2
  4. ocr.py +60 -1
config.py CHANGED
@@ -7,3 +7,7 @@ class Settings(BaseSettings):
7
  SER_MODEL: str
8
  TOKENIZER: str
9
  RE_MODEL: str
 
 
 
 
 
7
  SER_MODEL: str
8
  TOKENIZER: str
9
  RE_MODEL: str
10
+ ROBOFLOW_API_KEY: str
11
+ ROBOFLOW_URL: str
12
+ YOLO_MODEL_ID: str
13
+ TROCR_API_URL: str
handwritting_detection.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference_sdk import InferenceHTTPClient
2
+ from config import Settings
3
+ from PIL import Image, ImageDraw
4
+
5
+ def draw_rectangle(image, x, y, width, height, **kwargs):
6
+ # Create a draw object
7
+ draw = ImageDraw.Draw(image)
8
+ # Calculate the top-left and bottom-right corners of the rectangle
9
+ x1 = x - width // 2
10
+ y1 = y - height // 2
11
+ x2 = x1 + width
12
+ y2 = y1 + height
13
+
14
+ # Draw the rectangle
15
+ draw.rectangle(((x1, y1), (x2, y2)), fill=(255, 255, 255))
16
+ return image
17
+
18
+ def crop_image(image, x, y, width, height, **kwargs):
19
+ # Calculate the top-left and bottom-right corners of the cropping area
20
+ left = x - width // 2
21
+ top = y - height // 2
22
+ right = left + width
23
+ bottom = top + height
24
+
25
+ # Crop the image
26
+ cropped_image = image.crop((left, top, right, bottom))
27
+ return cropped_image, left, top, (right-left), (bottom-top)
28
+
29
+ def DetectHandwritting(image):
30
+ settings = Settings()
31
+ CLIENT = InferenceHTTPClient(
32
+ api_url=settings.ROBOFLOW_URL,
33
+ api_key=settings.ROBOFLOW_API_KEY
34
+ )
35
+ result = CLIENT.infer(image, model_id=settings.YOLO_MODEL_ID)
36
+ cpy = image.copy()
37
+ handwritten_parts = []
38
+ for prediction in result['predictions']:
39
+ cpy = draw_rectangle(cpy, **prediction)
40
+ handwritten_parts.append(crop_image(cpy, **prediction))
41
+ return cpy, handwritten_parts
main.py CHANGED
@@ -11,6 +11,8 @@ import json
11
  import io
12
  from models import LiLTRobertaLikeForRelationExtraction
13
  from base64 import b64decode
 
 
14
  config = {}
15
 
16
  @asynccontextmanager
@@ -23,6 +25,7 @@ async def lifespan(app: FastAPI):
23
  config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
24
  config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
25
  config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
 
26
  yield
27
  # Clean up and release the resources
28
  config.clear()
@@ -69,13 +72,31 @@ def ApplyOCR(content):
69
  image = Image.open(io.BytesIO(content))
70
  except:
71
  raise HTTPException(status_code=400, detail="Invalid image")
 
72
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
74
- ocr_df = vision_client.ocr(content, image)
75
  except:
76
- raise HTTPException(status_code=400, detail="OCR process failed")
 
 
77
  return ocr_df, image
78
 
 
79
  def LabelTokens(ocr_df, image):
80
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
81
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
 
11
  import io
12
  from models import LiLTRobertaLikeForRelationExtraction
13
  from base64 import b64decode
14
+ from handwritting_detection import DetectHandwritting
15
+ import pandas as pd
16
  config = {}
17
 
18
  @asynccontextmanager
 
25
  config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
26
  config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
27
  config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
28
+ config['TROCR_API'] = settings.TROCR_API_URL
29
  yield
30
  # Clean up and release the resources
31
  config.clear()
 
72
  image = Image.open(io.BytesIO(content))
73
  except:
74
  raise HTTPException(status_code=400, detail="Invalid image")
75
+
76
  try:
77
+ printed_img, handwritten_imgs = DetectHandwritting(image)
78
+ except:
79
+ raise HTTPException(status_code=400, detail="Handwritten OCR failed")
80
+
81
+ try:
82
+ trocr_client = ocr.TrOCRClientClient(config['settings'].TROCR_API_URL)
83
+ handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
84
+ except:
85
+ raise HTTPException(status_code=400, detail="handwritten OCR process failed")
86
+
87
+ try:
88
+ jpeg_bytes = io.BytesIO()
89
+ printed_img.save(jpeg_bytes, format='JPEG')
90
+ jpeg_content = jpeg_bytes.getvalue()
91
  vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
92
+ printed_ocr_df = vision_client.ocr(jpeg_content, printed_img)
93
  except:
94
+ raise HTTPException(status_code=400, detail="Printed OCR process failed")
95
+
96
+ ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
97
  return ocr_df, image
98
 
99
+
100
  def LabelTokens(ocr_df, image):
101
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
102
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
ocr.py CHANGED
@@ -6,6 +6,7 @@ import json
6
  import numpy as np
7
  from PIL import Image
8
  import io
 
9
 
10
  image_ext = ("*.jpg", "*.jpeg", "*.png")
11
 
@@ -86,4 +87,62 @@ class VisionClient:
86
  resp_js = self.get_response(content)
87
  boxObjects = self.post_process(resp_js)
88
  ocr_df = self.convert_to_df(boxObjects, image)
89
- return ocr_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
  from PIL import Image
8
  import io
9
+ import requests
10
 
11
  image_ext = ("*.jpg", "*.jpeg", "*.png")
12
 
 
87
  resp_js = self.get_response(content)
88
  boxObjects = self.post_process(resp_js)
89
  ocr_df = self.convert_to_df(boxObjects, image)
90
+ return ocr_df
91
+
92
+
93
+ class TrOCRClient():
94
+ def __init__(self, api_url):
95
+ self.api_url = api_url
96
+
97
+ def convert_to_df(self, boxObjects, image):
98
+ ocr_df = pd.DataFrame(boxObjects)
99
+
100
+ # ocr_df = ocr_df.sort_values(by=['top', 'left'], ascending=True).reset_index(drop=True)
101
+ width, height = image.size
102
+ w_scale = 1000/width
103
+ h_scale = 1000/height
104
+
105
+ ocr_df = ocr_df.dropna() \
106
+ .assign(left_scaled = ocr_df.left*w_scale,
107
+ width_scaled = ocr_df.width*w_scale,
108
+ top_scaled = ocr_df.top*h_scale,
109
+ height_scaled = ocr_df.height*h_scale,
110
+ right_scaled = lambda x: x.left_scaled + x.width_scaled,
111
+ bottom_scaled = lambda x: x.top_scaled + x.height_scaled)
112
+
113
+ float_cols = ocr_df.select_dtypes('float').columns
114
+ ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
115
+ ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
116
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
117
+ return ocr_df
118
+
119
+ def send_request(self, handwritten_img):
120
+ jpeg_bytes = io.BytesIO()
121
+ handwritten_img.save(jpeg_bytes, format='JPEG')
122
+ jpeg_content = jpeg_bytes.getvalue()
123
+ # Send a POST request with the image file
124
+ response = requests.post(self.api_url, files={"file": jpeg_content})
125
+ # Check the response status code
126
+ if response.status_code == 200:
127
+ # Get the extracted text from the response
128
+ extracted_text = response.json()["text"]
129
+ print(extracted_text)
130
+ else:
131
+ print(f"Error: {response.text}")
132
+
133
+ def ocr(self, handwritten_imgs, image):
134
+ boxObjects = []
135
+ for i in len(handwritten_imgs):
136
+ handwritten_img = handwritten_imgs[i]
137
+ ocr_result = self.send_request(handwritten_img[0])
138
+ boxObjects.append({
139
+ "id": i-1,
140
+ "text": ocr_result,
141
+ "left": handwritten_img[1],
142
+ "width": handwritten_img[3],
143
+ "top": handwritten_img[2],
144
+ "height":handwritten_img[4]
145
+ })
146
+ ocr_df = self.convert_to_df(boxObjects, image)
147
+ return ocr_df
148
+