saad noor commited on
Commit
dab2f85
1 Parent(s): cd5e9c8

init commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ yoloenv/
0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png ADDED
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import torch
4
+ import os
5
+ from tqdm import tqdm
6
+ # import wandb
7
+ from ultralytics import YOLO
8
+ import cv2
9
+ import numpy as np
10
+ import pandas as pd
11
+ from skimage.transform import resize
12
+ from skimage import img_as_bool
13
+ from skimage.morphology import convex_hull_image
14
+ import json
15
+
16
+ # wandb.init(mode='disabled')
17
+
18
+ def tableConvexHull(img, masks):
19
+ mask=np.zeros(masks[0].shape,dtype="bool")
20
+ for msk in masks:
21
+ temp=msk.cpu().detach().numpy();
22
+ chull = convex_hull_image(temp);
23
+ mask=np.bitwise_or(mask,chull)
24
+ return mask
25
+
26
+ def cls_exists(clss, cls):
27
+ indices = torch.where(clss==cls)
28
+ return len(indices[0])>0
29
+
30
+ def empty_mask(img):
31
+ mask = np.zeros(img.shape[:2], dtype="uint8")
32
+ return np.array(mask, dtype=bool)
33
+
34
+ def extract_img_mask(img_model, img, config):
35
+ res_dict = {
36
+ 'status' : 1
37
+ }
38
+ res = get_predictions(img_model, img, config)
39
+
40
+ if res['status']==-1:
41
+ res_dict['status'] = -1
42
+
43
+ elif res['status']==0:
44
+ res_dict['mask']=empty_mask(img)
45
+
46
+ else:
47
+ masks = res['masks']
48
+ boxes = res['boxes']
49
+ clss = boxes[:, 5]
50
+ mask = extract_mask(img, masks, boxes, clss, 0)
51
+ res_dict['mask'] = mask
52
+ return res_dict
53
+
54
+ def get_predictions(model, img2, config):
55
+ res_dict = {
56
+ 'status': 1
57
+ }
58
+ try:
59
+ for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\
60
+ imgsz=config['sz'], conf=config['conf'], stream=True,\
61
+ classes=config['classes']):
62
+ try:
63
+ res_dict['masks'] = result.masks.data
64
+ res_dict['boxes'] = result.boxes.data
65
+ del result
66
+ return res_dict
67
+ except Exception as e:
68
+ res_dict['status'] = 0
69
+ return res_dict
70
+ except:
71
+ res_dict['status'] = -1
72
+ return res_dict
73
+
74
+ def extract_mask(img, masks, boxes, clss, cls):
75
+ if not cls_exists(clss, cls):
76
+ return empty_mask(img)
77
+ indices = torch.where(clss==cls)
78
+ c_masks = masks[indices]
79
+ mask_arr = torch.any(c_masks, dim=0).bool()
80
+ mask_arr = mask_arr.cpu().detach().numpy()
81
+ mask = mask_arr
82
+ return mask
83
+
84
+
85
+ def get_masks(img, model, img_model, flags, configs):
86
+ response = {
87
+ 'status': 1
88
+ }
89
+ ans_masks = []
90
+ img2 = img
91
+
92
+
93
+ # ***** Getting paragraph and text masks
94
+ res = get_predictions(model, img2, configs['paratext'])
95
+ if res['status']==-1:
96
+ response['status'] = -1
97
+ return response
98
+ elif res['status']==0:
99
+ for i in range(2): ans_masks.append(empty_mask(img))
100
+ else:
101
+ masks, boxes = res['masks'], res['boxes']
102
+ clss = boxes[:, 5]
103
+ for cls in range(2):
104
+ mask = extract_mask(img, masks, boxes, clss, cls)
105
+ ans_masks.append(mask)
106
+
107
+
108
+ # ***** Getting image and table masks
109
+ res2 = get_predictions(model, img2, configs['imgtab'])
110
+ if res2['status']==-1:
111
+ response['status'] = -1
112
+ return response
113
+ elif res2['status']==0:
114
+ for i in range(2): ans_masks.append(empty_mask(img))
115
+ else:
116
+ masks, boxes = res2['masks'], res2['boxes']
117
+ clss = boxes[:, 5]
118
+
119
+ if cls_exists(clss, 2):
120
+ img_res = extract_img_mask(img_model, img, configs['image'])
121
+ if img_res['status'] == 1:
122
+ img_mask = img_res['mask']
123
+ else:
124
+ response['status'] = -1
125
+ return response
126
+
127
+ else:
128
+ img_mask = empty_mask(img)
129
+ ans_masks.append(img_mask)
130
+
131
+ if cls_exists(clss, 3):
132
+ indices = torch.where(clss==3)
133
+ tbl_mask = tableConvexHull(img, masks[indices])
134
+ else:
135
+ tbl_mask = empty_mask(img)
136
+ ans_masks.append(tbl_mask)
137
+
138
+ if not configs['paratext']['rm']:
139
+ h, w, c = img.shape
140
+ for i in range(4):
141
+ ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w)))
142
+
143
+
144
+ response['masks'] = ans_masks
145
+ return response
146
+
147
+ def overlay(image, mask, color, alpha, resize=None):
148
+ """Combines image and its segmentation mask into a single image.
149
+ https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay
150
+
151
+ Params:
152
+ image: Training image. np.ndarray,
153
+ mask: Segmentation mask. np.ndarray,
154
+ color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0)
155
+ alpha: Segmentation mask's transparency. float = 0.5,
156
+ resize: If provided, both image and its mask are resized before blending them together.
157
+ tuple[int, int] = (1024, 1024))
158
+
159
+ Returns:
160
+ image_combined: The combined image. np.ndarray
161
+
162
+ """
163
+ color = color[::-1]
164
+ colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
165
+ colored_mask = np.moveaxis(colored_mask, 0, -1)
166
+ masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
167
+ image_overlay = masked.filled()
168
+
169
+ if resize is not None:
170
+ image = cv2.resize(image.transpose(1, 2, 0), resize)
171
+ image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
172
+
173
+ image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
174
+
175
+ return image_combined
176
+
177
+
178
+
179
+
180
+ general_model_path = 'e50_aug.pt'
181
+ image_model_path = 'e100_img.pt'
182
+
183
+ general_model = YOLO(general_model_path)
184
+ image_model = YOLO(image_model_path)
185
+
186
+ sample_path = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png']
187
+
188
+ flags = {
189
+ 'hist': False,
190
+ 'bz': False
191
+ }
192
+
193
+
194
+ configs = {}
195
+ configs['paratext'] = {
196
+ 'sz' : 640,
197
+ 'conf': 0.25,
198
+ 'rm': True,
199
+ 'classes': [0, 1]
200
+ }
201
+ configs['imgtab'] = {
202
+ 'sz' : 640,
203
+ 'conf': 0.35,
204
+ 'rm': True,
205
+ 'classes': [2, 3]
206
+ }
207
+ configs['image'] = {
208
+ 'sz' : 640,
209
+ 'conf': 0.35,
210
+ 'rm': True,
211
+ 'classes': [0]
212
+ }
213
+
214
+ def evaluate(img_path, model=general_model, img_model=image_model,\
215
+ configs=configs, flags=flags):
216
+ print('starting')
217
+ img = cv2.imread(img_path)
218
+ res = get_masks(img, general_model, image_model, flags, configs)
219
+ if res['status']==-1:
220
+ for idx in configs.keys():
221
+ configs[idx]['rm'] = False
222
+ return evaluate(img, model, img_model, flags, configs)
223
+ else:
224
+ masks = res['masks']
225
+
226
+ color_map = {
227
+ 0 : (255, 0, 0),
228
+ 1 : (0, 255, 0),
229
+ 2 : (0, 0, 255),
230
+ 3 : (255, 255, 0),
231
+ }
232
+ for i, mask in enumerate(masks):
233
+ img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4)
234
+ print('finishing')
235
+ return img
236
+
237
+ # output = evaluate(img_path=sample_path, model=general_model, img_model=image_model,\
238
+ # configs=configs, flags=flags)
239
+
240
+
241
+ inputs_img = [
242
+ gr.components.Video(type="filepath", label="Input Video"),
243
+
244
+ ]
245
+ outputs_img = [
246
+ gr.components.Image(type="numpy", label="Output Image"),
247
+ ]
248
+
249
+ inputs_image = [
250
+ gr.components.Image(type="filepath", label="Input Image"),
251
+ ]
252
+ outputs_image = [
253
+ gr.components.Image(type="numpy", label="Output Image"),
254
+ ]
255
+ interface_image = gr.Interface(
256
+ fn=evaluate,
257
+ inputs=inputs_image,
258
+ outputs=outputs_image,
259
+ title="Document Layout Segmentor",
260
+ examples=sample_path,
261
+ cache_examples=True,
262
+ )
e100_img.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7424265a528fd1a2f741bb48a3586e69496de55f14e4a4c5ba867e83c2d159f8
3
+ size 54786656
e50_aug.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12dba7a7156750342fb35ef2305a0bffa31615258aced63811e9220990f1f0a3
3
+ size 54792992
epoch50hgeq2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40c00f2b620f539f9054bd17f4fbda064782aa64c089f1c366a607189a112acf
3
+ size 218670661
raytuneYolo50epoch.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:971d22657b3a263a44150bbcb9a2a0726e15c3460a0f6a4810ae949c623bc5fa
3
+ size 54793056
requirements.txt ADDED
Binary file (2.46 kB). View file