Mattral commited on
Commit
57462b3
·
verified ·
1 Parent(s): 19e0cf1

Upload 13 files

Browse files
app/FinalApp.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from path import Path
6
+ import streamlit as st
7
+ from typing import Tuple
8
+ from dataloader_iam import Batch
9
+ from model import Model, DecoderType
10
+ from preprocessor import Preprocessor
11
+ from streamlit_drawable_canvas import st_canvas
12
+ import easyocr # Import EasyOCR
13
+
14
+ # Set page config at the very beginning (only executed once)
15
+ st.set_page_config(
16
+ page_title="HTR App",
17
+ page_icon=":pencil:",
18
+ layout="centered",
19
+ initial_sidebar_state="auto",
20
+ )
21
+
22
+ ms = st.session_state
23
+ if "themes" not in ms:
24
+ ms.themes = {"current_theme": "light",
25
+ "refreshed": True,
26
+
27
+ "light": {"theme.base": "dark",
28
+ "theme.backgroundColor": "black",
29
+ "theme.primaryColor": "#c98bdb",
30
+ "theme.secondaryBackgroundColor": "#5591f5",
31
+ "theme.textColor": "white",
32
+ "theme.textColor": "white",
33
+ "button_face": "🌜"},
34
+
35
+ "dark": {"theme.base": "light",
36
+ "theme.backgroundColor": "white",
37
+ "theme.primaryColor": "#5591f5",
38
+ "theme.secondaryBackgroundColor": "#82E1D7",
39
+ "theme.textColor": "#0a1464",
40
+ "button_face": "🌞"},
41
+ }
42
+
43
+
44
+ def ChangeTheme():
45
+ previous_theme = ms.themes["current_theme"]
46
+ tdict = ms.themes["light"] if ms.themes["current_theme"] == "light" else ms.themes["dark"]
47
+ for vkey, vval in tdict.items():
48
+ if vkey.startswith("theme"): st._config.set_option(vkey, vval)
49
+
50
+ ms.themes["refreshed"] = False
51
+ if previous_theme == "dark": ms.themes["current_theme"] = "light"
52
+ elif previous_theme == "light": ms.themes["current_theme"] = "dark"
53
+
54
+
55
+ btn_face = ms.themes["light"]["button_face"] if ms.themes["current_theme"] == "light" else ms.themes["dark"]["button_face"]
56
+ st.button(btn_face, on_click=ChangeTheme)
57
+
58
+ if ms.themes["refreshed"] == False:
59
+ ms.themes["refreshed"] = True
60
+ st.rerun()
61
+
62
+
63
+ def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
64
+ """
65
+ Auxiliary method that sets the height and width
66
+ Height is fixed while width is set according to the Model used.
67
+ """
68
+ if line_mode:
69
+ return 256, get_img_height()
70
+ return 128, get_img_height()
71
+
72
+ def get_img_height() -> int:
73
+ """
74
+ Auxiliary method that sets the height, which is fixed for the Neural Network.
75
+ """
76
+ return 32
77
+
78
+ def infer(line_mode: bool, model: Model, fn_img: Path) -> None:
79
+ """
80
+ Auxiliary method that does inference using the pretrained models:
81
+ Recognizes text in an image given its path.
82
+ """
83
+ img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
84
+ assert img is not None
85
+
86
+ preprocessor = Preprocessor(get_img_size(line_mode), dynamic_width=True, padding=16)
87
+ img = preprocessor.process_img(img)
88
+
89
+ batch = Batch([img], None, 1)
90
+ recognized, probability = model.infer_batch(batch, True)
91
+ return [recognized, probability]
92
+
93
+ def infer_super_model(image_path) -> None:
94
+ reader = easyocr.Reader(['en']) # Initialize EasyOCR reader
95
+ result = reader.readtext(image_path)
96
+ recognized_texts = [text[1] for text in result] # Extract recognized texts
97
+ probabilities = [text[2] for text in result] # Extract probabilities
98
+ return recognized_texts, probabilities
99
+
100
+
101
+
102
+ def main():
103
+
104
+ st.title('Extract text from Image Demo')
105
+
106
+ st.markdown("""
107
+ Streamlit Web Interface for Handwritten Text Recognition (HTR), Optical Character Recognition (OCR)
108
+ implemented with TensorFlow and trained on the IAM off-line HTR dataset.
109
+ The model takes images of single words or text lines (multiple words) as input and outputs the recognized text.
110
+ """, unsafe_allow_html=True)
111
+
112
+ st.markdown("""
113
+ Predictions can be made using one of two models:
114
+ - Single_Model (Trained on Single Word Images)
115
+ - Line_Model (Trained on Text Line Images)
116
+ - Super_Model ( Most Robust Option for English )
117
+ - Burmese (Link)
118
+ """, unsafe_allow_html=True)
119
+
120
+ st.subheader('Select a Model, Choose the Arguments and Draw in the box below or Upload an Image to obtain a prediction.')
121
+
122
+ #Selectors for the model and decoder
123
+ modelSelect = st.selectbox("Select a Model", ['Single_Model', 'Line_Model', 'Super_Model'])
124
+
125
+
126
+ if modelSelect != 'Super_Model':
127
+ decoderSelect = st.selectbox("Select a Decoder", ['Bestpath', 'Beamsearch', 'Wordbeamsearch'])
128
+
129
+
130
+ #Mappings (dictionaries) for the model and decoder. Asigns the directory or the DecoderType of the selected option.
131
+ modelMapping = {
132
+ "Single_Model": '../model/word-model',
133
+ "Line_Model": '../model/line-model'
134
+ }
135
+
136
+ decoderMapping = {
137
+ 'Bestpath': DecoderType.BestPath,
138
+ 'Beamsearch': DecoderType.BeamSearch,
139
+ 'Wordbeamsearch': DecoderType.WordBeamSearch
140
+ }
141
+
142
+ #Slider for pencil width
143
+ strokeWidth = st.slider("Stroke Width: ", 1, 25, 6)
144
+
145
+ #Canvas/Text Box for user input. BackGround Color must be white (#FFFFFF) or else text will not be properly recognised.
146
+ inputDrawn = st_canvas(
147
+ fill_color="rgba(255, 165, 0, 0.3)",
148
+ stroke_width=strokeWidth,
149
+ update_streamlit=True,
150
+ background_image=None,
151
+ height = 200,
152
+ width = 400,
153
+ drawing_mode='freedraw',
154
+ key="canvas",
155
+ background_color = '#FFFFFF'
156
+ )
157
+
158
+ #Buffer for user input (images uploaded from the user's device)
159
+ inputBuffer = st.file_uploader("Upload an Image", type=["png"])
160
+
161
+ #Inference Button
162
+ inferBool = st.button("Recognize Text")
163
+
164
+ # After clicking the "Recognize Text" button, check if the model selected is Super_Model
165
+ if inferBool:
166
+ if modelSelect == 'Super_Model':
167
+ inputArray = None # Initialize inputArray to None
168
+
169
+ # Handling uploaded file
170
+ if inputBuffer is not None:
171
+ with Image.open(inputBuffer).convert('RGB') as img:
172
+ inputArray = np.array(img)
173
+
174
+ # Handling canvas data
175
+ elif inputDrawn.image_data is not None:
176
+ # Convert RGBA to RGB
177
+ inputArray = cv2.cvtColor(np.array(inputDrawn.image_data, dtype=np.uint8), cv2.COLOR_RGBA2RGB)
178
+
179
+ # Now check if inputArray has been set
180
+ if inputArray is not None:
181
+ # Initialize EasyOCR Reader
182
+ reader = easyocr.Reader(['en']) # Assuming English language; adjust as necessary
183
+ # Perform OCR
184
+ results = reader.readtext(inputArray)
185
+
186
+ # Display results
187
+ all_text = ''
188
+ for (bbox, text, prob) in results:
189
+ all_text += f'{text} (confidence: {prob:.2f})\n'
190
+
191
+ st.write("**Recognized Texts and their Confidence Scores:**")
192
+ st.text(all_text)
193
+ else:
194
+ st.write("No image data found. Please upload an image or draw on the canvas.")
195
+
196
+
197
+ else:
198
+ # Handle other model selections as before
199
+ if ((inputDrawn.image_data is not None or inputBuffer is not None) and inferBool == True):
200
+ #We turn the input into a numpy array
201
+ if inputDrawn.image_data is not None:
202
+ inputArray = np.array(inputDrawn.image_data)
203
+
204
+ if inputBuffer is not None:
205
+ inputBufferImage = Image.open(inputBuffer)
206
+ inputArray = np.array(inputBufferImage)
207
+
208
+ #We turn this array into a .png format and save it.
209
+ inputImage = Image.fromarray(inputArray.astype('uint8'), 'RGBA')
210
+ inputImage.save('userInput.png')
211
+ #We obtain the model directory and the decoder type from their mapping
212
+ modelDir = modelMapping[modelSelect]
213
+ decoderType = decoderMapping[decoderSelect]
214
+
215
+ #Finally, we call the model with this image as attribute and display the Best Candidate and its probability on the Interface
216
+ model = Model(list(open(modelDir + "/charList.txt").read()), modelDir, decoderType, must_restore=True)
217
+ inferedText = infer(modelDir == '../model/line-model', model, 'userInput.png')
218
+
219
+ st.write("**Best Candidate: **", inferedText[0][0])
220
+ st.write("**Probability: **", str(inferedText[1][0]*100) + "%")
221
+
222
+ if __name__ == "__main__":
223
+ main()
app/__init__.py ADDED
File without changes
app/__pycache__/dataloader_iam.cpython-311.pyc ADDED
Binary file (8.17 kB). View file
 
