Upload 13 files
Browse files- app/FinalApp.py +223 -0
- app/__init__.py +0 -0
- app/__pycache__/dataloader_iam.cpython-311.pyc +0 -0
- app/__pycache__/model.cpython-311.pyc +0 -0
- app/__pycache__/preprocessor.cpython-311.pyc +0 -0
- app/dataloader_iam.py +133 -0
- app/model.py +334 -0
- app/preprocessor.py +191 -0
- app/runner.py +5 -0
- app/simple.py +64 -0
- app/userInput.png +0 -0
- app/webapp.py +132 -0
- app/word.png +0 -0
app/FinalApp.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from path import Path
|
6 |
+
import streamlit as st
|
7 |
+
from typing import Tuple
|
8 |
+
from dataloader_iam import Batch
|
9 |
+
from model import Model, DecoderType
|
10 |
+
from preprocessor import Preprocessor
|
11 |
+
from streamlit_drawable_canvas import st_canvas
|
12 |
+
import easyocr # Import EasyOCR
|
13 |
+
|
14 |
+
# Set page config at the very beginning (only executed once)
|
15 |
+
st.set_page_config(
|
16 |
+
page_title="HTR App",
|
17 |
+
page_icon=":pencil:",
|
18 |
+
layout="centered",
|
19 |
+
initial_sidebar_state="auto",
|
20 |
+
)
|
21 |
+
|
22 |
+
ms = st.session_state
|
23 |
+
if "themes" not in ms:
|
24 |
+
ms.themes = {"current_theme": "light",
|
25 |
+
"refreshed": True,
|
26 |
+
|
27 |
+
"light": {"theme.base": "dark",
|
28 |
+
"theme.backgroundColor": "black",
|
29 |
+
"theme.primaryColor": "#c98bdb",
|
30 |
+
"theme.secondaryBackgroundColor": "#5591f5",
|
31 |
+
"theme.textColor": "white",
|
32 |
+
"theme.textColor": "white",
|
33 |
+
"button_face": "🌜"},
|
34 |
+
|
35 |
+
"dark": {"theme.base": "light",
|
36 |
+
"theme.backgroundColor": "white",
|
37 |
+
"theme.primaryColor": "#5591f5",
|
38 |
+
"theme.secondaryBackgroundColor": "#82E1D7",
|
39 |
+
"theme.textColor": "#0a1464",
|
40 |
+
"button_face": "🌞"},
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def ChangeTheme():
|
45 |
+
previous_theme = ms.themes["current_theme"]
|
46 |
+
tdict = ms.themes["light"] if ms.themes["current_theme"] == "light" else ms.themes["dark"]
|
47 |
+
for vkey, vval in tdict.items():
|
48 |
+
if vkey.startswith("theme"): st._config.set_option(vkey, vval)
|
49 |
+
|
50 |
+
ms.themes["refreshed"] = False
|
51 |
+
if previous_theme == "dark": ms.themes["current_theme"] = "light"
|
52 |
+
elif previous_theme == "light": ms.themes["current_theme"] = "dark"
|
53 |
+
|
54 |
+
|
55 |
+
btn_face = ms.themes["light"]["button_face"] if ms.themes["current_theme"] == "light" else ms.themes["dark"]["button_face"]
|
56 |
+
st.button(btn_face, on_click=ChangeTheme)
|
57 |
+
|
58 |
+
if ms.themes["refreshed"] == False:
|
59 |
+
ms.themes["refreshed"] = True
|
60 |
+
st.rerun()
|
61 |
+
|
62 |
+
|
63 |
+
def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
|
64 |
+
"""
|
65 |
+
Auxiliary method that sets the height and width
|
66 |
+
Height is fixed while width is set according to the Model used.
|
67 |
+
"""
|
68 |
+
if line_mode:
|
69 |
+
return 256, get_img_height()
|
70 |
+
return 128, get_img_height()
|
71 |
+
|
72 |
+
def get_img_height() -> int:
|
73 |
+
"""
|
74 |
+
Auxiliary method that sets the height, which is fixed for the Neural Network.
|
75 |
+
"""
|
76 |
+
return 32
|
77 |
+
|
78 |
+
def infer(line_mode: bool, model: Model, fn_img: Path) -> None:
|
79 |
+
"""
|
80 |
+
Auxiliary method that does inference using the pretrained models:
|
81 |
+
Recognizes text in an image given its path.
|
82 |
+
"""
|
83 |
+
img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
|
84 |
+
assert img is not None
|
85 |
+
|
86 |
+
preprocessor = Preprocessor(get_img_size(line_mode), dynamic_width=True, padding=16)
|
87 |
+
img = preprocessor.process_img(img)
|
88 |
+
|
89 |
+
batch = Batch([img], None, 1)
|
90 |
+
recognized, probability = model.infer_batch(batch, True)
|
91 |
+
return [recognized, probability]
|
92 |
+
|
93 |
+
def infer_super_model(image_path) -> None:
|
94 |
+
reader = easyocr.Reader(['en']) # Initialize EasyOCR reader
|
95 |
+
result = reader.readtext(image_path)
|
96 |
+
recognized_texts = [text[1] for text in result] # Extract recognized texts
|
97 |
+
probabilities = [text[2] for text in result] # Extract probabilities
|
98 |
+
return recognized_texts, probabilities
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def main():
|
103 |
+
|
104 |
+
st.title('Extract text from Image Demo')
|
105 |
+
|
106 |
+
st.markdown("""
|
107 |
+
Streamlit Web Interface for Handwritten Text Recognition (HTR), Optical Character Recognition (OCR)
|
108 |
+
implemented with TensorFlow and trained on the IAM off-line HTR dataset.
|
109 |
+
The model takes images of single words or text lines (multiple words) as input and outputs the recognized text.
|
110 |
+
""", unsafe_allow_html=True)
|
111 |
+
|
112 |
+
st.markdown("""
|
113 |
+
Predictions can be made using one of two models:
|
114 |
+
- Single_Model (Trained on Single Word Images)
|
115 |
+
- Line_Model (Trained on Text Line Images)
|
116 |
+
- Super_Model ( Most Robust Option for English )
|
117 |
+
- Burmese (Link)
|
118 |
+
""", unsafe_allow_html=True)
|
119 |
+
|
120 |
+
st.subheader('Select a Model, Choose the Arguments and Draw in the box below or Upload an Image to obtain a prediction.')
|
121 |
+
|
122 |
+
#Selectors for the model and decoder
|
123 |
+
modelSelect = st.selectbox("Select a Model", ['Single_Model', 'Line_Model', 'Super_Model'])
|
124 |
+
|
125 |
+
|
126 |
+
if modelSelect != 'Super_Model':
|
127 |
+
decoderSelect = st.selectbox("Select a Decoder", ['Bestpath', 'Beamsearch', 'Wordbeamsearch'])
|
128 |
+
|
129 |
+
|
130 |
+
#Mappings (dictionaries) for the model and decoder. Asigns the directory or the DecoderType of the selected option.
|
131 |
+
modelMapping = {
|
132 |
+
"Single_Model": '../model/word-model',
|
133 |
+
"Line_Model": '../model/line-model'
|
134 |
+
}
|
135 |
+
|
136 |
+
decoderMapping = {
|
137 |
+
'Bestpath': DecoderType.BestPath,
|
138 |
+
'Beamsearch': DecoderType.BeamSearch,
|
139 |
+
'Wordbeamsearch': DecoderType.WordBeamSearch
|
140 |
+
}
|
141 |
+
|
142 |
+
#Slider for pencil width
|
143 |
+
strokeWidth = st.slider("Stroke Width: ", 1, 25, 6)
|
144 |
+
|
145 |
+
#Canvas/Text Box for user input. BackGround Color must be white (#FFFFFF) or else text will not be properly recognised.
|
146 |
+
inputDrawn = st_canvas(
|
147 |
+
fill_color="rgba(255, 165, 0, 0.3)",
|
148 |
+
stroke_width=strokeWidth,
|
149 |
+
update_streamlit=True,
|
150 |
+
background_image=None,
|
151 |
+
height = 200,
|
152 |
+
width = 400,
|
153 |
+
drawing_mode='freedraw',
|
154 |
+
key="canvas",
|
155 |
+
background_color = '#FFFFFF'
|
156 |
+
)
|
157 |
+
|
158 |
+
#Buffer for user input (images uploaded from the user's device)
|
159 |
+
inputBuffer = st.file_uploader("Upload an Image", type=["png"])
|
160 |
+
|
161 |
+
#Inference Button
|
162 |
+
inferBool = st.button("Recognize Text")
|
163 |
+
|
164 |
+
# After clicking the "Recognize Text" button, check if the model selected is Super_Model
|
165 |
+
if inferBool:
|
166 |
+
if modelSelect == 'Super_Model':
|
167 |
+
inputArray = None # Initialize inputArray to None
|
168 |
+
|
169 |
+
# Handling uploaded file
|
170 |
+
if inputBuffer is not None:
|
171 |
+
with Image.open(inputBuffer).convert('RGB') as img:
|
172 |
+
inputArray = np.array(img)
|
173 |
+
|
174 |
+
# Handling canvas data
|
175 |
+
elif inputDrawn.image_data is not None:
|
176 |
+
# Convert RGBA to RGB
|
177 |
+
inputArray = cv2.cvtColor(np.array(inputDrawn.image_data, dtype=np.uint8), cv2.COLOR_RGBA2RGB)
|
178 |
+
|
179 |
+
# Now check if inputArray has been set
|
180 |
+
if inputArray is not None:
|
181 |
+
# Initialize EasyOCR Reader
|
182 |
+
reader = easyocr.Reader(['en']) # Assuming English language; adjust as necessary
|
183 |
+
# Perform OCR
|
184 |
+
results = reader.readtext(inputArray)
|
185 |
+
|
186 |
+
# Display results
|
187 |
+
all_text = ''
|
188 |
+
for (bbox, text, prob) in results:
|
189 |
+
all_text += f'{text} (confidence: {prob:.2f})\n'
|
190 |
+
|
191 |
+
st.write("**Recognized Texts and their Confidence Scores:**")
|
192 |
+
st.text(all_text)
|
193 |
+
else:
|
194 |
+
st.write("No image data found. Please upload an image or draw on the canvas.")
|
195 |
+
|
196 |
+
|
197 |
+
else:
|
198 |
+
# Handle other model selections as before
|
199 |
+
if ((inputDrawn.image_data is not None or inputBuffer is not None) and inferBool == True):
|
200 |
+
#We turn the input into a numpy array
|
201 |
+
if inputDrawn.image_data is not None:
|
202 |
+
inputArray = np.array(inputDrawn.image_data)
|
203 |
+
|
204 |
+
if inputBuffer is not None:
|
205 |
+
inputBufferImage = Image.open(inputBuffer)
|
206 |
+
inputArray = np.array(inputBufferImage)
|
207 |
+
|
208 |
+
#We turn this array into a .png format and save it.
|
209 |
+
inputImage = Image.fromarray(inputArray.astype('uint8'), 'RGBA')
|
210 |
+
inputImage.save('userInput.png')
|
211 |
+
#We obtain the model directory and the decoder type from their mapping
|
212 |
+
modelDir = modelMapping[modelSelect]
|
213 |
+
decoderType = decoderMapping[decoderSelect]
|
214 |
+
|
215 |
+
#Finally, we call the model with this image as attribute and display the Best Candidate and its probability on the Interface
|
216 |
+
model = Model(list(open(modelDir + "/charList.txt").read()), modelDir, decoderType, must_restore=True)
|
217 |
+
inferedText = infer(modelDir == '../model/line-model', model, 'userInput.png')
|
218 |
+
|
219 |
+
st.write("**Best Candidate: **", inferedText[0][0])
|
220 |
+
st.write("**Probability: **", str(inferedText[1][0]*100) + "%")
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
main()
|
app/__init__.py
ADDED
File without changes
|
app/__pycache__/dataloader_iam.cpython-311.pyc
ADDED
Binary file (8.17 kB). View file
|
|
app/__pycache__/model.cpython-311.pyc
ADDED
Binary file (19.5 kB). View file
|
|
app/__pycache__/preprocessor.cpython-311.pyc
ADDED
Binary file (10.7 kB). View file
|
|
app/dataloader_iam.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import random
|
3 |
+
from collections import namedtuple
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
from imdb import Cinemagoer
|
8 |
+
import numpy as np
|
9 |
+
from path import Path
|
10 |
+
|
11 |
+
Sample = namedtuple('Sample', 'gt_text, file_path')
|
12 |
+
Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')
|
13 |
+
|
14 |
+
|
15 |
+
class DataLoaderIAM:
|
16 |
+
"""
|
17 |
+
Loads data which corresponds to IAM format,
|
18 |
+
see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
data_dir: Path,
|
23 |
+
batch_size: int,
|
24 |
+
data_split: float = 0.95,
|
25 |
+
fast: bool = True) -> None:
|
26 |
+
"""Loader for dataset."""
|
27 |
+
|
28 |
+
assert data_dir.exists()
|
29 |
+
|
30 |
+
self.fast = fast
|
31 |
+
if fast:
|
32 |
+
self.env = Cinemagoer.open(str(data_dir / 'lmdb'), readonly=True)
|
33 |
+
|
34 |
+
self.data_augmentation = False
|
35 |
+
self.curr_idx = 0
|
36 |
+
self.batch_size = batch_size
|
37 |
+
self.samples = []
|
38 |
+
|
39 |
+
f = open(data_dir / 'gt/words.txt')
|
40 |
+
chars = set()
|
41 |
+
bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset
|
42 |
+
for line in f:
|
43 |
+
# ignore comment line
|
44 |
+
if not line or line[0] == '#':
|
45 |
+
continue
|
46 |
+
|
47 |
+
line_split = line.strip().split(' ')
|
48 |
+
assert len(line_split) >= 9
|
49 |
+
|
50 |
+
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
|
51 |
+
file_name_split = line_split[0].split('-')
|
52 |
+
file_name_subdir1 = file_name_split[0]
|
53 |
+
file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}'
|
54 |
+
file_base_name = line_split[0] + '.png'
|
55 |
+
file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name
|
56 |
+
|
57 |
+
if line_split[0] in bad_samples_reference:
|
58 |
+
print('Ignoring known broken image:', file_name)
|
59 |
+
continue
|
60 |
+
|
61 |
+
# GT text are columns starting at 9
|
62 |
+
gt_text = ' '.join(line_split[8:])
|
63 |
+
chars = chars.union(set(list(gt_text)))
|
64 |
+
|
65 |
+
# put sample into list
|
66 |
+
self.samples.append(Sample(gt_text, file_name))
|
67 |
+
|
68 |
+
# split into training and validation set: 95% - 5%
|
69 |
+
split_idx = int(data_split * len(self.samples))
|
70 |
+
self.train_samples = self.samples[:split_idx]
|
71 |
+
self.validation_samples = self.samples[split_idx:]
|
72 |
+
|
73 |
+
# put words into lists
|
74 |
+
self.train_words = [x.gt_text for x in self.train_samples]
|
75 |
+
self.validation_words = [x.gt_text for x in self.validation_samples]
|
76 |
+
|
77 |
+
# start with train set
|
78 |
+
self.train_set()
|
79 |
+
|
80 |
+
# list of all chars in dataset
|
81 |
+
self.char_list = sorted(list(chars))
|
82 |
+
|
83 |
+
def train_set(self) -> None:
|
84 |
+
"""Switch to randomly chosen subset of training set."""
|
85 |
+
self.data_augmentation = True
|
86 |
+
self.curr_idx = 0
|
87 |
+
random.shuffle(self.train_samples)
|
88 |
+
self.samples = self.train_samples
|
89 |
+
self.curr_set = 'train'
|
90 |
+
|
91 |
+
def validation_set(self) -> None:
|
92 |
+
"""Switch to validation set."""
|
93 |
+
self.data_augmentation = False
|
94 |
+
self.curr_idx = 0
|
95 |
+
self.samples = self.validation_samples
|
96 |
+
self.curr_set = 'val'
|
97 |
+
|
98 |
+
def get_iterator_info(self) -> Tuple[int, int]:
|
99 |
+
"""Current batch index and overall number of batches."""
|
100 |
+
if self.curr_set == 'train':
|
101 |
+
num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches
|
102 |
+
else:
|
103 |
+
num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller
|
104 |
+
curr_batch = self.curr_idx // self.batch_size + 1
|
105 |
+
return curr_batch, num_batches
|
106 |
+
|
107 |
+
def has_next(self) -> bool:
|
108 |
+
"""Is there a next element?"""
|
109 |
+
if self.curr_set == 'train':
|
110 |
+
return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches
|
111 |
+
else:
|
112 |
+
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller
|
113 |
+
|
114 |
+
def _get_img(self, i: int) -> np.ndarray:
|
115 |
+
if self.fast:
|
116 |
+
with self.env.begin() as txn:
|
117 |
+
basename = Path(self.samples[i].file_path).basename()
|
118 |
+
data = txn.get(basename.encode("ascii"))
|
119 |
+
img = pickle.loads(data)
|
120 |
+
else:
|
121 |
+
img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
|
122 |
+
|
123 |
+
return img
|
124 |
+
|
125 |
+
def get_next(self) -> Batch:
|
126 |
+
"""Get next element."""
|
127 |
+
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
|
128 |
+
|
129 |
+
imgs = [self._get_img(i) for i in batch_range]
|
130 |
+
gt_texts = [self.samples[i].gt_text for i in batch_range]
|
131 |
+
|
132 |
+
self.curr_idx += self.batch_size
|
133 |
+
return Batch(imgs, gt_texts, len(imgs))
|
app/model.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from typing import List, Tuple
|
4 |
+
import tf_keras as keras
|
5 |
+
import numpy as np
|
6 |
+
from dataloader_iam import Batch
|
7 |
+
|
8 |
+
import tensorflow.compat.v1 as tf
|
9 |
+
tf.compat.v1.disable_v2_behavior
|
10 |
+
|
11 |
+
# Disable eager mode
|
12 |
+
tf.compat.v1.disable_eager_execution
|
13 |
+
|
14 |
+
class DecoderType:
|
15 |
+
"""
|
16 |
+
CTC decoder types.
|
17 |
+
"""
|
18 |
+
BestPath = 0
|
19 |
+
BeamSearch = 1
|
20 |
+
WordBeamSearch = 2
|
21 |
+
|
22 |
+
|
23 |
+
class Model:
|
24 |
+
"""
|
25 |
+
Minimalistic TF model for HTR.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
char_list: List[str],
|
30 |
+
model_dir: str,
|
31 |
+
decoder_type: str = DecoderType.BestPath,
|
32 |
+
must_restore: bool = False,
|
33 |
+
dump: bool = False) -> None:
|
34 |
+
"""
|
35 |
+
Init model: add CNN, RNN and CTC and initialize TF.
|
36 |
+
"""
|
37 |
+
self.dump = dump
|
38 |
+
self.char_list = char_list
|
39 |
+
self.decoder_type = decoder_type
|
40 |
+
self.must_restore = must_restore
|
41 |
+
self.snap_ID = 0
|
42 |
+
self.model_dir = model_dir
|
43 |
+
|
44 |
+
tf.compat.v1.disable_eager_execution()
|
45 |
+
# Whether to use normalization over a batch or a population
|
46 |
+
self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train')
|
47 |
+
|
48 |
+
# input image batch
|
49 |
+
self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None))
|
50 |
+
|
51 |
+
# setup CNN, RNN and CTC
|
52 |
+
self.setup_cnn()
|
53 |
+
self.setup_rnn()
|
54 |
+
self.setup_ctc()
|
55 |
+
|
56 |
+
# setup optimizer to train NN
|
57 |
+
self.batches_trained = 0
|
58 |
+
self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
|
59 |
+
with tf.control_dependencies(self.update_ops):
|
60 |
+
self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss)
|
61 |
+
|
62 |
+
# initialize TF
|
63 |
+
self.sess, self.saver = self.setup_tf()
|
64 |
+
|
65 |
+
def setup_cnn(self) -> None:
|
66 |
+
"""
|
67 |
+
Create CNN layers.
|
68 |
+
"""
|
69 |
+
cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3)
|
70 |
+
|
71 |
+
# list of parameters for the layers
|
72 |
+
kernel_vals = [5, 5, 3, 3, 3]
|
73 |
+
feature_vals = [1, 32, 64, 128, 128, 256]
|
74 |
+
stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)]
|
75 |
+
num_layers = len(stride_vals)
|
76 |
+
|
77 |
+
# create layers
|
78 |
+
pool = cnn_in4d # input to first CNN layer
|
79 |
+
for i in range(num_layers):
|
80 |
+
kernel = tf.Variable(
|
81 |
+
tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]],
|
82 |
+
stddev=0.1))
|
83 |
+
conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1))
|
84 |
+
conv_norm = tf.keras.layers.BatchNormalization()(conv, training=self.is_train)
|
85 |
+
relu = tf.nn.relu(conv_norm)
|
86 |
+
pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1),
|
87 |
+
strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID')
|
88 |
+
|
89 |
+
self.cnn_out_4d = pool
|
90 |
+
|
91 |
+
def setup_rnn(self) -> None:
|
92 |
+
"""
|
93 |
+
Create RNN layers.
|
94 |
+
"""
|
95 |
+
rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2])
|
96 |
+
|
97 |
+
# basic cells which is used to build RNN
|
98 |
+
num_hidden = 256
|
99 |
+
cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in
|
100 |
+
range(2)] # 2 layers
|
101 |
+
|
102 |
+
# stack basic cells
|
103 |
+
stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
|
104 |
+
|
105 |
+
# bidirectional RNN
|
106 |
+
# BxTxF -> BxTx2H
|
107 |
+
(fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d,
|
108 |
+
dtype=rnn_in3d.dtype)
|
109 |
+
|
110 |
+
# BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
|
111 |
+
concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)
|
112 |
+
|
113 |
+
# project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
|
114 |
+
kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1))
|
115 |
+
self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'),
|
116 |
+
axis=[2])
|
117 |
+
|
118 |
+
def setup_ctc(self) -> None:
|
119 |
+
"""
|
120 |
+
Create CTC loss and decoder.
|
121 |
+
"""
|
122 |
+
# BxTxC -> TxBxC
|
123 |
+
self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2])
|
124 |
+
# ground truth text as sparse tensor
|
125 |
+
self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]),
|
126 |
+
tf.compat.v1.placeholder(tf.int32, [None]),
|
127 |
+
tf.compat.v1.placeholder(tf.int64, [2]))
|
128 |
+
|
129 |
+
# calc loss for batch
|
130 |
+
self.seq_len = tf.compat.v1.placeholder(tf.int32, [None])
|
131 |
+
self.loss = tf.reduce_mean(
|
132 |
+
input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc,
|
133 |
+
sequence_length=self.seq_len,
|
134 |
+
ctc_merge_repeated=True))
|
135 |
+
|
136 |
+
# calc loss for each element to compute label probability
|
137 |
+
self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32,
|
138 |
+
shape=[None, None, len(self.char_list) + 1])
|
139 |
+
self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input,
|
140 |
+
sequence_length=self.seq_len, ctc_merge_repeated=True)
|
141 |
+
|
142 |
+
# best path decoding or beam search decoding
|
143 |
+
if self.decoder_type == DecoderType.BestPath:
|
144 |
+
self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len)
|
145 |
+
elif self.decoder_type == DecoderType.BeamSearch:
|
146 |
+
self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len,
|
147 |
+
beam_width=50)
|
148 |
+
# word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch)
|
149 |
+
elif self.decoder_type == DecoderType.WordBeamSearch:
|
150 |
+
# prepare information about language (dictionary, characters in dataset, characters forming words)
|
151 |
+
chars = ''.join(self.char_list)
|
152 |
+
word_chars = open('../model/wordCharList.txt').read().splitlines()[0]
|
153 |
+
corpus = open('../data/corpus.txt').read()
|
154 |
+
|
155 |
+
# decode using the "Words" mode of word beam search
|
156 |
+
from word_beam_search import WordBeamSearch
|
157 |
+
self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'),
|
158 |
+
word_chars.encode('utf8'))
|
159 |
+
|
160 |
+
# the input to the decoder must have softmax already applied
|
161 |
+
self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2)
|
162 |
+
|
163 |
+
def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]:
|
164 |
+
"""
|
165 |
+
Initialize TF.
|
166 |
+
"""
|
167 |
+
print('Python: ' + sys.version)
|
168 |
+
print('Tensorflow: ' + tf.__version__)
|
169 |
+
|
170 |
+
sess = tf.compat.v1.Session() # TF session
|
171 |
+
|
172 |
+
saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file
|
173 |
+
latest_snapshot = tf.train.latest_checkpoint(self.model_dir ) # is there a saved model?
|
174 |
+
|
175 |
+
# if model must be restored (for inference), there must be a snapshot
|
176 |
+
if self.must_restore and not latest_snapshot:
|
177 |
+
raise Exception('No saved model found in: ' + model_dir)
|
178 |
+
|
179 |
+
# load saved model if available
|
180 |
+
if latest_snapshot:
|
181 |
+
print('Init with stored values from ' + latest_snapshot)
|
182 |
+
saver.restore(sess, latest_snapshot)
|
183 |
+
else:
|
184 |
+
print('Init with new values')
|
185 |
+
sess.run(tf.compat.v1.global_variables_initializer())
|
186 |
+
|
187 |
+
return sess, saver
|
188 |
+
|
189 |
+
def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]:
|
190 |
+
"""
|
191 |
+
Put ground truth texts into sparse tensor for ctc_loss.
|
192 |
+
"""
|
193 |
+
indices = []
|
194 |
+
values = []
|
195 |
+
shape = [len(texts), 0] # last entry must be max(labelList[i])
|
196 |
+
|
197 |
+
# go over all texts
|
198 |
+
for batchElement, text in enumerate(texts):
|
199 |
+
# convert to string of label (i.e. class-ids)
|
200 |
+
label_str = [self.char_list.index(c) for c in text]
|
201 |
+
# sparse tensor must have size of max. label-string
|
202 |
+
if len(label_str) > shape[1]:
|
203 |
+
shape[1] = len(label_str)
|
204 |
+
# put each label into sparse tensor
|
205 |
+
for i, label in enumerate(label_str):
|
206 |
+
indices.append([batchElement, i])
|
207 |
+
values.append(label)
|
208 |
+
|
209 |
+
return indices, values, shape
|
210 |
+
|
211 |
+
def decoder_output_to_text(self, ctc_output: tuple, batch_size: int) -> List[str]:
|
212 |
+
"""
|
213 |
+
Extract texts from output of CTC decoder.
|
214 |
+
"""
|
215 |
+
|
216 |
+
# word beam search: already contains label strings
|
217 |
+
if self.decoder_type == DecoderType.WordBeamSearch:
|
218 |
+
label_strs = ctc_output
|
219 |
+
|
220 |
+
# TF decoders: label strings are contained in sparse tensor
|
221 |
+
else:
|
222 |
+
# ctc returns tuple, first element is SparseTensor
|
223 |
+
decoded = ctc_output[0][0]
|
224 |
+
|
225 |
+
# contains string of labels for each batch element
|
226 |
+
label_strs = [[] for _ in range(batch_size)]
|
227 |
+
|
228 |
+
# go over all indices and save mapping: batch -> values
|
229 |
+
for (idx, idx2d) in enumerate(decoded.indices):
|
230 |
+
label = decoded.values[idx]
|
231 |
+
batch_element = idx2d[0] # index according to [b,t]
|
232 |
+
label_strs[batch_element].append(label)
|
233 |
+
|
234 |
+
# map labels to chars for all batch elements
|
235 |
+
return [''.join([self.char_list[c] for c in labelStr]) for labelStr in label_strs]
|
236 |
+
|
237 |
+
def train_batch(self, batch: Batch) -> float:
|
238 |
+
"""
|
239 |
+
Feed a batch into the NN to train it.
|
240 |
+
"""
|
241 |
+
num_batch_elements = len(batch.imgs)
|
242 |
+
max_text_len = batch.imgs[0].shape[0] // 4
|
243 |
+
sparse = self.to_sparse(batch.gt_texts)
|
244 |
+
eval_list = [self.optimizer, self.loss]
|
245 |
+
feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse,
|
246 |
+
self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True}
|
247 |
+
_, loss_val = self.sess.run(eval_list, feed_dict)
|
248 |
+
self.batches_trained += 1
|
249 |
+
return loss_val
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def dump_nn_output(rnn_output: np.ndarray) -> None:
|
253 |
+
"""
|
254 |
+
Dump the output of the NN to CSV file(s).
|
255 |
+
"""
|
256 |
+
dump_dir = '../dump/'
|
257 |
+
if not os.path.isdir(dump_dir):
|
258 |
+
os.mkdir(dump_dir)
|
259 |
+
|
260 |
+
# iterate over all batch elements and create a CSV file for each one
|
261 |
+
max_t, max_b, max_c = rnn_output.shape
|
262 |
+
for b in range(max_b):
|
263 |
+
csv = ''
|
264 |
+
for t in range(max_t):
|
265 |
+
for c in range(max_c):
|
266 |
+
csv += str(rnn_output[t, b, c]) + ';'
|
267 |
+
csv += '\n'
|
268 |
+
fn = dump_dir + 'rnnOutput_' + str(b) + '.csv'
|
269 |
+
print('Write dump of NN to file: ' + fn)
|
270 |
+
with open(fn, 'w') as f:
|
271 |
+
f.write(csv)
|
272 |
+
|
273 |
+
def infer_batch(self, batch: Batch, calc_probability: bool = False, probability_of_gt: bool = False):
|
274 |
+
"""
|
275 |
+
Feed a batch into the NN to recognize the texts.
|
276 |
+
"""
|
277 |
+
|
278 |
+
# decode, optionally save RNN output
|
279 |
+
num_batch_elements = len(batch.imgs)
|
280 |
+
|
281 |
+
# put tensors to be evaluated into list
|
282 |
+
eval_list = []
|
283 |
+
|
284 |
+
if self.decoder_type == DecoderType.WordBeamSearch:
|
285 |
+
eval_list.append(self.wbs_input)
|
286 |
+
else:
|
287 |
+
eval_list.append(self.decoder)
|
288 |
+
|
289 |
+
if self.dump or calc_probability:
|
290 |
+
eval_list.append(self.ctc_in_3d_tbc)
|
291 |
+
|
292 |
+
# sequence length depends on input image size (model downsizes width by 4)
|
293 |
+
max_text_len = batch.imgs[0].shape[0] // 4
|
294 |
+
|
295 |
+
# dict containing all tensor fed into the model
|
296 |
+
feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements,
|
297 |
+
self.is_train: False}
|
298 |
+
|
299 |
+
# evaluate model
|
300 |
+
eval_res = self.sess.run(eval_list, feed_dict)
|
301 |
+
|
302 |
+
# TF decoders: decoding already done in TF graph
|
303 |
+
if self.decoder_type != DecoderType.WordBeamSearch:
|
304 |
+
decoded = eval_res[0]
|
305 |
+
# word beam search decoder: decoding is done in C++ function compute()
|
306 |
+
else:
|
307 |
+
decoded = self.decoder.compute(eval_res[0])
|
308 |
+
|
309 |
+
# map labels (numbers) to character string
|
310 |
+
texts = self.decoder_output_to_text(decoded, num_batch_elements)
|
311 |
+
|
312 |
+
# feed RNN output and recognized text into CTC loss to compute labeling probability
|
313 |
+
probs = None
|
314 |
+
if calc_probability:
|
315 |
+
sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts)
|
316 |
+
ctc_input = eval_res[1]
|
317 |
+
eval_list = self.loss_per_element
|
318 |
+
feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse,
|
319 |
+
self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False}
|
320 |
+
loss_vals = self.sess.run(eval_list, feed_dict)
|
321 |
+
probs = np.exp(-loss_vals)
|
322 |
+
|
323 |
+
# dump the output of the NN to CSV file(s)
|
324 |
+
if self.dump:
|
325 |
+
self.dump_nn_output(eval_res[1])
|
326 |
+
|
327 |
+
return texts, probs
|
328 |
+
|
329 |
+
def save(self) -> None:
|
330 |
+
"""
|
331 |
+
Save model to file.
|
332 |
+
"""
|
333 |
+
self.snap_ID += 1
|
334 |
+
self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID)
|
app/preprocessor.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from dataloader_iam import Batch
|
8 |
+
|
9 |
+
|
10 |
+
class Preprocessor:
|
11 |
+
def __init__(self,
|
12 |
+
img_size: Tuple[int, int],
|
13 |
+
padding: int = 0,
|
14 |
+
dynamic_width: bool = False,
|
15 |
+
data_augmentation: bool = False,
|
16 |
+
line_mode: bool = False) -> None:
|
17 |
+
# dynamic width only supported when no data augmentation happens
|
18 |
+
assert not (dynamic_width and data_augmentation)
|
19 |
+
# when padding is on, we need dynamic width enabled
|
20 |
+
assert not (padding > 0 and not dynamic_width)
|
21 |
+
|
22 |
+
self.img_size = img_size
|
23 |
+
self.padding = padding
|
24 |
+
self.dynamic_width = dynamic_width
|
25 |
+
self.data_augmentation = data_augmentation
|
26 |
+
self.line_mode = line_mode
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def _truncate_label(text: str, max_text_len: int) -> str:
|
30 |
+
"""
|
31 |
+
Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
|
32 |
+
labels. Repeat letters cost double because of the blank symbol needing to be inserted.
|
33 |
+
If a too-long label is provided, ctc_loss returns an infinite gradient.
|
34 |
+
"""
|
35 |
+
cost = 0
|
36 |
+
for i in range(len(text)):
|
37 |
+
if i != 0 and text[i] == text[i - 1]:
|
38 |
+
cost += 2
|
39 |
+
else:
|
40 |
+
cost += 1
|
41 |
+
if cost > max_text_len:
|
42 |
+
return text[:i]
|
43 |
+
return text
|
44 |
+
|
45 |
+
def _simulate_text_line(self, batch: Batch) -> Batch:
|
46 |
+
"""Create image of a text line by pasting multiple word images into an image."""
|
47 |
+
|
48 |
+
default_word_sep = 30
|
49 |
+
default_num_words = 5
|
50 |
+
|
51 |
+
# go over all batch elements
|
52 |
+
res_imgs = []
|
53 |
+
res_gt_texts = []
|
54 |
+
for i in range(batch.batch_size):
|
55 |
+
# number of words to put into current line
|
56 |
+
num_words = random.randint(1, 8) if self.data_augmentation else default_num_words
|
57 |
+
|
58 |
+
# concat ground truth texts
|
59 |
+
curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
|
60 |
+
res_gt_texts.append(curr_gt)
|
61 |
+
|
62 |
+
# put selected word images into list, compute target image size
|
63 |
+
sel_imgs = []
|
64 |
+
word_seps = [0]
|
65 |
+
h = 0
|
66 |
+
w = 0
|
67 |
+
for j in range(num_words):
|
68 |
+
curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
|
69 |
+
curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
|
70 |
+
h = max(h, curr_sel_img.shape[0])
|
71 |
+
w += curr_sel_img.shape[1]
|
72 |
+
sel_imgs.append(curr_sel_img)
|
73 |
+
if j + 1 < num_words:
|
74 |
+
w += curr_word_sep
|
75 |
+
word_seps.append(curr_word_sep)
|
76 |
+
|
77 |
+
# put all selected word images into target image
|
78 |
+
target = np.ones([h, w], np.uint8) * 255
|
79 |
+
x = 0
|
80 |
+
for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
|
81 |
+
x += curr_word_sep
|
82 |
+
y = (h - curr_sel_img.shape[0]) // 2
|
83 |
+
target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
|
84 |
+
x += curr_sel_img.shape[1]
|
85 |
+
|
86 |
+
# put image of line into result
|
87 |
+
res_imgs.append(target)
|
88 |
+
|
89 |
+
return Batch(res_imgs, res_gt_texts, batch.batch_size)
|
90 |
+
|
91 |
+
def process_img(self, img: np.ndarray) -> np.ndarray:
|
92 |
+
"""Resize to target size, apply data augmentation."""
|
93 |
+
|
94 |
+
# there are damaged files in IAM dataset - just use black image instead
|
95 |
+
if img is None:
|
96 |
+
img = np.zeros(self.img_size[::-1])
|
97 |
+
|
98 |
+
# data augmentation
|
99 |
+
img = img.astype(float)
|
100 |
+
if self.data_augmentation:
|
101 |
+
# photometric data augmentation
|
102 |
+
if random.random() < 0.25:
|
103 |
+
def rand_odd():
|
104 |
+
return random.randint(1, 3) * 2 + 1
|
105 |
+
img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
|
106 |
+
if random.random() < 0.25:
|
107 |
+
img = cv2.dilate(img, np.ones((3, 3)))
|
108 |
+
if random.random() < 0.25:
|
109 |
+
img = cv2.erode(img, np.ones((3, 3)))
|
110 |
+
|
111 |
+
# geometric data augmentation
|
112 |
+
wt, ht = self.img_size
|
113 |
+
h, w = img.shape
|
114 |
+
f = min(wt / w, ht / h)
|
115 |
+
fx = f * np.random.uniform(0.75, 1.05)
|
116 |
+
fy = f * np.random.uniform(0.75, 1.05)
|
117 |
+
|
118 |
+
# random position around center
|
119 |
+
txc = (wt - w * fx) / 2
|
120 |
+
tyc = (ht - h * fy) / 2
|
121 |
+
freedom_x = max((wt - fx * w) / 2, 0)
|
122 |
+
freedom_y = max((ht - fy * h) / 2, 0)
|
123 |
+
tx = txc + np.random.uniform(-freedom_x, freedom_x)
|
124 |
+
ty = tyc + np.random.uniform(-freedom_y, freedom_y)
|
125 |
+
|
126 |
+
# map image into target image
|
127 |
+
M = np.float32([[fx, 0, tx], [0, fy, ty]])
|
128 |
+
target = np.ones(self.img_size[::-1]) * 255
|
129 |
+
img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
|
130 |
+
|
131 |
+
# photometric data augmentation
|
132 |
+
if random.random() < 0.5:
|
133 |
+
img = img * (0.25 + random.random() * 0.75)
|
134 |
+
if random.random() < 0.25:
|
135 |
+
img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
|
136 |
+
if random.random() < 0.1:
|
137 |
+
img = 255 - img
|
138 |
+
|
139 |
+
# no data augmentation
|
140 |
+
else:
|
141 |
+
if self.dynamic_width:
|
142 |
+
ht = self.img_size[1]
|
143 |
+
h, w = img.shape
|
144 |
+
f = ht / h
|
145 |
+
wt = int(f * w + self.padding)
|
146 |
+
wt = wt + (4 - wt) % 4
|
147 |
+
tx = (wt - w * f) / 2
|
148 |
+
ty = 0
|
149 |
+
else:
|
150 |
+
wt, ht = self.img_size
|
151 |
+
h, w = img.shape
|
152 |
+
f = min(wt / w, ht / h)
|
153 |
+
tx = (wt - w * f) / 2
|
154 |
+
ty = (ht - h * f) / 2
|
155 |
+
|
156 |
+
# map image into target image
|
157 |
+
M = np.float32([[f, 0, tx], [0, f, ty]])
|
158 |
+
target = np.ones([ht, wt]) * 255
|
159 |
+
img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
|
160 |
+
|
161 |
+
# transpose for TF
|
162 |
+
img = cv2.transpose(img)
|
163 |
+
|
164 |
+
# convert to range [-1, 1]
|
165 |
+
img = img / 255 - 0.5
|
166 |
+
return img
|
167 |
+
|
168 |
+
def process_batch(self, batch: Batch) -> Batch:
|
169 |
+
if self.line_mode:
|
170 |
+
batch = self._simulate_text_line(batch)
|
171 |
+
|
172 |
+
res_imgs = [self.process_img(img) for img in batch.imgs]
|
173 |
+
max_text_len = res_imgs[0].shape[0] // 4
|
174 |
+
res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
|
175 |
+
return Batch(res_imgs, res_gt_texts, batch.batch_size)
|
176 |
+
|
177 |
+
|
178 |
+
def main():
|
179 |
+
import matplotlib.pyplot as plt
|
180 |
+
|
181 |
+
img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
|
182 |
+
img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
|
183 |
+
plt.subplot(121)
|
184 |
+
plt.imshow(img, cmap='gray')
|
185 |
+
plt.subplot(122)
|
186 |
+
plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
|
187 |
+
plt.show()
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == '__main__':
|
191 |
+
main()
|
app/runner.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
fileDir = os.path.dirname(os.path.realpath(__file__))
|
5 |
+
subprocess.run(["streamlit", "run", "webapp.py"], cwd = fileDir)
|
app/simple.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
from model import Model, DecoderType
|
5 |
+
from preprocessor import Preprocessor
|
6 |
+
from dataloader_iam import Batch
|
7 |
+
|
8 |
+
import tensorflow as tf
|
9 |
+
|
10 |
+
def get_img_size(line_mode: bool = False) -> tuple[int, int]:
|
11 |
+
"""
|
12 |
+
Auxiliary method that sets the height and width.
|
13 |
+
Height is fixed while width is set according to the Model used.
|
14 |
+
"""
|
15 |
+
if line_mode:
|
16 |
+
return 256, get_img_height()
|
17 |
+
return 128, get_img_height()
|
18 |
+
|
19 |
+
def get_img_height() -> int:
|
20 |
+
"""
|
21 |
+
Auxiliary method that sets the fixed height for the Neural Network.
|
22 |
+
"""
|
23 |
+
return 32
|
24 |
+
|
25 |
+
def infer(line_mode: bool, model: Model, fn_img: str) -> None:
|
26 |
+
"""
|
27 |
+
Auxiliary method that does inference using the pretrained models:
|
28 |
+
Recognizes text in an image given its path.
|
29 |
+
"""
|
30 |
+
img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
|
31 |
+
assert img is not None
|
32 |
+
|
33 |
+
preprocessor = Preprocessor(get_img_size(line_mode), dynamic_width=True, padding=16)
|
34 |
+
img = preprocessor.process_img(img)
|
35 |
+
|
36 |
+
batch = Batch([img], None, 1)
|
37 |
+
recognized, probability = model.infer_batch(batch, True)
|
38 |
+
return recognized, probability
|
39 |
+
|
40 |
+
def main(image_path: str, model_path: str, decoder_type: DecoderType):
|
41 |
+
"""
|
42 |
+
Main function to load the model, perform inference on the input image,
|
43 |
+
and print the result.
|
44 |
+
"""
|
45 |
+
# Load the model
|
46 |
+
char_list_path = model_path + "/charList.txt"
|
47 |
+
model = Model(list(open(char_list_path).read()), model_path, decoder_type, must_restore=True)
|
48 |
+
|
49 |
+
# Perform inference
|
50 |
+
recognized, probability = infer(model_path.endswith('line-model'), model, image_path)
|
51 |
+
|
52 |
+
# Print the results
|
53 |
+
print("Recognized Text:", recognized[0])
|
54 |
+
print("Probability:", probability[0])
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
# Example usage
|
58 |
+
# Define the image path, model directory, and decoder type here
|
59 |
+
image_path = 'word.png' # Update this path
|
60 |
+
model_path = '../model/word-model' # or '../model/line-model' depending on your model
|
61 |
+
decoder_type = DecoderType.BestPath # Change as needed: BestPath, BeamSearch, WordBeamSearch
|
62 |
+
|
63 |
+
# Call the main function with the specified parameters
|
64 |
+
main(image_path, model_path, decoder_type)
|
app/userInput.png
ADDED
app/webapp.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from path import Path
|
6 |
+
import streamlit as st
|
7 |
+
from typing import Tuple
|
8 |
+
from dataloader_iam import Batch
|
9 |
+
from model import Model, DecoderType
|
10 |
+
from preprocessor import Preprocessor
|
11 |
+
from streamlit_drawable_canvas import st_canvas
|
12 |
+
|
13 |
+
|
14 |
+
def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
|
15 |
+
"""
|
16 |
+
Auxiliary method that sets the height and width
|
17 |
+
Height is fixed while width is set according to the Model used.
|
18 |
+
"""
|
19 |
+
if line_mode:
|
20 |
+
return 256, get_img_height()
|
21 |
+
return 128, get_img_height()
|
22 |
+
|
23 |
+
def get_img_height() -> int:
|
24 |
+
"""
|
25 |
+
Auxiliary method that sets the height, which is fixed for the Neural Network.
|
26 |
+
"""
|
27 |
+
return 32
|
28 |
+
|
29 |
+
def infer(line_mode: bool, model: Model, fn_img: Path) -> None:
|
30 |
+
"""
|
31 |
+
Auxiliary method that does inference using the pretrained models:
|
32 |
+
Recognizes text in an image given its path.
|
33 |
+
"""
|
34 |
+
img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
|
35 |
+
assert img is not None
|
36 |
+
|
37 |
+
preprocessor = Preprocessor(get_img_size(line_mode), dynamic_width=True, padding=16)
|
38 |
+
img = preprocessor.process_img(img)
|
39 |
+
|
40 |
+
batch = Batch([img], None, 1)
|
41 |
+
recognized, probability = model.infer_batch(batch, True)
|
42 |
+
return [recognized, probability]
|
43 |
+
|
44 |
+
def main():
|
45 |
+
|
46 |
+
#Website properties
|
47 |
+
st.set_page_config(
|
48 |
+
page_title = "HTR App",
|
49 |
+
page_icon = ":pencil:",
|
50 |
+
layout = "centered",
|
51 |
+
initial_sidebar_state = "auto",
|
52 |
+
)
|
53 |
+
|
54 |
+
st.title('HTR Simple Application')
|
55 |
+
|
56 |
+
st.markdown("""
|
57 |
+
Streamlit Web Interface for Handwritten Text Recognition (HTR), implemented with TensorFlow and trained on the IAM off-line HTR dataset. The model takes images of single words or text lines (multiple words) as input and outputs the recognized text.
|
58 |
+
""", unsafe_allow_html=True)
|
59 |
+
|
60 |
+
st.markdown("""
|
61 |
+
Predictions can be made using one of two models:
|
62 |
+
- [Model 1](https://www.dropbox.com/s/mya8hw6jyzqm0a3/word-model.zip?dl=1) (Trained on Single Word Images)
|
63 |
+
- [Model 2](https://www.dropbox.com/s/7xwkcilho10rthn/line-model.zip?dl=1) (Trained on Text Line Images)
|
64 |
+
""", unsafe_allow_html=True)
|
65 |
+
|
66 |
+
st.subheader('Select a Model, Choose the Arguments and Draw in the box below or Upload an Image to obtain a prediction.')
|
67 |
+
|
68 |
+
#Selectors for the model and decoder
|
69 |
+
modelSelect = st.selectbox("Select a Model", ['Single_Model', 'Line_Model'])
|
70 |
+
|
71 |
+
decoderSelect = st.selectbox("Select a Decoder", ['Bestpath', 'Beamsearch', 'Wordbeamsearch'])
|
72 |
+
|
73 |
+
#Mappings (dictionaries) for the model and decoder. Asigns the directory or the DecoderType of the selected option.
|
74 |
+
modelMapping = {
|
75 |
+
"Single_Model": '../model/word-model',
|
76 |
+
"Line_Model": '../model/line-model'
|
77 |
+
}
|
78 |
+
|
79 |
+
decoderMapping = {
|
80 |
+
'Bestpath': DecoderType.BestPath,
|
81 |
+
'Beamsearch': DecoderType.BeamSearch,
|
82 |
+
'Wordbeamsearch': DecoderType.WordBeamSearch
|
83 |
+
}
|
84 |
+
|
85 |
+
#Slider for pencil width
|
86 |
+
strokeWidth = st.slider("Stroke Width: ", 1, 25, 6)
|
87 |
+
|
88 |
+
#Canvas/Text Box for user input. BackGround Color must be white (#FFFFFF) or else text will not be properly recognised.
|
89 |
+
inputDrawn = st_canvas(
|
90 |
+
fill_color="rgba(255, 165, 0, 0.3)",
|
91 |
+
stroke_width=strokeWidth,
|
92 |
+
update_streamlit=True,
|
93 |
+
height = 200,
|
94 |
+
width = 400,
|
95 |
+
drawing_mode='freedraw',
|
96 |
+
key="canvas",
|
97 |
+
background_color = '#FFFFFF'
|
98 |
+
)
|
99 |
+
|
100 |
+
#Buffer for user input (images uploaded from the user's device)
|
101 |
+
inputBuffer = st.file_uploader("Upload an Image", type=["png"])
|
102 |
+
|
103 |
+
#Infer Button
|
104 |
+
inferBool = st.button("Recognize Word")
|
105 |
+
|
106 |
+
#We start infering once we have the user input and he presses the Infer button.
|
107 |
+
if ((inputDrawn.image_data is not None or inputBuffer is not None) and inferBool == True):
|
108 |
+
|
109 |
+
#We turn the input into a numpy array
|
110 |
+
if inputDrawn.image_data is not None:
|
111 |
+
inputArray = np.array(inputDrawn.image_data)
|
112 |
+
|
113 |
+
if inputBuffer is not None:
|
114 |
+
inputBufferImage = Image.open(inputBuffer)
|
115 |
+
inputArray = np.array(inputBufferImage)
|
116 |
+
|
117 |
+
#We turn this array into a .png format and save it.
|
118 |
+
inputImage = Image.fromarray(inputArray.astype('uint8'), 'RGBA')
|
119 |
+
inputImage.save('userInput.png')
|
120 |
+
#We obtain the model directory and the decoder type from their mapping
|
121 |
+
modelDir = modelMapping[modelSelect]
|
122 |
+
decoderType = decoderMapping[decoderSelect]
|
123 |
+
|
124 |
+
#Finally, we call the model with this image as attribute and display the Best Candidate and its probability on the Interface
|
125 |
+
model = Model(list(open(modelDir + "/charList.txt").read()), modelDir, decoderType, must_restore=True)
|
126 |
+
inferedText = infer(modelDir == '../model/line-model', model, 'userInput.png')
|
127 |
+
|
128 |
+
st.write("**Best Candidate: **", inferedText[0][0])
|
129 |
+
st.write("**Probability: **", str(inferedText[1][0]*100) + "%")
|
130 |
+
|
131 |
+
if __name__ == "__main__":
|
132 |
+
main()
|
app/word.png
ADDED