tianchiguaixia commited on
Commit
03fb5e6
·
1 Parent(s): a2f23b0

Upload 2 files

Browse files
Files changed (2) hide show
  1. AI-医学图片OCR.py +65 -0
  2. ocr_utils.py +281 -0
AI-医学图片OCR.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # time: 2022/10/17 11:22
3
+ # file: AI-医学图片OCR.py
4
+
5
+
6
+ import streamlit as st
7
+
8
+ from ocr.ocr import detect, recognize
9
+ from ocr.utils import bytes_to_numpy
10
+ import pandas as pd
11
+
12
+ import os
13
+ import cv2
14
+ from paddleocr import PPStructure, draw_structure_result, save_structure_res
15
+
16
+ st.title("AI-医学图片OCR")
17
+
18
+
19
+ def convert_df(df):
20
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
21
+ return df.to_csv().encode("gbk")
22
+
23
+
24
+ # 上传图片
25
+ uploaded_file = st.sidebar.file_uploader(
26
+ '请选择一张图片', type=['png', 'jpg', 'jpeg'])
27
+ print('uploaded_file:', uploaded_file)
28
+ table_engine = PPStructure(show_log=True)
29
+ if uploaded_file is not None:
30
+ # To read file as bytes:
31
+ # content = cv2.imread(uploaded_file)
32
+ # st.write(content)
33
+ bytes_data = uploaded_file.getvalue()
34
+ # 转换格式
35
+ img = bytes_to_numpy(bytes_data, channels='RGB')
36
+ option_task = st.sidebar.radio('请选择要执行的任务', ('查看原图', '文本检测'))
37
+ if option_task == '查看原图':
38
+ st.image(img, caption='原图')
39
+ elif option_task == '文本检测':
40
+ im_show = detect(img)
41
+ st.image(im_show, caption='文本检测后的图片')
42
+
43
+ base_path = "streamlit_data"
44
+
45
+ path = os.path.exists(base_path + "/" + uploaded_file.name.split('.')[0])
46
+
47
+ if st.button('✨ 启动!'):
48
+ local_path = base_path + "/" + uploaded_file.name.split('.')[0]
49
+ result = table_engine(img)
50
+ save_structure_res(result, base_path, uploaded_file.name.split('.')[0])
51
+ with st.container():
52
+ with st.expander(label="json结果展示", expanded=False):
53
+ st.write(result)
54
+ for i in os.listdir(local_path):
55
+ if ".xlsx" in i:
56
+ df = pd.read_excel(os.path.join(local_path, i))
57
+ df = df.fillna("")
58
+ st.write(df)
59
+ csv = convert_df(df)
60
+ st.download_button(
61
+ label="Download data as csv",
62
+ data=csv,
63
+ file_name='large_df.csv',
64
+ mime='text/csv',
65
+ )
ocr_utils.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # time: 2022/10/17 13:25
3
+ # file: ocr_utils.py
4
+
5
+ import cv2
6
+ import math
7
+ import numpy as np
8
+
9
+ from PIL import Image, ImageDraw, ImageFont
10
+
11
+
12
+ def resize_img(img, input_size=600):
13
+ """
14
+ resize img and limit the longest side of the image to input_size
15
+ """
16
+ img = np.array(img)
17
+ im_shape = img.shape
18
+ im_size_max = np.max(im_shape[0:2])
19
+ im_scale = float(input_size) / float(im_size_max)
20
+ img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
21
+ return img
22
+
23
+
24
+ def draw_ocr(
25
+ image,
26
+ boxes,
27
+ txts=None,
28
+ scores=None,
29
+ drop_score=0.5,
30
+ font_path="./fonts/font.ttf"
31
+ ):
32
+ """
33
+ Visualize the results of OCR detection and recognition
34
+ args:
35
+ image(Image|array): RGB image
36
+ boxes(list): boxes with shape(N, 4, 2)
37
+ txts(list): the texts
38
+ scores(list): txxs corresponding scores
39
+ drop_score(float): only scores greater than drop_threshold will be visualized
40
+ font_path: the path of font which is used to draw text
41
+ return(array):
42
+ the visualized img
43
+ """
44
+ if scores is None:
45
+ scores = [1] * len(boxes)
46
+ box_num = len(boxes)
47
+ for i in range(box_num):
48
+ if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
49
+ continue
50
+ box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
51
+ image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
52
+ if txts is not None:
53
+ img = np.array(resize_img(image, input_size=600))
54
+ txt_img = text_visual(
55
+ txts,
56
+ scores,
57
+ img_h=img.shape[0],
58
+ img_w=600,
59
+ threshold=drop_score,
60
+ font_path=font_path
61
+ )
62
+ img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
63
+ return img
64
+ return image
65
+
66
+
67
+ def draw_ocr_box_txt(
68
+ image,
69
+ boxes,
70
+ txts,
71
+ scores=None,
72
+ drop_score=0.5,
73
+ font_path="./fonts/font.ttf"
74
+ ):
75
+ image = Image.fromarray(image)
76
+ h, w = image.height, image.width
77
+ img_left = image.copy()
78
+ img_right = Image.new('RGB', (w, h), (255, 255, 255))
79
+
80
+ import random
81
+
82
+ random.seed(0)
83
+ draw_left = ImageDraw.Draw(img_left)
84
+ draw_right = ImageDraw.Draw(img_right)
85
+ for idx, (box, txt) in enumerate(zip(boxes, txts)):
86
+ if scores is not None and scores[idx] < drop_score:
87
+ continue
88
+ color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
89
+ draw_left.polygon(
90
+ [
91
+ box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
92
+ box[2][1], box[3][0], box[3][1]
93
+ ],
94
+ fill=color)
95
+ draw_right.polygon(
96
+ [
97
+ box[0][0], box[0][1], box[1][0], box[1][1], box[2][0],
98
+ box[2][1], box[3][0], box[3][1]
99
+ ],
100
+ outline=color)
101
+ box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
102
+ 1])**2)
103
+ box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
104
+ 1])**2)
105
+ if box_height > 2 * box_width:
106
+ font_size = max(int(box_width * 0.9), 10)
107
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
108
+ cur_y = box[0][1]
109
+ for c in txt:
110
+ char_size = font.getsize(c)
111
+ draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font)
112
+ cur_y += char_size[1]
113
+ else:
114
+ font_size = max(int(box_height * 0.8), 10)
115
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
116
+ draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
117
+ img_left = Image.blend(image, img_left, 0.5)
118
+ img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
119
+ img_show.paste(img_left, (0, 0, w, h))
120
+ img_show.paste(img_right, (w, 0, w * 2, h))
121
+ return np.array(img_show)
122
+
123
+
124
+ def str_count(s):
125
+ """
126
+ Count the number of Chinese characters,
127
+ a single English character and a single number
128
+ equal to half the length of Chinese characters.
129
+ args:
130
+ s(string): the input of string
131
+ return(int):
132
+ the number of Chinese characters
133
+ """
134
+ import string
135
+ count_zh = count_pu = 0
136
+ s_len = len(s)
137
+ en_dg_count = 0
138
+ for c in s:
139
+ if c in string.ascii_letters or c.isdigit() or c.isspace():
140
+ en_dg_count += 1
141
+ elif c.isalpha():
142
+ count_zh += 1
143
+ else:
144
+ count_pu += 1
145
+ return s_len - math.ceil(en_dg_count / 2)
146
+
147
+
148
+ def text_visual(
149
+ texts,
150
+ scores,
151
+ img_h=400,
152
+ img_w=600,
153
+ threshold=0.,
154
+ font_path="./fonts/font.ttf"
155
+ ):
156
+ """
157
+ create new blank img and draw txt on it
158
+ args:
159
+ texts(list): the text will be draw
160
+ scores(list|None): corresponding score of each txt
161
+ img_h(int): the height of blank img
162
+ img_w(int): the width of blank img
163
+ font_path: the path of font which is used to draw text
164
+ return(array):
165
+ """
166
+ if scores is not None:
167
+ assert len(texts) == len(scores), "The number of txts and corresponding scores must match"
168
+
169
+ def create_blank_img():
170
+ blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255
171
+ blank_img[:, img_w - 1:] = 0
172
+ blank_img = Image.fromarray(blank_img).convert("RGB")
173
+ draw_txt = ImageDraw.Draw(blank_img)
174
+ return blank_img, draw_txt
175
+
176
+ blank_img, draw_txt = create_blank_img()
177
+
178
+ font_size = 20
179
+ txt_color = (0, 0, 0)
180
+ font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
181
+
182
+ gap = font_size + 5
183
+ txt_img_list = []
184
+ count, index = 1, 0
185
+ for idx, txt in enumerate(texts):
186
+ index += 1
187
+ if scores[idx] < threshold or math.isnan(scores[idx]):
188
+ index -= 1
189
+ continue
190
+ first_line = True
191
+ while str_count(txt) >= img_w // font_size - 4:
192
+ tmp = txt
193
+ txt = tmp[:img_w // font_size - 4]
194
+ if first_line:
195
+ new_txt = str(index) + ': ' + txt
196
+ first_line = False
197
+ else:
198
+ new_txt = ' ' + txt
199
+ draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
200
+ txt = tmp[img_w // font_size - 4:]
201
+ if count >= img_h // gap - 1:
202
+ txt_img_list.append(np.array(blank_img))
203
+ blank_img, draw_txt = create_blank_img()
204
+ count = 0
205
+ count += 1
206
+ if first_line:
207
+ new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx])
208
+ else:
209
+ new_txt = " " + txt + " " + '%.3f' % (scores[idx])
210
+ draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
211
+ # whether add new blank img or not
212
+ if count >= img_h // gap - 1 and idx + 1 < len(texts):
213
+ txt_img_list.append(np.array(blank_img))
214
+ blank_img, draw_txt = create_blank_img()
215
+ count = 0
216
+ count += 1
217
+ txt_img_list.append(np.array(blank_img))
218
+ if len(txt_img_list) == 1:
219
+ blank_img = np.array(txt_img_list[0])
220
+ else:
221
+ blank_img = np.concatenate(txt_img_list, axis=1)
222
+ return np.array(blank_img)
223
+
224
+
225
+ def base64_to_cv2(b64str):
226
+ import base64
227
+ data = base64.b64decode(b64str.encode('utf8'))
228
+ data = np.fromstring(data, np.uint8)
229
+ data = cv2.imdecode(data, cv2.IMREAD_COLOR)
230
+ return data
231
+
232
+
233
+ def draw_boxes(image, boxes, scores=None, drop_score=0.5):
234
+ if scores is None:
235
+ scores = [1] * len(boxes)
236
+ for (box, score) in zip(boxes, scores):
237
+ if score < drop_score:
238
+ continue
239
+ box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
240
+ image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
241
+ return image
242
+
243
+
244
+ def get_rotate_crop_image(img, points):
245
+ '''
246
+ img_height, img_width = img.shape[0:2]
247
+ left = int(np.min(points[:, 0]))
248
+ right = int(np.max(points[:, 0]))
249
+ top = int(np.min(points[:, 1]))
250
+ bottom = int(np.max(points[:, 1]))
251
+ img_crop = img[top:bottom, left:right, :].copy()
252
+ points[:, 0] = points[:, 0] - left
253
+ points[:, 1] = points[:, 1] - top
254
+ '''
255
+ assert len(points) == 4, "shape of points must be 4*2"
256
+ img_crop_width = int(
257
+ max(
258
+ np.linalg.norm(points[0] - points[1]),
259
+ np.linalg.norm(points[2] - points[3])))
260
+ img_crop_height = int(
261
+ max(
262
+ np.linalg.norm(points[0] - points[3]),
263
+ np.linalg.norm(points[1] - points[2])))
264
+ pts_std = np.float32([[0, 0], [img_crop_width, 0],
265
+ [img_crop_width, img_crop_height],
266
+ [0, img_crop_height]])
267
+ M = cv2.getPerspectiveTransform(points, pts_std)
268
+ dst_img = cv2.warpPerspective(
269
+ img,
270
+ M, (img_crop_width, img_crop_height),
271
+ borderMode=cv2.BORDER_REPLICATE,
272
+ flags=cv2.INTER_CUBIC
273
+ )
274
+ dst_img_height, dst_img_width = dst_img.shape[0:2]
275
+ if dst_img_height * 1.0 / dst_img_width >= 1.5:
276
+ dst_img = np.rot90(dst_img)
277
+ return dst_img
278
+
279
+
280
+ if __name__ == '__main__':
281
+ pass