app/__pycache__/model.cpython-311.pyc ADDED
Binary file (19.5 kB). View file
 
app/__pycache__/preprocessor.cpython-311.pyc ADDED
Binary file (10.7 kB). View file
 
app/dataloader_iam.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ from collections import namedtuple
4
+ from typing import Tuple
5
+
6
+ import cv2
7
+ from imdb import Cinemagoer
8
+ import numpy as np
9
+ from path import Path
10
+
11
+ Sample = namedtuple('Sample', 'gt_text, file_path')
12
+ Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')
13
+
14
+
15
+ class DataLoaderIAM:
16
+ """
17
+ Loads data which corresponds to IAM format,
18
+ see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
19
+ """
20
+
21
+ def __init__(self,
22
+ data_dir: Path,
23
+ batch_size: int,
24
+ data_split: float = 0.95,
25
+ fast: bool = True) -> None:
26
+ """Loader for dataset."""
27
+
28
+ assert data_dir.exists()
29
+
30
+ self.fast = fast
31
+ if fast:
32
+ self.env = Cinemagoer.open(str(data_dir / 'lmdb'), readonly=True)
33
+
34
+ self.data_augmentation = False
35
+ self.curr_idx = 0
36
+ self.batch_size = batch_size
37
+ self.samples = []
38
+
39
+ f = open(data_dir / 'gt/words.txt')
40
+ chars = set()
41
+ bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset
42
+ for line in f:
43
+ # ignore comment line
44
+ if not line or line[0] == '#':
45
+ continue
46
+
47
+ line_split = line.strip().split(' ')
48
+ assert len(line_split) >= 9
49
+
50
+ # filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
51
+ file_name_split = line_split[0].split('-')
52
+ file_name_subdir1 = file_name_split[0]
53
+ file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}'
54
+ file_base_name = line_split[0] + '.png'
55
+ file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name
56
+
57
+ if line_split[0] in bad_samples_reference:
58
+ print('Ignoring known broken image:', file_name)
59
+ continue
60
+
61
+ # GT text are columns starting at 9
62
+ gt_text = ' '.join(line_split[8:])
63
+ chars = chars.union(set(list(gt_text)))
64
+
65
+ # put sample into list
66
+ self.samples.append(Sample(gt_text, file_name))
67
+
68
+ # split into training and validation set: 95% - 5%
69
+ split_idx = int(data_split * len(self.samples))
70
+ self.train_samples = self.samples[:split_idx]
71
+ self.validation_samples = self.samples[split_idx:]
72
+
73
+ # put words into lists
74
+ self.train_words = [x.gt_text for x in self.train_samples]
75
+ self.validation_words = [x.gt_text for x in self.validation_samples]
76
+
77
+ # start with train set
78
+ self.train_set()
79
+
80
+ # list of all chars in dataset
81
+ self.char_list = sorted(list(chars))
82
+
83
+ def train_set(self) -> None:
84
+ """Switch to randomly chosen subset of training set."""
85
+ self.data_augmentation = True
86
+ self.curr_idx = 0
87
+ random.shuffle(self.train_samples)
88
+ self.samples = self.train_samples
89
+ self.curr_set = 'train'
90
+
91
+ def validation_set(self) -> None:
92
+ """Switch to validation set."""
93
+ self.data_augmentation = False
94
+ self.curr_idx = 0
95
+ self.samples = self.validation_samples
96
+ self.curr_set = 'val'
97
+
98
+ def get_iterator_info(self) -> Tuple[int, int]:
99
+ """Current batch index and overall number of batches."""
100
+ if self.curr_set == 'train':
101
+ num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches
102
+ else:
103
+ num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller
104
+ curr_batch = self.curr_idx // self.batch_size + 1
105
+ return curr_batch, num_batches
106
+
107
+ def has_next(self) -> bool:
108
+ """Is there a next element?"""
109
+ if self.curr_set == 'train':
110
+ return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches
111
+ else:
112
+ return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller
113
+
114
+ def _get_img(self, i: int) -> np.ndarray:
115
+ if self.fast:
116
+ with self.env.begin() as txn:
117
+ basename = Path(self.samples[i].file_path).basename()
118
+ data = txn.get(basename.encode("ascii"))
119
+ img = pickle.loads(data)
120
+ else:
121
+ img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
122
+
123
+ return img
124
+
125
+ def get_next(self) -> Batch:
126
+ """Get next element."""
127
+ batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
128
+
129
+ imgs = [self._get_img(i) for i in batch_range]
130
+ gt_texts = [self.samples[i].gt_text for i in batch_range]
131
+
132
+ self.curr_idx += self.batch_size
133
+ return Batch(imgs, gt_texts, len(imgs))
app/model.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List, Tuple
4
+ import tf_keras as keras
5
+ import numpy as np
6
+ from dataloader_iam import Batch
7
+
8
+ import tensorflow.compat.v1 as tf
9
+ tf.compat.v1.disable_v2_behavior
10
+
11
+ # Disable eager mode
12
+ tf.compat.v1.disable_eager_execution
13
+
14
+ class DecoderType:
15
+ """
16
+ CTC decoder types.
17
+ """
18
+ BestPath = 0
19
+ BeamSearch = 1
20
+ WordBeamSearch = 2
21
+
22
+
23
+ class Model:
24
+ """
25
+ Minimalistic TF model for HTR.
26
+ """
27
+
28
+ def __init__(self,
29
+ char_list: List[str],
30
+ model_dir: str,
31
+ decoder_type: str = DecoderType.BestPath,
32
+ must_restore: bool = False,
33
+ dump: bool = False) -> None:
34
+ """
35
+ Init model: add CNN, RNN and CTC and initialize TF.
36
+ """
37
+ self.dump = dump
38
+ self.char_list = char_list
39
+ self.decoder_type = decoder_type
40
+ self.must_restore = must_restore
41
+ self.snap_ID = 0
42
+ self.model_dir = model_dir
43
+
44
+ tf.compat.v1.disable_eager_execution()
45
+ # Whether to use normalization over a batch or a population
46
+ self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train')
47
+
48
+ # input image batch
49
+ self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None))
50
+
51
+ # setup CNN, RNN and CTC
52
+ self.setup_cnn()
53
+ self.setup_rnn()
54
+ self.setup_ctc()
55
+
56
+ # setup optimizer to train NN
57
+ self.batches_trained = 0
58
+ self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
59
+ with tf.control_dependencies(self.update_ops):
60
+ self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss)
61
+
62
+ # initialize TF
63
+ self.sess, self.saver = self.setup_tf()
64
+
65
+ def setup_cnn(self) -> None:
66
+ """
67
+ Create CNN layers.
68
+ """
69
+ cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3)
70
+
71
+ # list of parameters for the layers
72
+ kernel_vals = [5, 5, 3, 3, 3]
73
+ feature_vals = [1, 32, 64, 128, 128, 256]
74
+ stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)]
75
+ num_layers = len(stride_vals)
76
+
77
+ # create layers
78
+ pool = cnn_in4d # input to first CNN layer
79
+ for i in range(num_layers):
80
+ kernel = tf.Variable(
81
+ tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]],
82
+ stddev=0.1))
83
+ conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1))
84
+ conv_norm = tf.keras.layers.BatchNormalization()(conv, training=self.is_train)
85
+ relu = tf.nn.relu(conv_norm)
86
+ pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1),
87
+ strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID')
88
+
89
+ self.cnn_out_4d = pool
90
+
91
+ def setup_rnn(self) -> None:
92
+ """
93
+ Create RNN layers.
94
+ """
95
+ rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2])
96
+
97
+ # basic cells which is used to build RNN
98
+ num_hidden = 256
99
+ cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in
100
+ range(2)] # 2 layers
101
+
102
+ # stack basic cells
103
+ stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True)
104
+
105
+ # bidirectional RNN
106
+ # BxTxF -> BxTx2H
107
+ (fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d,
108
+ dtype=rnn_in3d.dtype)
109
+
110
+ # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
111
+ concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)
112
+
113
+ # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
114
+ kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1))
115
+ self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'),
116
+ axis=[2])
117
+
118
+ def setup_ctc(self) -> None:
119
+ """
120
+ Create CTC loss and decoder.
121
+ """
122
+ # BxTxC -> TxBxC
123
+ self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2])
124
+ # ground truth text as sparse tensor
125
+ self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]),
126
+ tf.compat.v1.placeholder(tf.int32, [None]),
127
+ tf.compat.v1.placeholder(tf.int64, [2]))
128
+
129
+ # calc loss for batch
130
+ self.seq_len = tf.compat.v1.placeholder(tf.int32, [None])
131
+ self.loss = tf.reduce_mean(
132
+ input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc,
133
+ sequence_length=self.seq_len,
134
+ ctc_merge_repeated=True))
135
+
136
+ # calc loss for each element to compute label probability
137
+ self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32,
138
+ shape=[None, None, len(self.char_list) + 1])
139
+ self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input,
140
+ sequence_length=self.seq_len, ctc_merge_repeated=True)
141
+
142
+ # best path decoding or beam search decoding
143
+ if self.decoder_type == DecoderType.BestPath:
144
+ self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len)
145
+ elif self.decoder_type == DecoderType.BeamSearch:
146
+ self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len,
147
+ beam_width=50)
148
+ # word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch)
149
+ elif self.decoder_type == DecoderType.WordBeamSearch:
150
+ # prepare information about language (dictionary, characters in dataset, characters forming words)
151
+ chars = ''.join(self.char_list)
152
+ word_chars = open('../model/wordCharList.txt').read().splitlines()[0]
153
+ corpus = open('../data/corpus.txt').read()
154
+
155
+ # decode using the "Words" mode of word beam search
156
+ from word_beam_search import WordBeamSearch
157
+ self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'),
158
+ word_chars.encode('utf8'))
159
+
160
+ # the input to the decoder must have softmax already applied
161
+ self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2)
162
+
163
+ def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]:
164
+ """
165
+ Initialize TF.
166
+ """
167
+ print('Python: ' + sys.version)
168
+ print('Tensorflow: ' + tf.__version__)
169
+
170
+ sess = tf.compat.v1.Session() # TF session
171
+
172
+ saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file
173
+ latest_snapshot = tf.train.latest_checkpoint(self.model_dir ) # is there a saved model?
174
+
175
+ # if model must be restored (for inference), there must be a snapshot
176
+ if self.must_restore and not latest_snapshot:
177
+ raise Exception('No saved model found in: ' + model_dir)
178
+
179
+ # load saved model if available
180
+ if latest_snapshot:
181
+ print('Init with stored values from ' + latest_snapshot)
182
+ saver.restore(sess, latest_snapshot)
183
+ else:
184
+ print('Init with new values')
185
+ sess.run(tf.compat.v1.global_variables_initializer())
186
+
187
+ return sess, saver
188
+
189
+ def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]:
190
+ """
191
+ Put ground truth texts into sparse tensor for ctc_loss.
192
+ """
193
+ indices = []
194
+ values = []
195
+ shape = [len(texts), 0] # last entry must be max(labelList[i])
196
+
197
+ # go over all texts
198
+ for batchElement, text in enumerate(texts):
199
+ # convert to string of label (i.e. class-ids)
200
+ label_str = [self.char_list.index(c) for c in text]
201
+ # sparse tensor must have size of max. label-string
202
+ if len(label_str) > shape[1]:
203
+ shape[1] = len(label_str)
204
+ # put each label into sparse tensor
205
+ for i, label in enumerate(label_str):
206
+ indices.append([batchElement, i])
207
+ values.append(label)
208
+
209
+ return indices, values, shape
210
+
211
+ def decoder_output_to_text(self, ctc_output: tuple, batch_size: int) -> List[str]:
212
+ """
213
+ Extract texts from output of CTC decoder.
214
+ """
215
+
216
+ # word beam search: already contains label strings
217
+ if self.decoder_type == DecoderType.WordBeamSearch:
218
+ label_strs = ctc_output
219
+
220
+ # TF decoders: label strings are contained in sparse tensor
221
+ else:
222
+ # ctc returns tuple, first element is SparseTensor
223
+ decoded = ctc_output[0][0]
224
+
225
+ # contains string of labels for each batch element
226
+ label_strs = [[] for _ in range(batch_size)]
227
+
228
+ # go over all indices and save mapping: batch -> values
229
+ for (idx, idx2d) in enumerate(decoded.indices):
230
+ label = decoded.values[idx]
231
+ batch_element = idx2d[0] # index according to [b,t]
232
+ label_strs[batch_element].append(label)
233
+
234
+ # map labels to chars for all batch elements
235
+ return [''.join([self.char_list[c] for c in labelStr]) for labelStr in label_strs]
236
+
237
+ def train_batch(self, batch: Batch) -> float:
238
+ """
239
+ Feed a batch into the NN to train it.
240
+ """
241
+ num_batch_elements = len(batch.imgs)
242
+ max_text_len = batch.imgs[0].shape[0] // 4
243
+ sparse = self.to_sparse(batch.gt_texts)
244
+ eval_list = [self.optimizer, self.loss]
245
+ feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse,
246
+ self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True}
247
+ _, loss_val = self.sess.run(eval_list, feed_dict)
248
+ self.batches_trained += 1
249
+ return loss_val
250
+
251
+ @staticmethod
252
+ def dump_nn_output(rnn_output: np.ndarray) -> None:
253
+ """
254
+ Dump the output of the NN to CSV file(s).
255
+ """
256
+ dump_dir = '../dump/'
257
+ if not os.path.isdir(dump_dir):
258
+ os.mkdir(dump_dir)
259
+
260
+ # iterate over all batch elements and create a CSV file for each one
261
+ max_t, max_b, max_c = rnn_output.shape
262
+ for b in range(max_b):
263
+ csv = ''
264
+ for t in range(max_t):
265
+ for c in range(max_c):
266
+ csv += str(rnn_output[t, b, c]) + ';'
267
+ csv += '\n'
268
+ fn = dump_dir + 'rnnOutput_' + str(b) + '.csv'
269
+ print('Write dump of NN to file: ' + fn)
270
+ with open(fn, 'w') as f:
271
+ f.write(csv)
272
+
273
+ def infer_batch(self, batch: Batch, calc_probability: bool = False, probability_of_gt: bool = False):
274
+ """
275
+ Feed a batch into the NN to recognize the texts.
276
+ """
277
+
278
+ # decode, optionally save RNN output
279
+ num_batch_elements = len(batch.imgs)
280
+
281
+ # put tensors to be evaluated into list
282
+ eval_list = []
283
+
284
+ if self.decoder_type == DecoderType.WordBeamSearch:
285
+ eval_list.append(self.wbs_input)
286
+ else:
287
+ eval_list.append(self.decoder)
288
+
289
+ if self.dump or calc_probability:
290
+ eval_list.append(self.ctc_in_3d_tbc)
291
+
292
+ # sequence length depends on input image size (model downsizes width by 4)
293
+ max_text_len = batch.imgs[0].shape[0] // 4
294
+
295
+ # dict containing all tensor fed into the model
296
+ feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements,
297
+ self.is_train: False}
298
+
299
+ # evaluate model
300
+ eval_res = self.sess.run(eval_list, feed_dict)
301
+
302
+ # TF decoders: decoding already done in TF graph
303
+ if self.decoder_type != DecoderType.WordBeamSearch:
304
+ decoded = eval_res[0]
305
+ # word beam search decoder: decoding is done in C++ function compute()
306
+ else:
307
+ decoded = self.decoder.compute(eval_res[0])
308
+
309
+ # map labels (numbers) to character string
310
+ texts = self.decoder_output_to_text(decoded, num_batch_elements)
311
+
312
+ # feed RNN output and recognized text into CTC loss to compute labeling probability
313
+ probs = None
314
+ if calc_probability:
315
+ sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts)
316
+ ctc_input = eval_res[1]
317
+ eval_list = self.loss_per_element
318
+ feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse,
319
+ self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False}
320
+ loss_vals = self.sess.run(eval_list, feed_dict)
321
+ probs = np.exp(-loss_vals)
322
+
323
+ # dump the output of the NN to CSV file(s)
324
+ if self.dump:
325
+ self.dump_nn_output(eval_res[1])
326
+
327
+ return texts, probs
328
+
329
+ def save(self) -> None:
330
+ """
331
+ Save model to file.
332
+ """
333
+ self.snap_ID += 1
334
+ self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID)
app/preprocessor.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from dataloader_iam import Batch
8
+
9
+
10
+ class Preprocessor:
11
+ def __init__(self,
12
+ img_size: Tuple[int, int],
13
+ padding: int = 0,
14
+ dynamic_width: bool = False,
15
+ data_augmentation: bool = False,
16
+ line_mode: bool = False) -> None:
17
+ # dynamic width only supported when no data augmentation happens
18
+ assert not (dynamic_width and data_augmentation)
19
+ # when padding is on, we need dynamic width enabled
20
+ assert not (padding > 0 and not dynamic_width)
21
+
22
+ self.img_size = img_size
23
+ self.padding = padding
24
+ self.dynamic_width = dynamic_width
25
+ self.data_augmentation = data_augmentation
26
+ self.line_mode = line_mode
27
+
28
+ @staticmethod
29
+ def _truncate_label(text: str, max_text_len: int) -> str:
30
+ """
31
+ Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
32
+ labels. Repeat letters cost double because of the blank symbol needing to be inserted.
33
+ If a too-long label is provided, ctc_loss returns an infinite gradient.
34
+ """
35
+ cost = 0
36
+ for i in range(len(text)):
37
+ if i != 0 and text[i] == text[i - 1]:
38
+ cost += 2
39
+ else:
40
+ cost += 1
41
+ if cost > max_text_len:
42
+ return text[:i]
43
+ return text
44
+
45
+ def _simulate_text_line(self, batch: Batch) -> Batch:
46
+ """Create image of a text line by pasting multiple word images into an image."""
47
+
48
+ default_word_sep = 30
49
+ default_num_words = 5
50
+
51
+ # go over all batch elements
52
+ res_imgs = []
53
+ res_gt_texts = []
54
+ for i in range(batch.batch_size):
55
+ # number of words to put into current line
56
+ num_words = random.randint(1, 8) if self.data_augmentation else default_num_words
57
+
58
+ # concat ground truth texts
59
+ curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
60
+ res_gt_texts.append(curr_gt)
61
+
62
+ # put selected word images into list, compute target image size
63
+ sel_imgs = []
64
+ word_seps = [0]
65
+ h = 0
66
+ w = 0
67
+ for j in range(num_words):
68
+ curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
69
+ curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
70
+ h = max(h, curr_sel_img.shape[0])
71
+ w += curr_sel_img.shape[1]
72
+ sel_imgs.append(curr_sel_img)
73
+ if j + 1 < num_words:
74
+ w += curr_word_sep
75
+ word_seps.append(curr_word_sep)
76
+
77
+ # put all selected word images into target image
78
+ target = np.ones([h, w], np.uint8) * 255
79
+ x = 0
80
+ for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
81
+ x += curr_word_sep
82
+ y = (h - curr_sel_img.shape[0]) // 2
83
+ target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
84
+ x += curr_sel_img.shape[1]
85
+
86
+ # put image of line into result
87
+ res_imgs.append(target)
88
+
89
+ return Batch(res_imgs, res_gt_texts, batch.batch_size)
90
+
91
+ def process_img(self, img: np.ndarray) -> np.ndarray:
92
+ """Resize to target size, apply data augmentation."""
93
+
94
+ # there are damaged files in IAM dataset - just use black image instead
95
+ if img is None:
96
+ img = np.zeros(self.img_size[::-1])
97
+
98
+ # data augmentation
99
+ img = img.astype(float)
100
+ if self.data_augmentation:
101
+ # photometric data augmentation
102
+ if random.random() < 0.25:
103
+ def rand_odd():
104
+ return random.randint(1, 3) * 2 + 1
105
+ img = cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
106
+ if random.random() < 0.25:
107
+ img = cv2.dilate(img, np.ones((3, 3)))
108
+ if random.random() < 0.25:
109
+ img = cv2.erode(img, np.ones((3, 3)))
110
+
111
+ # geometric data augmentation
112
+ wt, ht = self.img_size
113
+ h, w = img.shape
114
+ f = min(wt / w, ht / h)
115
+ fx = f * np.random.uniform(0.75, 1.05)
116
+ fy = f * np.random.uniform(0.75, 1.05)
117
+
118
+ # random position around center
119
+ txc = (wt - w * fx) / 2
120
+ tyc = (ht - h * fy) / 2
121
+ freedom_x = max((wt - fx * w) / 2, 0)
122
+ freedom_y = max((ht - fy * h) / 2, 0)
123
+ tx = txc + np.random.uniform(-freedom_x, freedom_x)
124
+ ty = tyc + np.random.uniform(-freedom_y, freedom_y)
125
+
126
+ # map image into target image
127
+ M = np.float32([[fx, 0, tx], [0, fy, ty]])
128
+ target = np.ones(self.img_size[::-1]) * 255
129
+ img = cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
130
+
131
+ # photometric data augmentation
132
+ if random.random() < 0.5:
133
+ img = img * (0.25 + random.random() * 0.75)
134
+ if random.random() < 0.25:
135
+ img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
136
+ if random.random() < 0.1:
137
+ img = 255 - img
138
+
139
+ # no data augmentation
140
+ else:
141
+ if self.dynamic_width:
142
+ ht = self.img_size[1]
143
+ h, w = img.shape
144
+ f = ht / h
145
+ wt = int(f * w + self.padding)
146
+ wt = wt + (4 - wt) % 4
147
+ tx = (wt - w * f) / 2
148
+ ty = 0
149
+ else:
150
+ wt, ht = self.img_size
151
+ h, w = img.shape
152
+ f = min(wt / w, ht / h)
153
+ tx = (wt - w * f) / 2
154
+ ty = (ht - h * f) / 2
155
+
156
+ # map image into target image
157
+ M = np.float32([[f, 0, tx], [0, f, ty]])
158
+ target = np.ones([ht, wt]) * 255
159
+ img = cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
160
+
161
+ # transpose for TF
162
+ img = cv2.transpose(img)
163
+
164
+ # convert to range [-1, 1]
165
+ img = img / 255 - 0.5
166
+ return img
167
+
168
+ def process_batch(self, batch: Batch) -> Batch:
169
+ if self.line_mode:
170
+ batch = self._simulate_text_line(batch)
171
+
172
+ res_imgs = [self.process_img(img) for img in batch.imgs]
173
+ max_text_len = res_imgs[0].shape[0] // 4
174
+ res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
175
+ return Batch(res_imgs, res_gt_texts, batch.batch_size)
176
+
177
+
178
+ def main():
179
+ import matplotlib.pyplot as plt
180
+
181
+ img = cv2.imread('../data/test.png', cv2.IMREAD_GRAYSCALE)
182
+ img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
183
+ plt.subplot(121)
184
+ plt.imshow(img, cmap='gray')
185
+ plt.subplot(122)
186
+ plt.imshow(cv2.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
187
+ plt.show()
188
+
189
+
190
+ if __name__ == '__main__':
191
+ main()
app/runner.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+
4
+ fileDir = os.path.dirname(os.path.realpath(__file__))
5
+ subprocess.run(["streamlit", "run", "webapp.py"], cwd = fileDir)
app/simple.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from model import Model, DecoderType
5
+ from preprocessor import Preprocessor
6
+ from dataloader_iam import Batch
7
+
8
+ import tensorflow as tf
9
+
10
+ def get_img_size(line_mode: bool = False) -> tuple[int, int]:
11
+ """
12
+ Auxiliary method that sets the height and width.
13
+ Height is fixed while width is set according to the Model used.
14
+ """
15
+ if line_mode:
16
+ return 256, get_img_height()
17
+ return 128, get_img_height()
18
+
19
+ def get_img_height() -> int:
20
+ """
21
+ Auxiliary method that sets the fixed height for the Neural Network.
22
+ """
23
+ return 32
24
+
25
+ def infer(line_mode: bool, model: Model, fn_img: str) -> None:
26
+ """
27
+ Auxiliary method that does inference using the pretrained models:
28
+ Recognizes text in an image given its path.
29
+ """
30
+ img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
31
+ assert img is not None
32
+
33
+ preprocessor = Preprocessor(get_img_size(line_mode), dynamic_width=True, padding=16)
34
+ img = preprocessor.process_img(img)
35
+
36
+ batch = Batch([img], None, 1)
37
+ recognized, probability = model.infer_batch(batch, True)
38
+ return recognized, probability
39
+
40
+ def main(image_path: str, model_path: str, decoder_type: DecoderType):
41
+ """
42
+ Main function to load the model, perform inference on the input image,
43
+ and print the result.
44
+ """
45
+ # Load the model
46
+ char_list_path = model_path + "/charList.txt"
47
+ model = Model(list(open(char_list_path).read()), model_path, decoder_type, must_restore=True)
48
+
49
+ # Perform inference
50
+ recognized, probability = infer(model_path.endswith('line-model'), model, image_path)
51
+
52
+ # Print the results
53
+ print("Recognized Text:", recognized[0])
54
+ print("Probability:", probability[0])
55
+
56
+ if __name__ == "__main__":
57
+ # Example usage
58
+ # Define the image path, model directory, and decoder type here
59
+ image_path = 'word.png' # Update this path
60
+ model_path = '../model/word-model' # or '../model/line-model' depending on your model
61
+ decoder_type = DecoderType.BestPath # Change as needed: BestPath, BeamSearch, WordBeamSearch
62
+
63
+ # Call the main function with the specified parameters
64
+ main(image_path, model_path, decoder_type)
app/userInput.png ADDED
app/webapp.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from path import Path
6
+ import streamlit as st
7
+ from typing import Tuple
8
+ from dataloader_iam import Batch
9
+ from model import Model, DecoderType
10
+ from preprocessor import Preprocessor
11
+ from streamlit_drawable_canvas import st_canvas
12
+
13
+
14
+ def get_img_size(line_mode: bool = False) -> Tuple[int, int]:
15
+ """
16
+ Auxiliary method that sets the height and width
17
+ Height is fixed while width is set according to the Model used.
18
+ """
19
+ if line_mode:
20
+ return 256, get_img_height()
21
+ return 128, get_img_height()
22
+
23
+ def get_img_height() -> int:
24
+ """
25
+ Auxiliary method that sets the height, which is fixed for the Neural Network.
26
+ """
27
+ return 32
28
+
29
+ def infer(line_mode: bool, model: Model, fn_img: Path) -> None:
30
+ """
31
+ Auxiliary method that does inference using the pretrained models:
32
+ Recognizes text in an image given its path.
33
+ """
34
+ img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
35
+ assert img is not None
36
+
37
+ preprocessor = Preprocessor(get_img_size(line_mode), dynamic_width=True, padding=16)
38
+ img = preprocessor.process_img(img)
39
+
40
+ batch = Batch([img], None, 1)
41
+ recognized, probability = model.infer_batch(batch, True)
42
+ return [recognized, probability]
43
+
44
+ def main():
45
+
46
+ #Website properties
47
+ st.set_page_config(
48
+ page_title = "HTR App",
49
+ page_icon = ":pencil:",
50
+ layout = "centered",
51
+ initial_sidebar_state = "auto",
52
+ )
53
+
54
+ st.title('HTR Simple Application')
55
+
56
+ st.markdown("""
57
+ Streamlit Web Interface for Handwritten Text Recognition (HTR), implemented with TensorFlow and trained on the IAM off-line HTR dataset. The model takes images of single words or text lines (multiple words) as input and outputs the recognized text.
58
+ """, unsafe_allow_html=True)
59
+
60
+ st.markdown("""
61
+ Predictions can be made using one of two models:
62
+ - [Model 1](https://www.dropbox.com/s/mya8hw6jyzqm0a3/word-model.zip?dl=1) (Trained on Single Word Images)
63
+ - [Model 2](https://www.dropbox.com/s/7xwkcilho10rthn/line-model.zip?dl=1) (Trained on Text Line Images)
64
+ """, unsafe_allow_html=True)
65
+
66
+ st.subheader('Select a Model, Choose the Arguments and Draw in the box below or Upload an Image to obtain a prediction.')
67
+
68
+ #Selectors for the model and decoder
69
+ modelSelect = st.selectbox("Select a Model", ['Single_Model', 'Line_Model'])
70
+
71
+ decoderSelect = st.selectbox("Select a Decoder", ['Bestpath', 'Beamsearch', 'Wordbeamsearch'])
72
+
73
+ #Mappings (dictionaries) for the model and decoder. Asigns the directory or the DecoderType of the selected option.
74
+ modelMapping = {
75
+ "Single_Model": '../model/word-model',
76
+ "Line_Model": '../model/line-model'
77
+ }
78
+
79
+ decoderMapping = {
80
+ 'Bestpath': DecoderType.BestPath,
81
+ 'Beamsearch': DecoderType.BeamSearch,
82
+ 'Wordbeamsearch': DecoderType.WordBeamSearch
83
+ }
84
+
85
+ #Slider for pencil width
86
+ strokeWidth = st.slider("Stroke Width: ", 1, 25, 6)
87
+
88
+ #Canvas/Text Box for user input. BackGround Color must be white (#FFFFFF) or else text will not be properly recognised.
89
+ inputDrawn = st_canvas(
90
+ fill_color="rgba(255, 165, 0, 0.3)",
91
+ stroke_width=strokeWidth,
92
+ update_streamlit=True,
93
+ height = 200,
94
+ width = 400,
95
+ drawing_mode='freedraw',
96
+ key="canvas",
97
+ background_color = '#FFFFFF'
98
+ )
99
+
100
+ #Buffer for user input (images uploaded from the user's device)
101
+ inputBuffer = st.file_uploader("Upload an Image", type=["png"])
102
+
103
+ #Infer Button
104
+ inferBool = st.button("Recognize Word")
105
+
106
+ #We start infering once we have the user input and he presses the Infer button.
107
+ if ((inputDrawn.image_data is not None or inputBuffer is not None) and inferBool == True):
108
+
109
+ #We turn the input into a numpy array
110
+ if inputDrawn.image_data is not None:
111
+ inputArray = np.array(inputDrawn.image_data)
112
+
113
+ if inputBuffer is not None:
114
+ inputBufferImage = Image.open(inputBuffer)
115
+ inputArray = np.array(inputBufferImage)
116
+
117
+ #We turn this array into a .png format and save it.
118
+ inputImage = Image.fromarray(inputArray.astype('uint8'), 'RGBA')
119
+ inputImage.save('userInput.png')
120
+ #We obtain the model directory and the decoder type from their mapping
121
+ modelDir = modelMapping[modelSelect]
122
+ decoderType = decoderMapping[decoderSelect]
123
+
124
+ #Finally, we call the model with this image as attribute and display the Best Candidate and its probability on the Interface
125
+ model = Model(list(open(modelDir + "/charList.txt").read()), modelDir, decoderType, must_restore=True)
126
+ inferedText = infer(modelDir == '../model/line-model', model, 'userInput.png')
127
+
128
+ st.write("**Best Candidate: **", inferedText[0][0])
129
+ st.write("**Probability: **", str(inferedText[1][0]*100) + "%")
130
+
131
+ if __name__ == "__main__":
132
+ main()
app/word.png ADDED