Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import torch | |
import os | |
from tqdm import tqdm | |
# import wandb | |
from ultralytics import YOLO | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
from skimage.transform import resize | |
from skimage import img_as_bool | |
from skimage.morphology import convex_hull_image | |
import json | |
# wandb.init(mode='disabled') | |
def tableConvexHull(img, masks): | |
mask=np.zeros(masks[0].shape,dtype="bool") | |
for msk in masks: | |
temp=msk.cpu().detach().numpy(); | |
chull = convex_hull_image(temp); | |
mask=np.bitwise_or(mask,chull) | |
return mask | |
def cls_exists(clss, cls): | |
indices = torch.where(clss==cls) | |
return len(indices[0])>0 | |
def empty_mask(img): | |
mask = np.zeros(img.shape[:2], dtype="uint8") | |
return np.array(mask, dtype=bool) | |
def extract_img_mask(img_model, img, config): | |
res_dict = { | |
'status' : 1 | |
} | |
res = get_predictions(img_model, img, config) | |
if res['status']==-1: | |
res_dict['status'] = -1 | |
elif res['status']==0: | |
res_dict['mask']=empty_mask(img) | |
else: | |
masks = res['masks'] | |
boxes = res['boxes'] | |
clss = boxes[:, 5] | |
mask = extract_mask(img, masks, boxes, clss, 0) | |
res_dict['mask'] = mask | |
return res_dict | |
def get_predictions(model, img2, config): | |
res_dict = { | |
'status': 1 | |
} | |
try: | |
for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\ | |
imgsz=config['sz'], conf=config['conf'], stream=True,\ | |
classes=config['classes']): | |
try: | |
res_dict['masks'] = result.masks.data | |
res_dict['boxes'] = result.boxes.data | |
del result | |
return res_dict | |
except Exception as e: | |
res_dict['status'] = 0 | |
return res_dict | |
except: | |
res_dict['status'] = -1 | |
return res_dict | |
def extract_mask(img, masks, boxes, clss, cls): | |
if not cls_exists(clss, cls): | |
return empty_mask(img) | |
indices = torch.where(clss==cls) | |
c_masks = masks[indices] | |
mask_arr = torch.any(c_masks, dim=0).bool() | |
mask_arr = mask_arr.cpu().detach().numpy() | |
mask = mask_arr | |
return mask | |
def get_masks(img, model, img_model, flags, configs): | |
response = { | |
'status': 1 | |
} | |
ans_masks = [] | |
img2 = img | |
# ***** Getting paragraph and text masks | |
res = get_predictions(model, img2, configs['paratext']) | |
if res['status']==-1: | |
response['status'] = -1 | |
return response | |
elif res['status']==0: | |
for i in range(2): ans_masks.append(empty_mask(img)) | |
else: | |
masks, boxes = res['masks'], res['boxes'] | |
clss = boxes[:, 5] | |
for cls in range(2): | |
mask = extract_mask(img, masks, boxes, clss, cls) | |
ans_masks.append(mask) | |
# ***** Getting image and table masks | |
res2 = get_predictions(model, img2, configs['imgtab']) | |
if res2['status']==-1: | |
response['status'] = -1 | |
return response | |
elif res2['status']==0: | |
for i in range(2): ans_masks.append(empty_mask(img)) | |
else: | |
masks, boxes = res2['masks'], res2['boxes'] | |
clss = boxes[:, 5] | |
if cls_exists(clss, 2): | |
img_res = extract_img_mask(img_model, img, configs['image']) | |
if img_res['status'] == 1: | |
img_mask = img_res['mask'] | |
else: | |
response['status'] = -1 | |
return response | |
else: | |
img_mask = empty_mask(img) | |
ans_masks.append(img_mask) | |
if cls_exists(clss, 3): | |
indices = torch.where(clss==3) | |
tbl_mask = tableConvexHull(img, masks[indices]) | |
else: | |
tbl_mask = empty_mask(img) | |
ans_masks.append(tbl_mask) | |
if not configs['paratext']['rm']: | |
h, w, c = img.shape | |
for i in range(4): | |
ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w))) | |
response['masks'] = ans_masks | |
return response | |
def overlay(image, mask, color, alpha, resize=None): | |
"""Combines image and its segmentation mask into a single image. | |
https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay | |
Params: | |
image: Training image. np.ndarray, | |
mask: Segmentation mask. np.ndarray, | |
color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0) | |
alpha: Segmentation mask's transparency. float = 0.5, | |
resize: If provided, both image and its mask are resized before blending them together. | |
tuple[int, int] = (1024, 1024)) | |
Returns: | |
image_combined: The combined image. np.ndarray | |
""" | |
color = color[::-1] | |
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) | |
colored_mask = np.moveaxis(colored_mask, 0, -1) | |
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) | |
image_overlay = masked.filled() | |
if resize is not None: | |
image = cv2.resize(image.transpose(1, 2, 0), resize) | |
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) | |
image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) | |
return image_combined | |
model_path = 'models' | |
general_model_name = 'e50_aug.pt' | |
image_model_name = 'e100_img.pt' | |
general_model = YOLO(os.path.join(model_path, general_model_name)) | |
image_model = YOLO(os.path.join(model_path, image_model_name)) | |
image_path = 'examples' | |
sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png', | |
'0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png'] | |
sample_path = [os.path.join(image_path, sample) for sample in sample_name] | |
flags = { | |
'hist': False, | |
'bz': False | |
} | |
configs = {} | |
configs['paratext'] = { | |
'sz' : 640, | |
'conf': 0.25, | |
'rm': True, | |
'classes': [0, 1] | |
} | |
configs['imgtab'] = { | |
'sz' : 640, | |
'conf': 0.35, | |
'rm': True, | |
'classes': [2, 3] | |
} | |
configs['image'] = { | |
'sz' : 640, | |
'conf': 0.35, | |
'rm': True, | |
'classes': [0] | |
} | |
def evaluate(img_path, model=general_model, img_model=image_model,\ | |
configs=configs, flags=flags): | |
# print('starting') | |
img = cv2.imread(img_path) | |
res = get_masks(img, general_model, image_model, flags, configs) | |
if res['status']==-1: | |
for idx in configs.keys(): | |
configs[idx]['rm'] = False | |
return evaluate(img, model, img_model, flags, configs) | |
else: | |
masks = res['masks'] | |
color_map = { | |
0 : (255, 0, 0), | |
1 : (0, 255, 0), | |
2 : (0, 0, 255), | |
3 : (255, 255, 0), | |
} | |
for i, mask in enumerate(masks): | |
img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4) | |
# print('finishing') | |
return img | |
# output = evaluate(img_path=sample_path, model=general_model, img_model=image_model,\ | |
# configs=configs, flags=flags) | |
inputs_image = [ | |
gr.components.Image(type="filepath", label="Input Image"), | |
] | |
outputs_image = [ | |
gr.components.Image(type="numpy", label="Output Image"), | |
] | |
interface_image = gr.Interface( | |
fn=evaluate, | |
inputs=inputs_image, | |
outputs=outputs_image, | |
title="Document Layout Segmentor", | |
examples=sample_path, | |
cache_examples=True, | |
).launch() |