Spaces:
Running
Running
import os | |
import os | |
os.system('pip install "detectron2@git+https://github.com/facebookresearch/[email protected]#egg=detectron2"') | |
import io | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
## for plotting | |
import matplotlib.pyplot as plt | |
## for ocr | |
import pdf2image | |
import cv2 | |
import layoutparser as lp | |
from docx import Document | |
from docx.shared import Inches | |
def parse_doc(dic): | |
for k,v in dic.items(): | |
if "Title" in k: | |
print('\x1b[1;31m'+ v +'\x1b[0m') | |
elif "Figure" in k: | |
plt.figure(figsize=(10,5)) | |
plt.imshow(v) | |
plt.show() | |
else: | |
print(v) | |
print(" ") | |
def to_image(filename): | |
doc = pdf2image.convert_from_path(filename, dpi=350, last_page=1) | |
# Save imgs | |
folder = "doc" | |
if folder not in os.listdir(): | |
os.makedirs(folder) | |
p = 1 | |
for page in doc: | |
image_name = "page_"+str(p)+".jpg" | |
page.save(os.path.join(folder, image_name), "JPEG") | |
p = p+1 | |
return doc | |
def detect(doc): | |
# General | |
model = lp.Detectron2LayoutModel("lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config", | |
extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8], | |
label_map={0:"Text", 1:"Title", 2:"List", 3:"Table", 4:"Figure"}) | |
## turn img into array | |
img = np.asarray(doc[0]) | |
## predict | |
detected = model.detect(img) | |
return img, detected | |
# sort detected | |
def split_page(img, n, axis): | |
new_detected, start = [], 0 | |
for s in range(n): | |
end = len(img[0])/3 * s if axis == "x" else len(img[1])/3 | |
section = lp.Interval(start=start, end=end, axis=axis).put_on_canvas(img) | |
filter_detected = detected.filter_by(section, center=True)._blocks | |
new_detected = new_detected + filter_detected | |
start = end | |
return lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)]) | |
def get_detected(img, detected): | |
n_cols,n_rows = 1,1 | |
## if single page just sort based on y | |
if (n_cols == 1) and (n_rows == 1): | |
new_detected = detected.sort(key=lambda x: x.coordinates[1]) | |
detected = lp.Layout([block.set(id=idx) for idx,block in enumerate(new_detected)]) | |
## if multi columns sort by x,y | |
elif (n_cols > 1) and (n_rows == 1): | |
detected = split_page(img, n_cols, axis="x") | |
## if multi rows sort by y,x | |
elif (n_cols > 1) and (n_rows == 1): | |
detected = split_page(img, n_rows, axis="y") | |
## if multi columns-rows | |
else: | |
pass | |
return detected | |
def predict_elements(img, detected)->dict: | |
model = lp.TesseractAgent(languages='eng') | |
dic_predicted = {} | |
for block in [block for block in detected if block.type in ["Title","Text", "List"]]: | |
## segmentation | |
segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img) | |
## extraction | |
extracted = model.detect(segmented) | |
## save | |
dic_predicted[str(block.id)+"-"+block.type] = extracted.replace('\n',' ').strip() | |
for block in [block for block in detected if block.type == "Figure"]: | |
## segmentation | |
segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img) | |
## save | |
dic_predicted[str(block.id)+"-"+block.type] = segmented | |
for block in [block for block in detected if block.type == "Table"]: | |
## segmentation | |
segmented = block.pad(left=15, right=15, top=5, bottom=5).crop_image(img) | |
## extraction | |
extracted = model.detect(segmented) | |
## save | |
dic_predicted[str(block.id)+"-"+block.type] = pd.read_csv( io.StringIO(extracted) ) | |
return dic_predicted | |
def gen_doc(dic_predicted:dict): | |
document = Document() | |
for k,v in dic_predicted.items(): | |
if "Figure" in k: | |
cv2.imwrite(f'{k}.jpg', dic_predicted[k]) | |
document.add_picture(f'{k}.jpg', width=Inches(3)) | |
elif "Table" in k: | |
table = document.add_table(rows=v.shape[0], cols=v.shape[1]) | |
hdr_cells = table.rows[0].cells | |
for idx, col in enumerate(v.columns): | |
hdr_cells[idx].text = col | |
for c in v.iterrows(): | |
for idx, col in enumerate(v.columns): | |
try: | |
if len(c[1][col].strip())>0: | |
row_cells = table.add_row().cells | |
row_cells[idx].text = str(c[1][col]) | |
except: | |
continue | |
else: | |
document.add_paragraph(str(v)) | |
document.save('demo.docx') | |
def main_convert(filename): | |
print(filename.name) | |
doc = to_image(filename.name) | |
img, detected = detect(doc) | |
n_detected = get_detected(img, detected) | |
dic_predicted = predict_elements(img, n_detected) | |
gen_doc(dic_predicted) | |
im_out = lp.draw_box(img, detected, box_width=5, box_alpha=0.2, show_element_type=True) | |
dict_out = {} | |
for k,v in dic_predicted.items(): | |
if "figure" not in k.lower(): | |
dict_out[k] = dic_predicted[k] | |
return 'demo.docx', im_out, dict_out | |
inputs = [gr.File(type='file', label="Original PDF File")] | |
outputs = [gr.File(label="Converted DOC File"),gr.Image(type="PIL.Image", label="Detected Image"), gr.JSON()] | |
title = "A Document AI parser" | |
description = "This demo uses AI Models to detect text, titles, tables, figures and lists as well as table cells from an Scanned document.\nBased on the layout it determines reading order and generates an MS-DOC file to Download." | |
io = gr.Interface(fn=main_convert, inputs=inputs, outputs=outputs, title=title, description=description, | |
css= """.gr-button-primary { background: -webkit-linear-gradient( | |
90deg, #355764 0%, #55a8a1 100% ) !important; background: #355764; | |
background: linear-gradient( | |
90deg, #355764 0%, #55a8a1 100% ) !important; | |
background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important; | |
background: -webkit-linear-gradient( | |
90deg, #355764 0%, #55a8a1 100% ) !important; | |
color:white !important}""" | |
) | |
io.launch() |