NadavShaked
commited on
Commit
โข
91da6cc
1
Parent(s):
c55ba1a
Upload 7 files
Browse files- handler.py +126 -0
- main.py +596 -0
- src/models.py +74 -0
- src/models_utils.py +561 -0
- src/plot_helpers.py +58 -0
- src/running_params.py +3 -0
- src/utiles_data.py +737 -0
handler.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
from transformers import AutoConfig, AutoTokenizer
|
3 |
+
from src.models import DNikudModel, ModelConfig
|
4 |
+
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
|
5 |
+
from src.utiles_data import Nikud, NikudDataset
|
6 |
+
from src.models_utils import predict_single, predict
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
class EndpointHandler:
|
13 |
+
def __init__(self, path=""):
|
14 |
+
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
|
16 |
+
self.tokenizer = AutoTokenizer.from_pretrained("tau/tavbert-he")
|
17 |
+
dir_model_config = os.path.join("models", "config.yml")
|
18 |
+
self.config = ModelConfig.load_from_file(dir_model_config)
|
19 |
+
self.model = DNikudModel(
|
20 |
+
self.config,
|
21 |
+
len(Nikud.label_2_id["nikud"]),
|
22 |
+
len(Nikud.label_2_id["dagesh"]),
|
23 |
+
len(Nikud.label_2_id["sin"]),
|
24 |
+
device=self.DEVICE,
|
25 |
+
).to(self.DEVICE)
|
26 |
+
state_dict_model = self.model.state_dict()
|
27 |
+
state_dict_model.update(torch.load("./models/Dnikud_best_model.pth"))
|
28 |
+
self.model.load_state_dict(state_dict_model)
|
29 |
+
self.max_length = MAX_LENGTH_SEN
|
30 |
+
|
31 |
+
def back_2_text(self, labels, text):
|
32 |
+
nikud = Nikud()
|
33 |
+
new_line = ""
|
34 |
+
|
35 |
+
for indx_char, c in enumerate(text):
|
36 |
+
new_line += (
|
37 |
+
c
|
38 |
+
+ nikud.id_2_char(labels[indx_char][1][1], "dagesh")
|
39 |
+
+ nikud.id_2_char(labels[indx_char][1][2], "sin")
|
40 |
+
+ nikud.id_2_char(labels[indx_char][1][0], "nikud")
|
41 |
+
)
|
42 |
+
print(indx_char, c)
|
43 |
+
print(labels)
|
44 |
+
return new_line
|
45 |
+
|
46 |
+
def prepare_data(self, data, name="train"):
|
47 |
+
print("Data = ", data)
|
48 |
+
dataset = []
|
49 |
+
for index, (sentence, label) in tqdm(
|
50 |
+
enumerate(data), desc=f"Prepare data {name}"
|
51 |
+
):
|
52 |
+
encoded_sequence = self.tokenizer.encode_plus(
|
53 |
+
sentence,
|
54 |
+
add_special_tokens=True,
|
55 |
+
max_length=self.max_length,
|
56 |
+
padding="max_length",
|
57 |
+
truncation=True,
|
58 |
+
return_attention_mask=True,
|
59 |
+
return_tensors="pt",
|
60 |
+
)
|
61 |
+
label_lists = [
|
62 |
+
[letter.nikud, letter.dagesh, letter.sin] for letter in label
|
63 |
+
]
|
64 |
+
label = torch.tensor(
|
65 |
+
[
|
66 |
+
[
|
67 |
+
Nikud.PAD_OR_IRRELEVANT,
|
68 |
+
Nikud.PAD_OR_IRRELEVANT,
|
69 |
+
Nikud.PAD_OR_IRRELEVANT,
|
70 |
+
]
|
71 |
+
]
|
72 |
+
+ label_lists[: (self.max_length - 1)]
|
73 |
+
+ [
|
74 |
+
[
|
75 |
+
Nikud.PAD_OR_IRRELEVANT,
|
76 |
+
Nikud.PAD_OR_IRRELEVANT,
|
77 |
+
Nikud.PAD_OR_IRRELEVANT,
|
78 |
+
]
|
79 |
+
for i in range(self.max_length - len(label) - 1)
|
80 |
+
]
|
81 |
+
)
|
82 |
+
|
83 |
+
dataset.append(
|
84 |
+
(
|
85 |
+
encoded_sequence["input_ids"][0],
|
86 |
+
encoded_sequence["attention_mask"][0],
|
87 |
+
label,
|
88 |
+
)
|
89 |
+
)
|
90 |
+
|
91 |
+
self.prepered_data = dataset
|
92 |
+
|
93 |
+
def predict_single_text(
|
94 |
+
self,
|
95 |
+
text,
|
96 |
+
):
|
97 |
+
dataset = NikudDataset(tokenizer=self.tokenizer, max_length=MAX_LENGTH_SEN)
|
98 |
+
data, orig_data = dataset.read_single_text(text)
|
99 |
+
print("data", data, len(data))
|
100 |
+
dataset.prepare_data(name="inference")
|
101 |
+
mtb_prediction_dl = torch.utils.data.DataLoader(
|
102 |
+
dataset.prepered_data, batch_size=BATCH_SIZE
|
103 |
+
)
|
104 |
+
# print("dataset", dataset, len(dataset))
|
105 |
+
# data = self.tokenizer(text, return_tensors="pt")
|
106 |
+
all_labels = predict(self.model, mtb_prediction_dl, self.DEVICE)
|
107 |
+
text_data_with_labels = dataset.back_2_text(labels=all_labels)
|
108 |
+
# all_labels = predict_single(self.model, dataset, self.DEVICE)
|
109 |
+
return text_data_with_labels
|
110 |
+
|
111 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
112 |
+
"""
|
113 |
+
data args:
|
114 |
+
"""
|
115 |
+
|
116 |
+
# get inputs
|
117 |
+
inputs = data.pop("text", data)
|
118 |
+
|
119 |
+
# run normal prediction
|
120 |
+
prediction = self.predict_single_text(inputs)
|
121 |
+
|
122 |
+
# result = []
|
123 |
+
# for pred in prediction:
|
124 |
+
# result.append(self.back_2_text(pred, inputs))
|
125 |
+
# result = self.back_2_text(prediction, inputs)
|
126 |
+
return prediction
|
main.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from datetime import datetime
|
6 |
+
import logging
|
7 |
+
from logging.handlers import RotatingFileHandler
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
# ML
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from transformers import AutoConfig, AutoTokenizer
|
14 |
+
|
15 |
+
# DL
|
16 |
+
from src.models import DNikudModel, ModelConfig
|
17 |
+
from src.models_utils import training, evaluate, predict
|
18 |
+
from src.plot_helpers import (
|
19 |
+
generate_plot_by_nikud_dagesh_sin_dict,
|
20 |
+
generate_word_and_letter_accuracy_plot,
|
21 |
+
)
|
22 |
+
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN
|
23 |
+
from src.utiles_data import (
|
24 |
+
NikudDataset,
|
25 |
+
Nikud,
|
26 |
+
create_missing_folders,
|
27 |
+
extract_text_to_compare_nakdimon,
|
28 |
+
)
|
29 |
+
|
30 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
assert DEVICE == "cuda"
|
32 |
+
|
33 |
+
|
34 |
+
def get_logger(
|
35 |
+
log_level, name_func, date_time=datetime.now().strftime("%d_%m_%y__%H_%M")
|
36 |
+
):
|
37 |
+
log_location = os.path.join(
|
38 |
+
os.path.join(Path(__file__).parent, "logging"),
|
39 |
+
f"log_model_{name_func}_{date_time}",
|
40 |
+
)
|
41 |
+
create_missing_folders(log_location)
|
42 |
+
|
43 |
+
log_format = "%(asctime)s %(levelname)-8s Thread_%(thread)-6d ::: %(funcName)s(%(lineno)d) ::: %(message)s"
|
44 |
+
logger = logging.getLogger("algo")
|
45 |
+
logger.setLevel(getattr(logging, log_level))
|
46 |
+
cnsl_log_formatter = logging.Formatter(log_format)
|
47 |
+
cnsl_handler = logging.StreamHandler()
|
48 |
+
cnsl_handler.setFormatter(cnsl_log_formatter)
|
49 |
+
cnsl_handler.setLevel(log_level)
|
50 |
+
logger.addHandler(cnsl_handler)
|
51 |
+
|
52 |
+
create_missing_folders(log_location)
|
53 |
+
|
54 |
+
file_location = os.path.join(log_location, "Diacritization_Model_DEBUG.log")
|
55 |
+
file_log_formatter = logging.Formatter(log_format)
|
56 |
+
|
57 |
+
SINGLE_LOG_SIZE = 2 * 1024 * 1024 # in Bytes
|
58 |
+
MAX_LOG_FILES = 20
|
59 |
+
file_handler = RotatingFileHandler(
|
60 |
+
file_location, mode="a", maxBytes=SINGLE_LOG_SIZE, backupCount=MAX_LOG_FILES
|
61 |
+
)
|
62 |
+
file_handler.setFormatter(file_log_formatter)
|
63 |
+
file_handler.setLevel(log_level)
|
64 |
+
logger.addHandler(file_handler)
|
65 |
+
|
66 |
+
return logger
|
67 |
+
|
68 |
+
|
69 |
+
def evaluate_text(
|
70 |
+
path,
|
71 |
+
dnikud_model,
|
72 |
+
tokenizer_tavbert,
|
73 |
+
logger,
|
74 |
+
plots_folder=None,
|
75 |
+
batch_size=BATCH_SIZE,
|
76 |
+
):
|
77 |
+
path_name = os.path.basename(path)
|
78 |
+
|
79 |
+
msg = f"evaluate text: {path_name} on D-nikud Model"
|
80 |
+
logger.debug(msg)
|
81 |
+
|
82 |
+
if os.path.isfile(path):
|
83 |
+
dataset = NikudDataset(
|
84 |
+
tokenizer_tavbert, file=path, logger=logger, max_length=MAX_LENGTH_SEN
|
85 |
+
)
|
86 |
+
elif os.path.isdir(path):
|
87 |
+
dataset = NikudDataset(
|
88 |
+
tokenizer_tavbert, folder=path, logger=logger, max_length=MAX_LENGTH_SEN
|
89 |
+
)
|
90 |
+
else:
|
91 |
+
raise Exception("input path doesnt exist")
|
92 |
+
|
93 |
+
dataset.prepare_data(name="evaluate")
|
94 |
+
mtb_dl = torch.utils.data.DataLoader(dataset.prepered_data, batch_size=batch_size)
|
95 |
+
|
96 |
+
word_level_correct, letter_level_correct_dev = evaluate(
|
97 |
+
dnikud_model, mtb_dl, plots_folder, device=DEVICE
|
98 |
+
)
|
99 |
+
|
100 |
+
msg = (
|
101 |
+
f"Dnikud Model\n{path_name} evaluate\nLetter level accuracy:{letter_level_correct_dev}\n"
|
102 |
+
f"Word level accuracy: {word_level_correct}"
|
103 |
+
)
|
104 |
+
logger.debug(msg)
|
105 |
+
|
106 |
+
|
107 |
+
def predict_text(
|
108 |
+
text_file,
|
109 |
+
tokenizer_tavbert,
|
110 |
+
output_file,
|
111 |
+
logger,
|
112 |
+
dnikud_model,
|
113 |
+
compare_nakdimon=False,
|
114 |
+
):
|
115 |
+
dataset = NikudDataset(
|
116 |
+
tokenizer_tavbert, file=text_file, logger=logger, max_length=MAX_LENGTH_SEN
|
117 |
+
)
|
118 |
+
|
119 |
+
dataset.prepare_data(name="prediction")
|
120 |
+
mtb_prediction_dl = torch.utils.data.DataLoader(
|
121 |
+
dataset.prepered_data, batch_size=BATCH_SIZE
|
122 |
+
)
|
123 |
+
all_labels = predict(dnikud_model, mtb_prediction_dl, DEVICE)
|
124 |
+
text_data_with_labels = dataset.back_2_text(labels=all_labels)
|
125 |
+
|
126 |
+
if output_file is None:
|
127 |
+
for line in text_data_with_labels:
|
128 |
+
print(line)
|
129 |
+
else:
|
130 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
131 |
+
if compare_nakdimon:
|
132 |
+
f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
|
133 |
+
else:
|
134 |
+
f.write(text_data_with_labels)
|
135 |
+
|
136 |
+
|
137 |
+
def predict_folder(
|
138 |
+
folder,
|
139 |
+
output_folder,
|
140 |
+
logger,
|
141 |
+
tokenizer_tavbert,
|
142 |
+
dnikud_model,
|
143 |
+
compare_nakdimon=False,
|
144 |
+
):
|
145 |
+
create_missing_folders(output_folder)
|
146 |
+
|
147 |
+
for filename in os.listdir(folder):
|
148 |
+
file_path = os.path.join(folder, filename)
|
149 |
+
|
150 |
+
if filename.lower().endswith(".txt") and os.path.isfile(file_path):
|
151 |
+
output_file = os.path.join(output_folder, filename)
|
152 |
+
predict_text(
|
153 |
+
file_path,
|
154 |
+
output_file=output_file,
|
155 |
+
logger=logger,
|
156 |
+
tokenizer_tavbert=tokenizer_tavbert,
|
157 |
+
dnikud_model=dnikud_model,
|
158 |
+
compare_nakdimon=compare_nakdimon,
|
159 |
+
)
|
160 |
+
elif (
|
161 |
+
os.path.isdir(file_path) and filename != ".git" and filename != "README.md"
|
162 |
+
):
|
163 |
+
sub_folder = file_path
|
164 |
+
sub_folder_output = os.path.join(output_folder, filename)
|
165 |
+
predict_folder(
|
166 |
+
sub_folder,
|
167 |
+
sub_folder_output,
|
168 |
+
logger,
|
169 |
+
tokenizer_tavbert,
|
170 |
+
dnikud_model,
|
171 |
+
compare_nakdimon=compare_nakdimon,
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
def update_compare_folder(folder, output_folder):
|
176 |
+
create_missing_folders(output_folder)
|
177 |
+
|
178 |
+
for filename in os.listdir(folder):
|
179 |
+
file_path = os.path.join(folder, filename)
|
180 |
+
|
181 |
+
if filename.lower().endswith(".txt") and os.path.isfile(file_path):
|
182 |
+
output_file = os.path.join(output_folder, filename)
|
183 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
184 |
+
text_data_with_labels = f.read()
|
185 |
+
with open(output_file, "w", encoding="utf-8") as f:
|
186 |
+
f.write(extract_text_to_compare_nakdimon(text_data_with_labels))
|
187 |
+
elif os.path.isdir(file_path) and filename != ".git":
|
188 |
+
sub_folder = file_path
|
189 |
+
sub_folder_output = os.path.join(output_folder, filename)
|
190 |
+
update_compare_folder(sub_folder, sub_folder_output)
|
191 |
+
|
192 |
+
|
193 |
+
def check_files_excepted(folder):
|
194 |
+
for filename in os.listdir(folder):
|
195 |
+
file_path = os.path.join(folder, filename)
|
196 |
+
|
197 |
+
if filename.lower().endswith(".txt") and os.path.isfile(file_path):
|
198 |
+
try:
|
199 |
+
x = NikudDataset(None, file=file_path)
|
200 |
+
except:
|
201 |
+
print(f"failed in file: {filename}")
|
202 |
+
elif os.path.isdir(file_path) and filename != ".git":
|
203 |
+
check_files_excepted(file_path)
|
204 |
+
|
205 |
+
|
206 |
+
def do_predict(
|
207 |
+
input_path, output_path, tokenizer_tavbert, logger, dnikud_model, compare_nakdimon
|
208 |
+
):
|
209 |
+
if os.path.isdir(input_path):
|
210 |
+
predict_folder(
|
211 |
+
input_path,
|
212 |
+
output_path,
|
213 |
+
logger,
|
214 |
+
tokenizer_tavbert,
|
215 |
+
dnikud_model,
|
216 |
+
compare_nakdimon=compare_nakdimon,
|
217 |
+
)
|
218 |
+
elif os.path.isfile(input_path):
|
219 |
+
predict_text(
|
220 |
+
input_path,
|
221 |
+
output_file=output_path,
|
222 |
+
logger=logger,
|
223 |
+
tokenizer_tavbert=tokenizer_tavbert,
|
224 |
+
dnikud_model=dnikud_model,
|
225 |
+
compare_nakdimon=compare_nakdimon,
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
raise Exception("Input file not exist")
|
229 |
+
|
230 |
+
|
231 |
+
def evaluate_folder(folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder):
|
232 |
+
msg = f"evaluate sub folder: {folder_path}"
|
233 |
+
logger.info(msg)
|
234 |
+
|
235 |
+
evaluate_text(
|
236 |
+
folder_path,
|
237 |
+
dnikud_model=dnikud_model,
|
238 |
+
tokenizer_tavbert=tokenizer_tavbert,
|
239 |
+
logger=logger,
|
240 |
+
plots_folder=plots_folder,
|
241 |
+
batch_size=BATCH_SIZE,
|
242 |
+
)
|
243 |
+
|
244 |
+
msg = f"\n***************************************\n"
|
245 |
+
logger.info(msg)
|
246 |
+
|
247 |
+
for sub_folder_name in os.listdir(folder_path):
|
248 |
+
sub_folder_path = os.path.join(folder_path, sub_folder_name)
|
249 |
+
|
250 |
+
if (
|
251 |
+
not os.path.isdir(sub_folder_path)
|
252 |
+
or sub_folder_path == ".git"
|
253 |
+
or "not_use" in sub_folder_path
|
254 |
+
or "NakdanResults" in sub_folder_path
|
255 |
+
):
|
256 |
+
continue
|
257 |
+
|
258 |
+
evaluate_folder(
|
259 |
+
sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
def do_evaluate(
|
264 |
+
input_path,
|
265 |
+
logger,
|
266 |
+
dnikud_model,
|
267 |
+
tokenizer_tavbert,
|
268 |
+
plots_folder,
|
269 |
+
eval_sub_folders=False,
|
270 |
+
):
|
271 |
+
msg = f"evaluate all_data: {input_path}"
|
272 |
+
logger.info(msg)
|
273 |
+
|
274 |
+
evaluate_text(
|
275 |
+
input_path,
|
276 |
+
dnikud_model=dnikud_model,
|
277 |
+
tokenizer_tavbert=tokenizer_tavbert,
|
278 |
+
logger=logger,
|
279 |
+
plots_folder=plots_folder,
|
280 |
+
batch_size=BATCH_SIZE,
|
281 |
+
)
|
282 |
+
|
283 |
+
msg = f"\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n"
|
284 |
+
logger.info(msg)
|
285 |
+
|
286 |
+
if eval_sub_folders:
|
287 |
+
for sub_folder_name in os.listdir(input_path):
|
288 |
+
sub_folder_path = os.path.join(input_path, sub_folder_name)
|
289 |
+
|
290 |
+
if (
|
291 |
+
not os.path.isdir(sub_folder_path)
|
292 |
+
or sub_folder_path == ".git"
|
293 |
+
or "not_use" in sub_folder_path
|
294 |
+
or "NakdanResults" in sub_folder_path
|
295 |
+
):
|
296 |
+
continue
|
297 |
+
|
298 |
+
evaluate_folder(
|
299 |
+
sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder
|
300 |
+
)
|
301 |
+
|
302 |
+
|
303 |
+
def do_train(
|
304 |
+
logger,
|
305 |
+
plots_folder,
|
306 |
+
dir_model_config,
|
307 |
+
tokenizer_tavbert,
|
308 |
+
dnikud_model,
|
309 |
+
output_trained_model_dir,
|
310 |
+
data_folder,
|
311 |
+
n_epochs,
|
312 |
+
checkpoints_frequency,
|
313 |
+
learning_rate,
|
314 |
+
batch_size,
|
315 |
+
):
|
316 |
+
msg = "Loading data..."
|
317 |
+
logger.debug(msg)
|
318 |
+
|
319 |
+
dataset_train = NikudDataset(
|
320 |
+
tokenizer_tavbert,
|
321 |
+
folder=os.path.join(data_folder, "train"),
|
322 |
+
logger=logger,
|
323 |
+
max_length=MAX_LENGTH_SEN,
|
324 |
+
is_train=True,
|
325 |
+
)
|
326 |
+
dataset_dev = NikudDataset(
|
327 |
+
tokenizer=tokenizer_tavbert,
|
328 |
+
folder=os.path.join(data_folder, "dev"),
|
329 |
+
logger=logger,
|
330 |
+
max_length=dataset_train.max_length,
|
331 |
+
is_train=True,
|
332 |
+
)
|
333 |
+
dataset_test = NikudDataset(
|
334 |
+
tokenizer=tokenizer_tavbert,
|
335 |
+
folder=os.path.join(data_folder, "test"),
|
336 |
+
logger=logger,
|
337 |
+
max_length=dataset_train.max_length,
|
338 |
+
is_train=True,
|
339 |
+
)
|
340 |
+
|
341 |
+
dataset_train.show_data_labels(plots_folder=plots_folder)
|
342 |
+
|
343 |
+
msg = f"Max length of data: {dataset_train.max_length}"
|
344 |
+
logger.debug(msg)
|
345 |
+
|
346 |
+
msg = (
|
347 |
+
f"Num rows in train data: {len(dataset_train.data)}, "
|
348 |
+
f"Num rows in dev data: {len(dataset_dev.data)}, "
|
349 |
+
f"Num rows in test data: {len(dataset_test.data)}"
|
350 |
+
)
|
351 |
+
logger.debug(msg)
|
352 |
+
|
353 |
+
msg = "Loading tokenizer and prepare data..."
|
354 |
+
logger.debug(msg)
|
355 |
+
|
356 |
+
dataset_train.prepare_data(name="train")
|
357 |
+
dataset_dev.prepare_data(name="dev")
|
358 |
+
dataset_test.prepare_data(name="test")
|
359 |
+
|
360 |
+
mtb_train_dl = torch.utils.data.DataLoader(
|
361 |
+
dataset_train.prepered_data, batch_size=batch_size
|
362 |
+
)
|
363 |
+
mtb_dev_dl = torch.utils.data.DataLoader(
|
364 |
+
dataset_dev.prepered_data, batch_size=batch_size
|
365 |
+
)
|
366 |
+
|
367 |
+
if not os.path.isfile(dir_model_config):
|
368 |
+
our_model_config = ModelConfig(dataset_train.max_length)
|
369 |
+
our_model_config.save_to_file(dir_model_config)
|
370 |
+
|
371 |
+
optimizer = torch.optim.Adam(dnikud_model.parameters(), lr=learning_rate)
|
372 |
+
|
373 |
+
msg = "training..."
|
374 |
+
logger.debug(msg)
|
375 |
+
|
376 |
+
criterion_nikud = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
|
377 |
+
DEVICE
|
378 |
+
)
|
379 |
+
criterion_dagesh = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(
|
380 |
+
DEVICE
|
381 |
+
)
|
382 |
+
criterion_sin = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(DEVICE)
|
383 |
+
|
384 |
+
training_params = {
|
385 |
+
"n_epochs": n_epochs,
|
386 |
+
"checkpoints_frequency": checkpoints_frequency,
|
387 |
+
}
|
388 |
+
(
|
389 |
+
best_model_details,
|
390 |
+
best_accuracy,
|
391 |
+
epochs_loss_train_values,
|
392 |
+
steps_loss_train_values,
|
393 |
+
loss_dev_values,
|
394 |
+
accuracy_dev_values,
|
395 |
+
) = training(
|
396 |
+
dnikud_model,
|
397 |
+
mtb_train_dl,
|
398 |
+
mtb_dev_dl,
|
399 |
+
criterion_nikud,
|
400 |
+
criterion_dagesh,
|
401 |
+
criterion_sin,
|
402 |
+
training_params,
|
403 |
+
logger,
|
404 |
+
output_trained_model_dir,
|
405 |
+
optimizer,
|
406 |
+
device=DEVICE,
|
407 |
+
)
|
408 |
+
|
409 |
+
generate_plot_by_nikud_dagesh_sin_dict(
|
410 |
+
epochs_loss_train_values, "Train epochs loss", "Loss", plots_folder
|
411 |
+
)
|
412 |
+
generate_plot_by_nikud_dagesh_sin_dict(
|
413 |
+
steps_loss_train_values, "Train steps loss", "Loss", plots_folder
|
414 |
+
)
|
415 |
+
generate_plot_by_nikud_dagesh_sin_dict(
|
416 |
+
loss_dev_values, "Dev epochs loss", "Loss", plots_folder
|
417 |
+
)
|
418 |
+
generate_plot_by_nikud_dagesh_sin_dict(
|
419 |
+
accuracy_dev_values, "Dev accuracy", "Accuracy", plots_folder
|
420 |
+
)
|
421 |
+
generate_word_and_letter_accuracy_plot(
|
422 |
+
accuracy_dev_values, "Accuracy", plots_folder
|
423 |
+
)
|
424 |
+
|
425 |
+
msg = "Done"
|
426 |
+
logger.info(msg)
|
427 |
+
|
428 |
+
|
429 |
+
if __name__ == "__main__":
|
430 |
+
tokenizer_tavbert = AutoTokenizer.from_pretrained("tau/tavbert-he")
|
431 |
+
|
432 |
+
parser = argparse.ArgumentParser(
|
433 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
434 |
+
description="""Predict D-nikud""",
|
435 |
+
)
|
436 |
+
parser.add_argument(
|
437 |
+
"-l",
|
438 |
+
"--log",
|
439 |
+
dest="log_level",
|
440 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
441 |
+
default="DEBUG",
|
442 |
+
help="Set the logging level",
|
443 |
+
)
|
444 |
+
parser.add_argument(
|
445 |
+
"-m",
|
446 |
+
"--output_model_dir",
|
447 |
+
type=str,
|
448 |
+
default="models",
|
449 |
+
help="save directory for model",
|
450 |
+
)
|
451 |
+
subparsers = parser.add_subparsers(
|
452 |
+
help="sub-command help", dest="command", required=True
|
453 |
+
)
|
454 |
+
|
455 |
+
parser_predict = subparsers.add_parser("predict", help="diacritize a text files ")
|
456 |
+
parser_predict.add_argument("input_path", help="input file or folder")
|
457 |
+
parser_predict.add_argument("output_path", help="output file")
|
458 |
+
parser_predict.add_argument(
|
459 |
+
"-ptmp",
|
460 |
+
"--pretrain_model_path",
|
461 |
+
type=str,
|
462 |
+
default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
|
463 |
+
help="pre-train model path - use only if you want to use trained model weights",
|
464 |
+
)
|
465 |
+
parser_predict.add_argument(
|
466 |
+
"-c",
|
467 |
+
"--compare",
|
468 |
+
dest="compare_nakdimon",
|
469 |
+
default=False,
|
470 |
+
help="predict text for comparing with Nakdimon",
|
471 |
+
)
|
472 |
+
parser_predict.set_defaults(func=do_predict)
|
473 |
+
|
474 |
+
parser_evaluate = subparsers.add_parser("evaluate", help="evaluate D-nikud")
|
475 |
+
parser_evaluate.add_argument("input_path", help="input file or folder")
|
476 |
+
parser_evaluate.add_argument(
|
477 |
+
"-ptmp",
|
478 |
+
"--pretrain_model_path",
|
479 |
+
type=str,
|
480 |
+
default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"),
|
481 |
+
help="pre-train model path - use only if you want to use trained model weights",
|
482 |
+
)
|
483 |
+
parser_evaluate.add_argument(
|
484 |
+
"-df",
|
485 |
+
"--plots_folder",
|
486 |
+
dest="plots_folder",
|
487 |
+
default=os.path.join(Path(__file__).parent, "plots"),
|
488 |
+
help="set the debug folder",
|
489 |
+
)
|
490 |
+
parser_evaluate.add_argument(
|
491 |
+
"-es",
|
492 |
+
"--eval_sub_folders",
|
493 |
+
dest="eval_sub_folders",
|
494 |
+
default=False,
|
495 |
+
help="accuracy calculation includes the evaluation of sub-folders "
|
496 |
+
"within the input_path folder, providing independent assessments "
|
497 |
+
"for each subfolder.",
|
498 |
+
)
|
499 |
+
parser_evaluate.set_defaults(func=do_evaluate)
|
500 |
+
|
501 |
+
# train --n_epochs 20
|
502 |
+
|
503 |
+
parser_train = subparsers.add_parser("train", help="train D-nikud")
|
504 |
+
parser_train.add_argument(
|
505 |
+
"-ptmp",
|
506 |
+
"--pretrain_model_path",
|
507 |
+
type=str,
|
508 |
+
default=None,
|
509 |
+
help="pre-train model path - use only if you want to use trained model weights",
|
510 |
+
)
|
511 |
+
parser_train.add_argument(
|
512 |
+
"--learning_rate", type=float, default=0.001, help="Learning rate"
|
513 |
+
)
|
514 |
+
parser_train.add_argument("--batch_size", type=int, default=32, help="batch_size")
|
515 |
+
parser_train.add_argument(
|
516 |
+
"--n_epochs", type=int, default=10, help="number of epochs"
|
517 |
+
)
|
518 |
+
parser_train.add_argument(
|
519 |
+
"--data_folder",
|
520 |
+
dest="data_folder",
|
521 |
+
default=os.path.join(Path(__file__).parent, "data"),
|
522 |
+
help="Set the debug folder",
|
523 |
+
)
|
524 |
+
parser_train.add_argument(
|
525 |
+
"--checkpoints_frequency",
|
526 |
+
type=int,
|
527 |
+
default=1,
|
528 |
+
help="checkpoints frequency for save the model",
|
529 |
+
)
|
530 |
+
parser_train.add_argument(
|
531 |
+
"-df",
|
532 |
+
"--plots_folder",
|
533 |
+
dest="plots_folder",
|
534 |
+
default=os.path.join(Path(__file__).parent, "plots"),
|
535 |
+
help="Set the debug folder",
|
536 |
+
)
|
537 |
+
parser_train.set_defaults(func=do_train)
|
538 |
+
|
539 |
+
args = parser.parse_args()
|
540 |
+
kwargs = vars(args).copy()
|
541 |
+
date_time = datetime.now().strftime("%d_%m_%y__%H_%M")
|
542 |
+
logger = get_logger(kwargs["log_level"], args.command, date_time)
|
543 |
+
|
544 |
+
del kwargs["log_level"]
|
545 |
+
|
546 |
+
kwargs["tokenizer_tavbert"] = tokenizer_tavbert
|
547 |
+
kwargs["logger"] = logger
|
548 |
+
|
549 |
+
msg = "Loading model..."
|
550 |
+
logger.debug(msg)
|
551 |
+
|
552 |
+
if args.command in ["evaluate", "predict"] or (
|
553 |
+
args.command == "train" and args.pretrain_model_path is not None
|
554 |
+
):
|
555 |
+
dir_model_config = os.path.join("models", "config.yml")
|
556 |
+
config = ModelConfig.load_from_file(dir_model_config)
|
557 |
+
|
558 |
+
dnikud_model = DNikudModel(
|
559 |
+
config,
|
560 |
+
len(Nikud.label_2_id["nikud"]),
|
561 |
+
len(Nikud.label_2_id["dagesh"]),
|
562 |
+
len(Nikud.label_2_id["sin"]),
|
563 |
+
device=DEVICE,
|
564 |
+
).to(DEVICE)
|
565 |
+
state_dict_model = dnikud_model.state_dict()
|
566 |
+
state_dict_model.update(torch.load(args.pretrain_model_path))
|
567 |
+
dnikud_model.load_state_dict(state_dict_model)
|
568 |
+
else:
|
569 |
+
base_model_name = "tau/tavbert-he"
|
570 |
+
config = AutoConfig.from_pretrained(base_model_name)
|
571 |
+
dnikud_model = DNikudModel(
|
572 |
+
config,
|
573 |
+
len(Nikud.label_2_id["nikud"]),
|
574 |
+
len(Nikud.label_2_id["dagesh"]),
|
575 |
+
len(Nikud.label_2_id["sin"]),
|
576 |
+
pretrain_model=base_model_name,
|
577 |
+
device=DEVICE,
|
578 |
+
).to(DEVICE)
|
579 |
+
|
580 |
+
if args.command == "train":
|
581 |
+
output_trained_model_dir = os.path.join(
|
582 |
+
kwargs["output_model_dir"], "latest", f"output_models_{date_time}"
|
583 |
+
)
|
584 |
+
create_missing_folders(output_trained_model_dir)
|
585 |
+
dir_model_config = os.path.join(kwargs["output_model_dir"], "config.yml")
|
586 |
+
kwargs["dir_model_config"] = dir_model_config
|
587 |
+
kwargs["output_trained_model_dir"] = output_trained_model_dir
|
588 |
+
del kwargs["pretrain_model_path"]
|
589 |
+
del kwargs["output_model_dir"]
|
590 |
+
kwargs["dnikud_model"] = dnikud_model
|
591 |
+
|
592 |
+
del kwargs["command"]
|
593 |
+
del kwargs["func"]
|
594 |
+
args.func(**kwargs)
|
595 |
+
|
596 |
+
sys.exit(0)
|
src/models.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general
|
2 |
+
import subprocess
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
# ML
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers import AutoConfig, RobertaForMaskedLM, PretrainedConfig
|
8 |
+
|
9 |
+
|
10 |
+
class DNikudModel(nn.Module):
|
11 |
+
def __init__(self, config, nikud_size, dagesh_size, sin_size, pretrain_model=None, device='cpu'):
|
12 |
+
super(DNikudModel, self).__init__()
|
13 |
+
|
14 |
+
if pretrain_model is not None:
|
15 |
+
model_base = RobertaForMaskedLM.from_pretrained(pretrain_model).to(device)
|
16 |
+
else:
|
17 |
+
model_base = RobertaForMaskedLM(config=config).to(device)
|
18 |
+
|
19 |
+
self.model = model_base.roberta
|
20 |
+
for name, param in self.model.named_parameters():
|
21 |
+
param.requires_grad = False
|
22 |
+
|
23 |
+
self.lstm1 = nn.LSTM(config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
|
24 |
+
self.lstm2 = nn.LSTM(2 * config.hidden_size, config.hidden_size, bidirectional=True, dropout=0.1, batch_first=True)
|
25 |
+
self.dense = nn.Linear(2 * config.hidden_size, config.hidden_size)
|
26 |
+
self.out_n = nn.Linear(config.hidden_size, nikud_size)
|
27 |
+
self.out_d = nn.Linear(config.hidden_size, dagesh_size)
|
28 |
+
self.out_s = nn.Linear(config.hidden_size, sin_size)
|
29 |
+
|
30 |
+
def forward(self, input_ids, attention_mask):
|
31 |
+
last_hidden_state = self.model(input_ids, attention_mask=attention_mask).last_hidden_state
|
32 |
+
lstm1, _ = self.lstm1(last_hidden_state)
|
33 |
+
lstm2, _ = self.lstm2(lstm1)
|
34 |
+
dense = self.dense(lstm2)
|
35 |
+
|
36 |
+
nikud = self.out_n(dense)
|
37 |
+
dagesh = self.out_d(dense)
|
38 |
+
sin = self.out_s(dense)
|
39 |
+
|
40 |
+
return nikud, dagesh, sin
|
41 |
+
|
42 |
+
|
43 |
+
def get_git_commit_hash():
|
44 |
+
try:
|
45 |
+
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
|
46 |
+
return commit_hash
|
47 |
+
except subprocess.CalledProcessError:
|
48 |
+
# This will be raised if you're not in a Git repository
|
49 |
+
print("Not inside a Git repository!")
|
50 |
+
return None
|
51 |
+
|
52 |
+
|
53 |
+
class ModelConfig(PretrainedConfig):
|
54 |
+
def __init__(self, max_length=None, dict=None):
|
55 |
+
super(ModelConfig, self).__init__()
|
56 |
+
if dict is None:
|
57 |
+
self.__dict__.update(AutoConfig.from_pretrained("tau/tavbert-he").__dict__)
|
58 |
+
self.max_length = max_length
|
59 |
+
self._commit_hash = get_git_commit_hash()
|
60 |
+
else:
|
61 |
+
self.__dict__.update(dict)
|
62 |
+
|
63 |
+
def print(self):
|
64 |
+
print(self.__dict__)
|
65 |
+
|
66 |
+
def save_to_file(self, file_path):
|
67 |
+
with open(file_path, "w") as yaml_file:
|
68 |
+
yaml.dump(self.__dict__, yaml_file, default_flow_style=False)
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def load_from_file(cls, file_path):
|
72 |
+
with open(file_path, "r") as yaml_file:
|
73 |
+
config_dict = yaml.safe_load(yaml_file)
|
74 |
+
return cls(dict=config_dict)
|
src/models_utils.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
# ML
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# visual
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import seaborn as sns
|
13 |
+
from sklearn.metrics import confusion_matrix
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from src.running_params import DEBUG_MODE
|
17 |
+
from src.utiles_data import Nikud, create_missing_folders
|
18 |
+
|
19 |
+
CLASSES_LIST = ["nikud", "dagesh", "sin"]
|
20 |
+
|
21 |
+
|
22 |
+
def calc_num_correct_words(input, letter_correct_mask):
|
23 |
+
SPACE_TOKEN = 104
|
24 |
+
START_SENTENCE_TOKEN = 1
|
25 |
+
END_SENTENCE_TOKEN = 2
|
26 |
+
|
27 |
+
correct_words_count = 0
|
28 |
+
words_count = 0
|
29 |
+
for index in range(input.shape[0]):
|
30 |
+
input[index][np.where(input[index] == SPACE_TOKEN)[0]] = 0
|
31 |
+
input[index][np.where(input[index] == START_SENTENCE_TOKEN)[0]] = 0
|
32 |
+
input[index][np.where(input[index] == END_SENTENCE_TOKEN)[0]] = 0
|
33 |
+
words_end_index = np.concatenate(
|
34 |
+
(np.array([-1]), np.where(input[index] == 0)[0])
|
35 |
+
)
|
36 |
+
is_correct_words_array = [
|
37 |
+
bool(
|
38 |
+
letter_correct_mask[index][
|
39 |
+
list(range((words_end_index[s] + 1), words_end_index[s + 1]))
|
40 |
+
].all()
|
41 |
+
)
|
42 |
+
for s in range(len(words_end_index) - 1)
|
43 |
+
if words_end_index[s + 1] - (words_end_index[s] + 1) > 1
|
44 |
+
]
|
45 |
+
correct_words_count += np.array(is_correct_words_array).sum()
|
46 |
+
words_count += len(is_correct_words_array)
|
47 |
+
|
48 |
+
return correct_words_count, words_count
|
49 |
+
|
50 |
+
|
51 |
+
def predict(model, data_loader, device="cpu"):
|
52 |
+
model.to(device)
|
53 |
+
|
54 |
+
all_labels = None
|
55 |
+
with torch.no_grad():
|
56 |
+
for index_data, data in enumerate(data_loader):
|
57 |
+
(inputs, attention_mask, labels_demo) = data
|
58 |
+
inputs = inputs.to(device)
|
59 |
+
attention_mask = attention_mask.to(device)
|
60 |
+
labels_demo = labels_demo.to(device)
|
61 |
+
|
62 |
+
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
|
63 |
+
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
|
64 |
+
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
|
65 |
+
|
66 |
+
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
67 |
+
|
68 |
+
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
|
69 |
+
inputs.shape[0], inputs.shape[1], 1
|
70 |
+
)
|
71 |
+
pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
|
72 |
+
inputs.shape[0], inputs.shape[1], 1
|
73 |
+
)
|
74 |
+
pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
|
75 |
+
inputs.shape[0], inputs.shape[1], 1
|
76 |
+
)
|
77 |
+
|
78 |
+
pred_nikud[mask_cant_be_nikud] = -1
|
79 |
+
pred_dagesh[mask_cant_be_dagesh] = -1
|
80 |
+
pred_sin[mask_cant_be_sin] = -1
|
81 |
+
|
82 |
+
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
|
83 |
+
|
84 |
+
if all_labels is None:
|
85 |
+
all_labels = pred_labels
|
86 |
+
else:
|
87 |
+
all_labels = np.concatenate((all_labels, pred_labels), axis=0)
|
88 |
+
|
89 |
+
return all_labels
|
90 |
+
|
91 |
+
|
92 |
+
def predict_single(model, data, device="cpu"):
|
93 |
+
# model.to(device)
|
94 |
+
|
95 |
+
all_labels = None
|
96 |
+
with torch.no_grad():
|
97 |
+
(inputs, attention_mask, labels_demo) = data
|
98 |
+
inputs = inputs.to(device)
|
99 |
+
attention_mask = attention_mask.to(device)
|
100 |
+
labels_demo = labels_demo.to(device)
|
101 |
+
|
102 |
+
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
|
103 |
+
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
|
104 |
+
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
|
105 |
+
|
106 |
+
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
107 |
+
print("model output: ", nikud_probs, dagesh_probs, sin_probs)
|
108 |
+
|
109 |
+
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
|
110 |
+
inputs.shape[0], inputs.shape[1], 1
|
111 |
+
)
|
112 |
+
pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
|
113 |
+
inputs.shape[0], inputs.shape[1], 1
|
114 |
+
)
|
115 |
+
pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
|
116 |
+
inputs.shape[0], inputs.shape[1], 1
|
117 |
+
)
|
118 |
+
|
119 |
+
pred_nikud[mask_cant_be_nikud] = -1
|
120 |
+
pred_dagesh[mask_cant_be_dagesh] = -1
|
121 |
+
pred_sin[mask_cant_be_sin] = -1
|
122 |
+
# print(pred_nikud, pred_dagesh, pred_sin)
|
123 |
+
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
|
124 |
+
print(pred_labels)
|
125 |
+
if all_labels is None:
|
126 |
+
all_labels = pred_labels
|
127 |
+
else:
|
128 |
+
all_labels = np.concatenate((all_labels, pred_labels), axis=0)
|
129 |
+
|
130 |
+
return all_labels
|
131 |
+
|
132 |
+
|
133 |
+
def training(
|
134 |
+
model,
|
135 |
+
train_loader,
|
136 |
+
dev_loader,
|
137 |
+
criterion_nikud,
|
138 |
+
criterion_dagesh,
|
139 |
+
criterion_sin,
|
140 |
+
training_params,
|
141 |
+
logger,
|
142 |
+
output_model_path,
|
143 |
+
optimizer,
|
144 |
+
device="cpu",
|
145 |
+
):
|
146 |
+
max_length = None
|
147 |
+
best_accuracy = 0.0
|
148 |
+
|
149 |
+
logger.info(f"start training with training_params: {training_params}")
|
150 |
+
model = model.to(device)
|
151 |
+
|
152 |
+
criteria = {
|
153 |
+
"nikud": criterion_nikud.to(device),
|
154 |
+
"dagesh": criterion_dagesh.to(device),
|
155 |
+
"sin": criterion_sin.to(device),
|
156 |
+
}
|
157 |
+
|
158 |
+
output_checkpoints_path = os.path.join(output_model_path, "checkpoints")
|
159 |
+
create_missing_folders(output_checkpoints_path)
|
160 |
+
|
161 |
+
train_steps_loss_values = {"nikud": [], "dagesh": [], "sin": []}
|
162 |
+
train_epochs_loss_values = {"nikud": [], "dagesh": [], "sin": []}
|
163 |
+
dev_loss_values = {"nikud": [], "dagesh": [], "sin": []}
|
164 |
+
dev_accuracy_values = {
|
165 |
+
"nikud": [],
|
166 |
+
"dagesh": [],
|
167 |
+
"sin": [],
|
168 |
+
"all_nikud_letter": [],
|
169 |
+
"all_nikud_word": [],
|
170 |
+
}
|
171 |
+
|
172 |
+
for epoch in tqdm(range(training_params["n_epochs"]), desc="Training"):
|
173 |
+
model.train()
|
174 |
+
train_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
175 |
+
relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
176 |
+
|
177 |
+
for index_data, data in enumerate(train_loader):
|
178 |
+
(inputs, attention_mask, labels) = data
|
179 |
+
|
180 |
+
if max_length is None:
|
181 |
+
max_length = labels.shape[1]
|
182 |
+
|
183 |
+
inputs = inputs.to(device)
|
184 |
+
attention_mask = attention_mask.to(device)
|
185 |
+
labels = labels.to(device)
|
186 |
+
|
187 |
+
optimizer.zero_grad()
|
188 |
+
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
189 |
+
|
190 |
+
for i, (probs, class_name) in enumerate(
|
191 |
+
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
|
192 |
+
):
|
193 |
+
reshaped_tensor = (
|
194 |
+
torch.transpose(probs, 1, 2)
|
195 |
+
.contiguous()
|
196 |
+
.view(probs.shape[0], probs.shape[2], probs.shape[1])
|
197 |
+
)
|
198 |
+
loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(device)
|
199 |
+
|
200 |
+
num_relevant = (labels[:, :, i] != -1).sum()
|
201 |
+
train_loss[class_name] += loss.item() * num_relevant
|
202 |
+
relevant_count[class_name] += num_relevant
|
203 |
+
|
204 |
+
loss.backward(retain_graph=True)
|
205 |
+
|
206 |
+
for i, class_name in enumerate(CLASSES_LIST):
|
207 |
+
train_steps_loss_values[class_name].append(
|
208 |
+
float(train_loss[class_name] / relevant_count[class_name])
|
209 |
+
)
|
210 |
+
|
211 |
+
optimizer.step()
|
212 |
+
if (index_data + 1) % 100 == 0:
|
213 |
+
msg = f"epoch: {epoch} , index_data: {index_data + 1}\n"
|
214 |
+
for i, class_name in enumerate(CLASSES_LIST):
|
215 |
+
msg += f"mean loss train {class_name}: {float(train_loss[class_name] / relevant_count[class_name])}, "
|
216 |
+
|
217 |
+
logger.debug(msg[:-2])
|
218 |
+
|
219 |
+
for i, class_name in enumerate(CLASSES_LIST):
|
220 |
+
train_epochs_loss_values[class_name].append(
|
221 |
+
float(train_loss[class_name] / relevant_count[class_name])
|
222 |
+
)
|
223 |
+
|
224 |
+
for class_name in train_loss.keys():
|
225 |
+
train_loss[class_name] /= relevant_count[class_name]
|
226 |
+
|
227 |
+
msg = f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
|
228 |
+
for i, class_name in enumerate(CLASSES_LIST):
|
229 |
+
msg += f"mean loss train {class_name}: {train_loss[class_name]}, "
|
230 |
+
logger.debug(msg[:-2])
|
231 |
+
|
232 |
+
model.eval()
|
233 |
+
dev_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
234 |
+
dev_accuracy = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
235 |
+
relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
236 |
+
correct_preds = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
237 |
+
un_masks = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
238 |
+
predictions = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
239 |
+
labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
240 |
+
|
241 |
+
all_nikud_types_correct_preds_letter = 0.0
|
242 |
+
|
243 |
+
letter_count = 0.0
|
244 |
+
correct_words_count = 0.0
|
245 |
+
word_count = 0.0
|
246 |
+
with torch.no_grad():
|
247 |
+
for index_data, data in enumerate(dev_loader):
|
248 |
+
(inputs, attention_mask, labels) = data
|
249 |
+
inputs = inputs.to(device)
|
250 |
+
attention_mask = attention_mask.to(device)
|
251 |
+
labels = labels.to(device)
|
252 |
+
|
253 |
+
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
254 |
+
|
255 |
+
for i, (probs, class_name) in enumerate(
|
256 |
+
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
|
257 |
+
):
|
258 |
+
reshaped_tensor = (
|
259 |
+
torch.transpose(probs, 1, 2)
|
260 |
+
.contiguous()
|
261 |
+
.view(probs.shape[0], probs.shape[2], probs.shape[1])
|
262 |
+
)
|
263 |
+
loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(
|
264 |
+
device
|
265 |
+
)
|
266 |
+
un_masked = labels[:, :, i] != -1
|
267 |
+
num_relevant = un_masked.sum()
|
268 |
+
relevant_count[class_name] += num_relevant
|
269 |
+
_, preds = torch.max(probs, 2)
|
270 |
+
dev_loss[class_name] += loss.item() * num_relevant
|
271 |
+
correct_preds[class_name] += torch.sum(
|
272 |
+
preds[un_masked] == labels[:, :, i][un_masked]
|
273 |
+
)
|
274 |
+
un_masks[class_name] = un_masked
|
275 |
+
predictions[class_name] = preds
|
276 |
+
labels_class[class_name] = labels[:, :, i]
|
277 |
+
|
278 |
+
un_mask_all_or = torch.logical_or(
|
279 |
+
torch.logical_or(un_masks["nikud"], un_masks["dagesh"]),
|
280 |
+
un_masks["sin"],
|
281 |
+
)
|
282 |
+
|
283 |
+
correct = {
|
284 |
+
class_name: (torch.ones(un_mask_all_or.shape) == 1).to(device)
|
285 |
+
for class_name in CLASSES_LIST
|
286 |
+
}
|
287 |
+
|
288 |
+
for i, class_name in enumerate(CLASSES_LIST):
|
289 |
+
correct[class_name][un_masks[class_name]] = (
|
290 |
+
predictions[class_name][un_masks[class_name]]
|
291 |
+
== labels_class[class_name][un_masks[class_name]]
|
292 |
+
)
|
293 |
+
|
294 |
+
letter_correct_mask = torch.logical_and(
|
295 |
+
torch.logical_and(correct["sin"], correct["dagesh"]),
|
296 |
+
correct["nikud"],
|
297 |
+
)
|
298 |
+
all_nikud_types_correct_preds_letter += torch.sum(
|
299 |
+
letter_correct_mask[un_mask_all_or]
|
300 |
+
)
|
301 |
+
|
302 |
+
letter_correct_mask[~un_mask_all_or] = True
|
303 |
+
correct_num, total_words_num = calc_num_correct_words(
|
304 |
+
inputs.cpu(), letter_correct_mask
|
305 |
+
)
|
306 |
+
|
307 |
+
word_count += total_words_num
|
308 |
+
correct_words_count += correct_num
|
309 |
+
letter_count += un_mask_all_or.sum()
|
310 |
+
|
311 |
+
for class_name in CLASSES_LIST:
|
312 |
+
dev_loss[class_name] /= relevant_count[class_name]
|
313 |
+
dev_accuracy[class_name] = float(
|
314 |
+
correct_preds[class_name].double() / relevant_count[class_name]
|
315 |
+
)
|
316 |
+
|
317 |
+
dev_loss_values[class_name].append(float(dev_loss[class_name]))
|
318 |
+
dev_accuracy_values[class_name].append(float(dev_accuracy[class_name]))
|
319 |
+
|
320 |
+
dev_all_nikud_types_accuracy_letter = float(
|
321 |
+
all_nikud_types_correct_preds_letter / letter_count
|
322 |
+
)
|
323 |
+
|
324 |
+
dev_accuracy_values["all_nikud_letter"].append(
|
325 |
+
dev_all_nikud_types_accuracy_letter
|
326 |
+
)
|
327 |
+
|
328 |
+
word_all_nikud_accuracy = correct_words_count / word_count
|
329 |
+
dev_accuracy_values["all_nikud_word"].append(word_all_nikud_accuracy)
|
330 |
+
|
331 |
+
msg = (
|
332 |
+
f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
|
333 |
+
f'mean loss Dev nikud: {train_loss["nikud"]}, '
|
334 |
+
f'mean loss Dev dagesh: {train_loss["dagesh"]}, '
|
335 |
+
f'mean loss Dev sin: {train_loss["sin"]}, '
|
336 |
+
f"Dev all nikud types letter Accuracy: {dev_all_nikud_types_accuracy_letter}, "
|
337 |
+
f'Dev nikud letter Accuracy: {dev_accuracy["nikud"]}, '
|
338 |
+
f'Dev dagesh letter Accuracy: {dev_accuracy["dagesh"]}, '
|
339 |
+
f'Dev sin letter Accuracy: {dev_accuracy["sin"]}, '
|
340 |
+
f"Dev word Accuracy: {word_all_nikud_accuracy}"
|
341 |
+
)
|
342 |
+
logger.debug(msg)
|
343 |
+
|
344 |
+
save_progress_details(
|
345 |
+
dev_accuracy_values,
|
346 |
+
train_epochs_loss_values,
|
347 |
+
dev_loss_values,
|
348 |
+
train_steps_loss_values,
|
349 |
+
)
|
350 |
+
|
351 |
+
if dev_all_nikud_types_accuracy_letter > best_accuracy:
|
352 |
+
best_accuracy = dev_all_nikud_types_accuracy_letter
|
353 |
+
best_model = {
|
354 |
+
"epoch": epoch,
|
355 |
+
"model_state_dict": model.state_dict(),
|
356 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
357 |
+
"loss": loss,
|
358 |
+
}
|
359 |
+
|
360 |
+
if epoch % training_params["checkpoints_frequency"] == 0:
|
361 |
+
save_checkpoint_path = os.path.join(
|
362 |
+
output_checkpoints_path, f"checkpoint_model_epoch_{epoch + 1}.pth"
|
363 |
+
)
|
364 |
+
checkpoint = {
|
365 |
+
"epoch": epoch,
|
366 |
+
"model_state_dict": model.state_dict(),
|
367 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
368 |
+
"loss": loss,
|
369 |
+
}
|
370 |
+
torch.save(checkpoint["model_state_dict"], save_checkpoint_path)
|
371 |
+
|
372 |
+
save_model_path = os.path.join(output_model_path, "best_model.pth")
|
373 |
+
torch.save(best_model["model_state_dict"], save_model_path)
|
374 |
+
return (
|
375 |
+
best_model,
|
376 |
+
best_accuracy,
|
377 |
+
train_epochs_loss_values,
|
378 |
+
train_steps_loss_values,
|
379 |
+
dev_loss_values,
|
380 |
+
dev_accuracy_values,
|
381 |
+
)
|
382 |
+
|
383 |
+
|
384 |
+
def save_progress_details(
|
385 |
+
accuracy_dev_values,
|
386 |
+
epochs_loss_train_values,
|
387 |
+
loss_dev_values,
|
388 |
+
steps_loss_train_values,
|
389 |
+
):
|
390 |
+
epochs_data_path = "epochs_data"
|
391 |
+
create_missing_folders(epochs_data_path)
|
392 |
+
|
393 |
+
save_dict_as_json(
|
394 |
+
steps_loss_train_values, epochs_data_path, "steps_loss_train_values.json"
|
395 |
+
)
|
396 |
+
save_dict_as_json(
|
397 |
+
epochs_loss_train_values, epochs_data_path, "epochs_loss_train_values.json"
|
398 |
+
)
|
399 |
+
save_dict_as_json(loss_dev_values, epochs_data_path, "loss_dev_values.json")
|
400 |
+
save_dict_as_json(accuracy_dev_values, epochs_data_path, "accuracy_dev_values.json")
|
401 |
+
|
402 |
+
|
403 |
+
def save_dict_as_json(dict, file_path, file_name):
|
404 |
+
json_data = json.dumps(dict, indent=4)
|
405 |
+
with open(os.path.join(file_path, file_name), "w") as json_file:
|
406 |
+
json_file.write(json_data)
|
407 |
+
|
408 |
+
|
409 |
+
def evaluate(model, test_data, plots_folder=None, device="cpu"):
|
410 |
+
model.to(device)
|
411 |
+
model.eval()
|
412 |
+
|
413 |
+
true_labels = {"nikud": [], "dagesh": [], "sin": []}
|
414 |
+
predictions = {"nikud": 0, "dagesh": 0, "sin": 0}
|
415 |
+
predicted_labels_2_report = {"nikud": [], "dagesh": [], "sin": []}
|
416 |
+
not_masks = {"nikud": 0, "dagesh": 0, "sin": 0}
|
417 |
+
correct_preds = {"nikud": 0, "dagesh": 0, "sin": 0}
|
418 |
+
relevant_count = {"nikud": 0, "dagesh": 0, "sin": 0}
|
419 |
+
labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
|
420 |
+
|
421 |
+
all_nikud_types_letter_level_correct = 0.0
|
422 |
+
nikud_letter_level_correct = 0.0
|
423 |
+
dagesh_letter_level_correct = 0.0
|
424 |
+
sin_letter_level_correct = 0.0
|
425 |
+
|
426 |
+
letters_count = 0.0
|
427 |
+
words_count = 0.0
|
428 |
+
correct_words_count = 0.0
|
429 |
+
with torch.no_grad():
|
430 |
+
for index_data, data in enumerate(test_data):
|
431 |
+
if DEBUG_MODE and index_data > 100:
|
432 |
+
break
|
433 |
+
|
434 |
+
(inputs, attention_mask, labels) = data
|
435 |
+
|
436 |
+
inputs = inputs.to(device)
|
437 |
+
attention_mask = attention_mask.to(device)
|
438 |
+
labels = labels.to(device)
|
439 |
+
|
440 |
+
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
|
441 |
+
|
442 |
+
for i, (probs, class_name) in enumerate(
|
443 |
+
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
|
444 |
+
):
|
445 |
+
labels_class[class_name] = labels[:, :, i]
|
446 |
+
not_masked = labels_class[class_name] != -1
|
447 |
+
num_relevant = not_masked.sum()
|
448 |
+
relevant_count[class_name] += num_relevant
|
449 |
+
_, preds = torch.max(probs, 2)
|
450 |
+
correct_preds[class_name] += torch.sum(
|
451 |
+
preds[not_masked] == labels_class[class_name][not_masked]
|
452 |
+
)
|
453 |
+
predictions[class_name] = preds
|
454 |
+
not_masks[class_name] = not_masked
|
455 |
+
|
456 |
+
if len(true_labels[class_name]) == 0:
|
457 |
+
true_labels[class_name] = (
|
458 |
+
labels_class[class_name][not_masked].cpu().numpy()
|
459 |
+
)
|
460 |
+
else:
|
461 |
+
true_labels[class_name] = np.concatenate(
|
462 |
+
(
|
463 |
+
true_labels[class_name],
|
464 |
+
labels_class[class_name][not_masked].cpu().numpy(),
|
465 |
+
)
|
466 |
+
)
|
467 |
+
|
468 |
+
if len(predicted_labels_2_report[class_name]) == 0:
|
469 |
+
predicted_labels_2_report[class_name] = (
|
470 |
+
preds[not_masked].cpu().numpy()
|
471 |
+
)
|
472 |
+
else:
|
473 |
+
predicted_labels_2_report[class_name] = np.concatenate(
|
474 |
+
(
|
475 |
+
predicted_labels_2_report[class_name],
|
476 |
+
preds[not_masked].cpu().numpy(),
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
not_mask_all_or = torch.logical_or(
|
481 |
+
torch.logical_or(not_masks["nikud"], not_masks["dagesh"]),
|
482 |
+
not_masks["sin"],
|
483 |
+
)
|
484 |
+
|
485 |
+
correct_nikud = (torch.ones(not_mask_all_or.shape) == 1).to(device)
|
486 |
+
correct_dagesh = (torch.ones(not_mask_all_or.shape) == 1).to(device)
|
487 |
+
correct_sin = (torch.ones(not_mask_all_or.shape) == 1).to(device)
|
488 |
+
|
489 |
+
correct_nikud[not_masks["nikud"]] = (
|
490 |
+
predictions["nikud"][not_masks["nikud"]]
|
491 |
+
== labels_class["nikud"][not_masks["nikud"]]
|
492 |
+
)
|
493 |
+
correct_dagesh[not_masks["dagesh"]] = (
|
494 |
+
predictions["dagesh"][not_masks["dagesh"]]
|
495 |
+
== labels_class["dagesh"][not_masks["dagesh"]]
|
496 |
+
)
|
497 |
+
correct_sin[not_masks["sin"]] = (
|
498 |
+
predictions["sin"][not_masks["sin"]]
|
499 |
+
== labels_class["sin"][not_masks["sin"]]
|
500 |
+
)
|
501 |
+
|
502 |
+
letter_correct_mask = torch.logical_and(
|
503 |
+
torch.logical_and(correct_sin, correct_dagesh), correct_nikud
|
504 |
+
)
|
505 |
+
all_nikud_types_letter_level_correct += torch.sum(
|
506 |
+
letter_correct_mask[not_mask_all_or]
|
507 |
+
)
|
508 |
+
|
509 |
+
letter_correct_mask[~not_mask_all_or] = True
|
510 |
+
total_correct_count, total_words_num = calc_num_correct_words(
|
511 |
+
inputs.cpu(), letter_correct_mask
|
512 |
+
)
|
513 |
+
|
514 |
+
words_count += total_words_num
|
515 |
+
correct_words_count += total_correct_count
|
516 |
+
|
517 |
+
letters_count += not_mask_all_or.sum()
|
518 |
+
|
519 |
+
nikud_letter_level_correct += torch.sum(correct_nikud[not_mask_all_or])
|
520 |
+
dagesh_letter_level_correct += torch.sum(correct_dagesh[not_mask_all_or])
|
521 |
+
sin_letter_level_correct += torch.sum(correct_sin[not_mask_all_or])
|
522 |
+
|
523 |
+
for i, name in enumerate(CLASSES_LIST):
|
524 |
+
index_labels = np.unique(true_labels[name])
|
525 |
+
cm = confusion_matrix(
|
526 |
+
true_labels[name], predicted_labels_2_report[name], labels=index_labels
|
527 |
+
)
|
528 |
+
|
529 |
+
vowel_label = [Nikud.id_2_label[name][l] for l in index_labels]
|
530 |
+
unique_vowels_names = [
|
531 |
+
Nikud.sign_2_name[int(vowel)] for vowel in vowel_label if vowel != "WITHOUT"
|
532 |
+
]
|
533 |
+
if "WITHOUT" in vowel_label:
|
534 |
+
unique_vowels_names += ["WITHOUT"]
|
535 |
+
cm_df = pd.DataFrame(cm, index=unique_vowels_names, columns=unique_vowels_names)
|
536 |
+
|
537 |
+
# Display confusion matrix
|
538 |
+
plt.figure(figsize=(10, 8))
|
539 |
+
sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d")
|
540 |
+
plt.title("Confusion Matrix")
|
541 |
+
plt.xlabel("True Label")
|
542 |
+
plt.ylabel("Predicted Label")
|
543 |
+
if plots_folder is None:
|
544 |
+
plt.show()
|
545 |
+
else:
|
546 |
+
plt.savefig(os.path.join(plots_folder, f"Confusion_Matrix_{name}.jpg"))
|
547 |
+
|
548 |
+
all_nikud_types_letter_level_correct = (
|
549 |
+
all_nikud_types_letter_level_correct / letters_count
|
550 |
+
)
|
551 |
+
all_nikud_types_word_level_correct = correct_words_count / words_count
|
552 |
+
nikud_letter_level_correct = nikud_letter_level_correct / letters_count
|
553 |
+
dagesh_letter_level_correct = dagesh_letter_level_correct / letters_count
|
554 |
+
sin_letter_level_correct = sin_letter_level_correct / letters_count
|
555 |
+
print("\n")
|
556 |
+
print(f"nikud_letter_level_correct = {nikud_letter_level_correct}")
|
557 |
+
print(f"dagesh_letter_level_correct = {dagesh_letter_level_correct}")
|
558 |
+
print(f"sin_letter_level_correct = {sin_letter_level_correct}")
|
559 |
+
print(f"word_level_correct = {all_nikud_types_word_level_correct}")
|
560 |
+
|
561 |
+
return all_nikud_types_word_level_correct, all_nikud_types_letter_level_correct
|
src/plot_helpers.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general
|
2 |
+
import os
|
3 |
+
|
4 |
+
# visual
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
cols = ["precision", "recall", "f1-score", "support"]
|
8 |
+
|
9 |
+
|
10 |
+
def generate_plot_by_nikud_dagesh_sin_dict(nikud_dagesh_sin_dict, title, y_axis, plot_folder=None):
|
11 |
+
# Create a figure and axis
|
12 |
+
plt.figure(figsize=(8, 6))
|
13 |
+
plt.title(title)
|
14 |
+
|
15 |
+
ax = plt.gca()
|
16 |
+
indexes = list(range(1, len(nikud_dagesh_sin_dict["nikud"]) + 1))
|
17 |
+
|
18 |
+
# Plot data series with different colors and labels
|
19 |
+
ax.plot(indexes, nikud_dagesh_sin_dict["nikud"], color='blue', label='Nikud')
|
20 |
+
ax.plot(indexes, nikud_dagesh_sin_dict["dagesh"], color='green', label='Dagesh')
|
21 |
+
ax.plot(indexes, nikud_dagesh_sin_dict["sin"], color='red', label='Sin')
|
22 |
+
|
23 |
+
# Add legend
|
24 |
+
ax.legend()
|
25 |
+
|
26 |
+
# Set labels and title
|
27 |
+
ax.set_xlabel('Epoch')
|
28 |
+
ax.set_ylabel(y_axis)
|
29 |
+
|
30 |
+
if plot_folder is None:
|
31 |
+
plt.show()
|
32 |
+
else:
|
33 |
+
plt.savefig(os.path.join(plot_folder, f'{title.replace(" ", "_")}_plot.jpg'))
|
34 |
+
|
35 |
+
|
36 |
+
def generate_word_and_letter_accuracy_plot(word_and_letter_accuracy_dict, title, plot_folder=None):
|
37 |
+
# Create a figure and axis
|
38 |
+
plt.figure(figsize=(8, 6))
|
39 |
+
plt.title(title)
|
40 |
+
|
41 |
+
ax = plt.gca()
|
42 |
+
indexes = list(range(1, len(word_and_letter_accuracy_dict["all_nikud_letter"]) + 1))
|
43 |
+
|
44 |
+
# Plot data series with different colors and labels
|
45 |
+
ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_letter"], color='blue', label='Letter')
|
46 |
+
ax.plot(indexes, word_and_letter_accuracy_dict["all_nikud_word"], color='green', label='Word')
|
47 |
+
|
48 |
+
# Add legend
|
49 |
+
ax.legend()
|
50 |
+
|
51 |
+
# Set labels and title
|
52 |
+
ax.set_xlabel("Epoch")
|
53 |
+
ax.set_ylabel("Accuracy")
|
54 |
+
|
55 |
+
if plot_folder is None:
|
56 |
+
plt.show()
|
57 |
+
else:
|
58 |
+
plt.savefig(os.path.join(plot_folder, 'word_and_letter_accuracy_plot.jpg'))
|
src/running_params.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
DEBUG_MODE = False
|
2 |
+
BATCH_SIZE = 32
|
3 |
+
MAX_LENGTH_SEN = 1024
|
src/utiles_data.py
ADDED
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# general
|
2 |
+
import os.path
|
3 |
+
from datetime import datetime
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List, Tuple
|
6 |
+
from uuid import uuid1
|
7 |
+
import re
|
8 |
+
import glob2
|
9 |
+
|
10 |
+
# visual
|
11 |
+
import matplotlib
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
# ML
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from torch.utils.data import Dataset
|
19 |
+
|
20 |
+
from src.running_params import DEBUG_MODE, MAX_LENGTH_SEN
|
21 |
+
|
22 |
+
matplotlib.use("agg")
|
23 |
+
unique_key = str(uuid1())
|
24 |
+
|
25 |
+
|
26 |
+
class Nikud:
|
27 |
+
"""
|
28 |
+
1456 HEBREW POINT SHEVA
|
29 |
+
1457 HEBREW POINT HATAF SEGOL
|
30 |
+
1458 HEBREW POINT HATAF PATAH
|
31 |
+
1459 HEBREW POINT HATAF QAMATS
|
32 |
+
1460 HEBREW POINT HIRIQ
|
33 |
+
1461 HEBREW POINT TSERE
|
34 |
+
1462 HEBREW POINT SEGOL
|
35 |
+
1463 HEBREW POINT PATAH
|
36 |
+
1464 HEBREW POINT QAMATS
|
37 |
+
1465 HEBREW POINT HOLAM
|
38 |
+
1466 HEBREW POINT HOLAM HASER FOR VAV ***EXTENDED***
|
39 |
+
1467 HEBREW POINT QUBUTS
|
40 |
+
1468 HEBREW POINT DAGESH OR MAPIQ
|
41 |
+
1469 HEBREW POINT METEG ***EXTENDED***
|
42 |
+
1470 HEBREW PUNCTUATION MAQAF ***EXTENDED***
|
43 |
+
1471 HEBREW POINT RAFE ***EXTENDED***
|
44 |
+
1472 HEBREW PUNCTUATION PASEQ ***EXTENDED***
|
45 |
+
1473 HEBREW POINT SHIN DOT
|
46 |
+
1474 HEBREW POINT SIN DOT
|
47 |
+
"""
|
48 |
+
|
49 |
+
nikud_dict = {
|
50 |
+
"SHVA": 1456,
|
51 |
+
"REDUCED_SEGOL": 1457,
|
52 |
+
"REDUCED_PATAKH": 1458,
|
53 |
+
"REDUCED_KAMATZ": 1459,
|
54 |
+
"HIRIK": 1460,
|
55 |
+
"TZEIRE": 1461,
|
56 |
+
"SEGOL": 1462,
|
57 |
+
"PATAKH": 1463,
|
58 |
+
"KAMATZ": 1464,
|
59 |
+
"KAMATZ_KATAN": 1479,
|
60 |
+
"HOLAM": 1465,
|
61 |
+
"HOLAM HASER VAV": 1466,
|
62 |
+
"KUBUTZ": 1467,
|
63 |
+
"DAGESH OR SHURUK": 1468,
|
64 |
+
"METEG": 1469,
|
65 |
+
"PUNCTUATION MAQAF": 1470,
|
66 |
+
"RAFE": 1471,
|
67 |
+
"PUNCTUATION PASEQ": 1472,
|
68 |
+
"SHIN_YEMANIT": 1473,
|
69 |
+
"SHIN_SMALIT": 1474,
|
70 |
+
}
|
71 |
+
|
72 |
+
skip_nikud = (
|
73 |
+
[]
|
74 |
+
) # [nikud_dict["KAMATZ_KATAN"], nikud_dict["HOLAM HASER VAV"], nikud_dict["METEG"], nikud_dict["PUNCTUATION MAQAF"], nikud_dict["PUNCTUATION PASEQ"]]
|
75 |
+
sign_2_name = {sign: name for name, sign in nikud_dict.items()}
|
76 |
+
sin = [nikud_dict["RAFE"], nikud_dict["SHIN_YEMANIT"], nikud_dict["SHIN_SMALIT"]]
|
77 |
+
dagesh = [
|
78 |
+
nikud_dict["RAFE"],
|
79 |
+
nikud_dict["DAGESH OR SHURUK"],
|
80 |
+
] # note that DAGESH and SHURUK are one and the same
|
81 |
+
nikud = []
|
82 |
+
for v in nikud_dict.values():
|
83 |
+
if v not in sin and v not in skip_nikud:
|
84 |
+
nikud.append(v)
|
85 |
+
all_nikud_ord = {v for v in nikud_dict.values()}
|
86 |
+
all_nikud_chr = {chr(v) for v in nikud_dict.values()}
|
87 |
+
|
88 |
+
label_2_id = {
|
89 |
+
"nikud": {label: i for i, label in enumerate(nikud + ["WITHOUT"])},
|
90 |
+
"dagesh": {label: i for i, label in enumerate(dagesh + ["WITHOUT"])},
|
91 |
+
"sin": {label: i for i, label in enumerate(sin + ["WITHOUT"])},
|
92 |
+
}
|
93 |
+
id_2_label = {
|
94 |
+
"nikud": {i: label for i, label in enumerate(nikud + ["WITHOUT"])},
|
95 |
+
"dagesh": {i: label for i, label in enumerate(dagesh + ["WITHOUT"])},
|
96 |
+
"sin": {i: label for i, label in enumerate(sin + ["WITHOUT"])},
|
97 |
+
}
|
98 |
+
|
99 |
+
DAGESH_LETTER = nikud_dict["DAGESH OR SHURUK"]
|
100 |
+
RAFE = nikud_dict["RAFE"]
|
101 |
+
PAD_OR_IRRELEVANT = -1
|
102 |
+
|
103 |
+
LEN_NIKUD = len(label_2_id["nikud"])
|
104 |
+
LEN_DAGESH = len(label_2_id["dagesh"])
|
105 |
+
LEN_SIN = len(label_2_id["sin"])
|
106 |
+
|
107 |
+
def id_2_char(self, c, class_type):
|
108 |
+
if c == -1:
|
109 |
+
return ""
|
110 |
+
|
111 |
+
label = self.id_2_label[class_type][c]
|
112 |
+
|
113 |
+
if label != "WITHOUT":
|
114 |
+
print("Label =", chr(self.id_2_label[class_type][c]))
|
115 |
+
return chr(self.id_2_label[class_type][c])
|
116 |
+
return ""
|
117 |
+
|
118 |
+
|
119 |
+
class Letters:
|
120 |
+
hebrew = [chr(c) for c in range(0x05D0, 0x05EA + 1)]
|
121 |
+
VALID_LETTERS = [
|
122 |
+
" ",
|
123 |
+
"!",
|
124 |
+
'"',
|
125 |
+
"'",
|
126 |
+
"(",
|
127 |
+
")",
|
128 |
+
",",
|
129 |
+
"-",
|
130 |
+
".",
|
131 |
+
":",
|
132 |
+
";",
|
133 |
+
"?",
|
134 |
+
] + hebrew
|
135 |
+
SPECIAL_TOKENS = ["H", "O", "5", "1"]
|
136 |
+
ENDINGS_TO_REGULAR = dict(zip("ืืืืฃืฅ", "ืืื ืคืฆ"))
|
137 |
+
vocab = VALID_LETTERS + SPECIAL_TOKENS
|
138 |
+
vocab_size = len(vocab)
|
139 |
+
|
140 |
+
|
141 |
+
class Letter:
|
142 |
+
def __init__(self, letter):
|
143 |
+
self.letter = letter
|
144 |
+
self.normalized = None
|
145 |
+
self.dagesh = None
|
146 |
+
self.sin = None
|
147 |
+
self.nikud = None
|
148 |
+
|
149 |
+
def normalize(self, letter):
|
150 |
+
if letter in Letters.VALID_LETTERS:
|
151 |
+
return letter
|
152 |
+
if letter in Letters.ENDINGS_TO_REGULAR:
|
153 |
+
return Letters.ENDINGS_TO_REGULAR[letter]
|
154 |
+
if letter in ["\n", "\t"]:
|
155 |
+
return " "
|
156 |
+
if letter in ["โ", "โ", "โ", "โ", "โ", "+"]:
|
157 |
+
return "-"
|
158 |
+
if letter == "[":
|
159 |
+
return "("
|
160 |
+
if letter == "]":
|
161 |
+
return ")"
|
162 |
+
if letter in ["ยด", "โ", "โ"]:
|
163 |
+
return "'"
|
164 |
+
if letter in ["โ", "โ", "ืด"]:
|
165 |
+
return '"'
|
166 |
+
if letter.isdigit():
|
167 |
+
if int(letter) == 1:
|
168 |
+
return "1"
|
169 |
+
else:
|
170 |
+
return "5"
|
171 |
+
if letter == "โฆ":
|
172 |
+
return ","
|
173 |
+
if letter in ["ืฒ", "ืฐ", "ืฑ"]:
|
174 |
+
return "H"
|
175 |
+
return "O"
|
176 |
+
|
177 |
+
def can_dagesh(self, letter):
|
178 |
+
return letter in ("ืืืืืืืืืืืื ืกืคืฆืงืฉืช" + "ืืฃ")
|
179 |
+
|
180 |
+
def can_sin(self, letter):
|
181 |
+
return letter == "ืฉ"
|
182 |
+
|
183 |
+
def can_nikud(self, letter):
|
184 |
+
return letter in ("ืืืืืืืืืืืืืื ืกืขืคืฆืงืจืฉืช" + "ืื")
|
185 |
+
|
186 |
+
def get_label_letter(self, labels):
|
187 |
+
dagesh_sin_nikud = [
|
188 |
+
True if self.can_dagesh(self.letter) else False,
|
189 |
+
True if self.can_sin(self.letter) else False,
|
190 |
+
True if self.can_nikud(self.letter) else False,
|
191 |
+
]
|
192 |
+
|
193 |
+
labels_ids = {
|
194 |
+
"nikud": Nikud.PAD_OR_IRRELEVANT,
|
195 |
+
"dagesh": Nikud.PAD_OR_IRRELEVANT,
|
196 |
+
"sin": Nikud.PAD_OR_IRRELEVANT,
|
197 |
+
}
|
198 |
+
|
199 |
+
normalized = self.normalize(self.letter)
|
200 |
+
|
201 |
+
i = 0
|
202 |
+
if Nikud.nikud_dict["PUNCTUATION PASEQ"] in labels:
|
203 |
+
labels.remove(Nikud.nikud_dict["PUNCTUATION PASEQ"])
|
204 |
+
if Nikud.nikud_dict["PUNCTUATION MAQAF"] in labels:
|
205 |
+
labels.remove(Nikud.nikud_dict["PUNCTUATION MAQAF"])
|
206 |
+
if Nikud.nikud_dict["HOLAM HASER VAV"] in labels:
|
207 |
+
labels.remove(Nikud.nikud_dict["HOLAM HASER VAV"])
|
208 |
+
if Nikud.nikud_dict["METEG"] in labels:
|
209 |
+
labels.remove(Nikud.nikud_dict["METEG"])
|
210 |
+
if Nikud.nikud_dict["KAMATZ_KATAN"] in labels:
|
211 |
+
labels[labels.index(Nikud.nikud_dict["KAMATZ_KATAN"])] = Nikud.nikud_dict[
|
212 |
+
"KAMATZ"
|
213 |
+
]
|
214 |
+
for index, (class_name, group) in enumerate(
|
215 |
+
zip(
|
216 |
+
["dagesh", "sin", "nikud"],
|
217 |
+
[[Nikud.DAGESH_LETTER], Nikud.sin, Nikud.nikud],
|
218 |
+
)
|
219 |
+
):
|
220 |
+
# notice - order is important: dagesh then sin and then nikud
|
221 |
+
if dagesh_sin_nikud[index]:
|
222 |
+
if i < len(labels) and labels[i] in group:
|
223 |
+
labels_ids[class_name] = Nikud.label_2_id[class_name][labels[i]]
|
224 |
+
i += 1
|
225 |
+
else:
|
226 |
+
labels_ids[class_name] = Nikud.label_2_id[class_name]["WITHOUT"]
|
227 |
+
|
228 |
+
if (
|
229 |
+
np.array(dagesh_sin_nikud).all()
|
230 |
+
and len(labels) == 3
|
231 |
+
and labels[0] in Nikud.sin
|
232 |
+
):
|
233 |
+
labels_ids["nikud"] = Nikud.label_2_id["nikud"][labels[2]]
|
234 |
+
labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
|
235 |
+
|
236 |
+
if (
|
237 |
+
self.can_sin(self.letter)
|
238 |
+
and len(labels) == 2
|
239 |
+
and labels[1] == Nikud.DAGESH_LETTER
|
240 |
+
):
|
241 |
+
labels_ids["dagesh"] = Nikud.label_2_id["dagesh"][labels[1]]
|
242 |
+
labels_ids["nikud"] = Nikud.label_2_id[class_name]["WITHOUT"]
|
243 |
+
|
244 |
+
if (
|
245 |
+
self.letter == "ื"
|
246 |
+
and labels_ids["dagesh"] == Nikud.DAGESH_LETTER
|
247 |
+
and labels_ids["nikud"] == Nikud.label_2_id["nikud"]["WITHOUT"]
|
248 |
+
):
|
249 |
+
labels_ids["dagesh"] = Nikud.label_2_id["dagesh"]["WITHOUT"]
|
250 |
+
labels_ids["nikud"] = Nikud.DAGESH_LETTER
|
251 |
+
|
252 |
+
self.normalized = normalized
|
253 |
+
self.dagesh = labels_ids["dagesh"]
|
254 |
+
self.sin = labels_ids["sin"]
|
255 |
+
self.nikud = labels_ids["nikud"]
|
256 |
+
|
257 |
+
def name_of(self, letter):
|
258 |
+
if "ื" <= letter <= "ืช":
|
259 |
+
return letter
|
260 |
+
if letter == Nikud.DAGESH_LETTER:
|
261 |
+
return "ืืืฉ\ืฉืืจืืง"
|
262 |
+
if letter == Nikud.KAMATZ:
|
263 |
+
return "ืงืืฅ"
|
264 |
+
if letter == Nikud.PATAKH:
|
265 |
+
return "ืคืชื"
|
266 |
+
if letter == Nikud.TZEIRE:
|
267 |
+
return "ืฆืืจื"
|
268 |
+
if letter == Nikud.SEGOL:
|
269 |
+
return "ืกืืื"
|
270 |
+
if letter == Nikud.SHVA:
|
271 |
+
return "ืฉืื"
|
272 |
+
if letter == Nikud.HOLAM:
|
273 |
+
return "ืืืื"
|
274 |
+
if letter == Nikud.KUBUTZ:
|
275 |
+
return "ืงืืืืฅ"
|
276 |
+
if letter == Nikud.HIRIK:
|
277 |
+
return "ืืืจืืง"
|
278 |
+
if letter == Nikud.REDUCED_KAMATZ:
|
279 |
+
return "ืืืฃ-ืงืืฅ"
|
280 |
+
if letter == Nikud.REDUCED_PATAKH:
|
281 |
+
return "ืืืฃ-ืคืชื"
|
282 |
+
if letter == Nikud.REDUCED_SEGOL:
|
283 |
+
return "ืืืฃ-ืกืืื"
|
284 |
+
if letter == Nikud.SHIN_SMALIT:
|
285 |
+
return "ืฉืื-ืฉืืืืืช"
|
286 |
+
if letter == Nikud.SHIN_YEMANIT:
|
287 |
+
return "ืฉืื-ืืื ืืช"
|
288 |
+
if letter.isprintable():
|
289 |
+
return letter
|
290 |
+
return "ืื ืืืืข ({})".format(hex(ord(letter)))
|
291 |
+
|
292 |
+
|
293 |
+
def text_contains_nikud(text):
|
294 |
+
return len(set(text) & Nikud.all_nikud_chr) > 0
|
295 |
+
|
296 |
+
|
297 |
+
def combine_sentences(list_sentences, max_length=0, is_train=False):
|
298 |
+
all_new_sentences = []
|
299 |
+
new_sen = ""
|
300 |
+
index = 0
|
301 |
+
while index < len(list_sentences):
|
302 |
+
sen = list_sentences[index]
|
303 |
+
|
304 |
+
if not text_contains_nikud(sen) and (
|
305 |
+
"------------------" in sen or sen == "\n"
|
306 |
+
):
|
307 |
+
if len(new_sen) > 0:
|
308 |
+
all_new_sentences.append(new_sen)
|
309 |
+
if not is_train:
|
310 |
+
all_new_sentences.append(sen)
|
311 |
+
new_sen = ""
|
312 |
+
index += 1
|
313 |
+
continue
|
314 |
+
|
315 |
+
if not text_contains_nikud(sen) and is_train:
|
316 |
+
index += 1
|
317 |
+
continue
|
318 |
+
|
319 |
+
if len(sen) > max_length:
|
320 |
+
update_sen = sen.replace(". ", f". {unique_key}")
|
321 |
+
update_sen = update_sen.replace("? ", f"? {unique_key}")
|
322 |
+
update_sen = update_sen.replace("! ", f"! {unique_key}")
|
323 |
+
update_sen = update_sen.replace("โ ", f"โ {unique_key}")
|
324 |
+
update_sen = update_sen.replace("\t", f"\t{unique_key}")
|
325 |
+
part_sentence = update_sen.split(unique_key)
|
326 |
+
|
327 |
+
good_parts = []
|
328 |
+
for p in part_sentence:
|
329 |
+
if len(p) < max_length:
|
330 |
+
good_parts.append(p)
|
331 |
+
else:
|
332 |
+
prev = 0
|
333 |
+
while prev <= len(p):
|
334 |
+
part = p[prev : (prev + max_length)]
|
335 |
+
last_space = 0
|
336 |
+
if " " in part:
|
337 |
+
last_space = part[::-1].index(" ") + 1
|
338 |
+
next = prev + max_length - last_space
|
339 |
+
part = p[prev:next]
|
340 |
+
good_parts.append(part)
|
341 |
+
prev = next
|
342 |
+
list_sentences = (
|
343 |
+
list_sentences[:index] + good_parts + list_sentences[index + 1 :]
|
344 |
+
)
|
345 |
+
continue
|
346 |
+
if new_sen == "":
|
347 |
+
new_sen = sen
|
348 |
+
elif len(new_sen) + len(sen) < max_length:
|
349 |
+
new_sen += sen
|
350 |
+
else:
|
351 |
+
all_new_sentences.append(new_sen)
|
352 |
+
new_sen = sen
|
353 |
+
|
354 |
+
index += 1
|
355 |
+
if len(new_sen) > 0:
|
356 |
+
all_new_sentences.append(new_sen)
|
357 |
+
return all_new_sentences
|
358 |
+
|
359 |
+
|
360 |
+
class NikudDataset(Dataset):
|
361 |
+
def __init__(
|
362 |
+
self,
|
363 |
+
tokenizer,
|
364 |
+
folder=None,
|
365 |
+
file=None,
|
366 |
+
logger=None,
|
367 |
+
max_length=0,
|
368 |
+
is_train=False,
|
369 |
+
):
|
370 |
+
self.max_length = max_length
|
371 |
+
self.tokenizer = tokenizer
|
372 |
+
self.is_train = is_train
|
373 |
+
self.data = None
|
374 |
+
self.origin_data = None
|
375 |
+
if folder is not None:
|
376 |
+
self.data, self.origin_data = self.read_data_folder(folder, logger)
|
377 |
+
elif file is not None:
|
378 |
+
self.data, self.origin_data = self.read_data(file, logger)
|
379 |
+
self.prepered_data = None
|
380 |
+
|
381 |
+
def read_data_folder(self, folder_path: str, logger=None):
|
382 |
+
all_files = glob2.glob(f"{folder_path}/**/*.txt", recursive=True)
|
383 |
+
msg = f"number of files: " + str(len(all_files))
|
384 |
+
if logger:
|
385 |
+
logger.debug(msg)
|
386 |
+
else:
|
387 |
+
print(msg)
|
388 |
+
all_data = []
|
389 |
+
all_origin_data = []
|
390 |
+
if DEBUG_MODE:
|
391 |
+
all_files = all_files[0:2]
|
392 |
+
for file in all_files:
|
393 |
+
if "not_use" in file or "NakdanResults" in file:
|
394 |
+
continue
|
395 |
+
data, origin_data = self.read_data(file, logger)
|
396 |
+
all_data.extend(data)
|
397 |
+
all_origin_data.extend(origin_data)
|
398 |
+
return all_data, all_origin_data
|
399 |
+
|
400 |
+
def read_data(self, filepath: str, logger=None) -> List[Tuple[str, list]]:
|
401 |
+
msg = f"read file: {filepath}"
|
402 |
+
if logger:
|
403 |
+
logger.debug(msg)
|
404 |
+
else:
|
405 |
+
print(msg)
|
406 |
+
data = []
|
407 |
+
orig_data = []
|
408 |
+
with open(filepath, "r", encoding="utf-8") as file:
|
409 |
+
file_data = file.read()
|
410 |
+
data_list = self.split_text(file_data)
|
411 |
+
|
412 |
+
for sen in tqdm(data_list, desc=f"Source: {os.path.basename(filepath)}"):
|
413 |
+
if sen == "":
|
414 |
+
continue
|
415 |
+
|
416 |
+
labels = []
|
417 |
+
text = ""
|
418 |
+
text_org = ""
|
419 |
+
index = 0
|
420 |
+
sentence_length = len(sen)
|
421 |
+
while index < sentence_length:
|
422 |
+
if (
|
423 |
+
ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
|
424 |
+
or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
|
425 |
+
or ord(sen[index]) == Nikud.nikud_dict["METEG"]
|
426 |
+
):
|
427 |
+
index += 1
|
428 |
+
continue
|
429 |
+
|
430 |
+
label = []
|
431 |
+
l = Letter(sen[index])
|
432 |
+
if not (l.letter not in Nikud.all_nikud_chr):
|
433 |
+
if sen[index - 1] == "\n":
|
434 |
+
index += 1
|
435 |
+
continue
|
436 |
+
assert l.letter not in Nikud.all_nikud_chr
|
437 |
+
if sen[index] in Letters.hebrew:
|
438 |
+
index += 1
|
439 |
+
while (
|
440 |
+
index < sentence_length
|
441 |
+
and ord(sen[index]) in Nikud.all_nikud_ord
|
442 |
+
):
|
443 |
+
label.append(ord(sen[index]))
|
444 |
+
index += 1
|
445 |
+
else:
|
446 |
+
index += 1
|
447 |
+
|
448 |
+
l.get_label_letter(label)
|
449 |
+
text += l.normalized
|
450 |
+
text_org += l.letter
|
451 |
+
labels.append(l)
|
452 |
+
|
453 |
+
data.append((text, labels))
|
454 |
+
orig_data.append(text_org)
|
455 |
+
|
456 |
+
return data, orig_data
|
457 |
+
|
458 |
+
def read_single_text(self, text: str, logger=None) -> List[Tuple[str, list]]:
|
459 |
+
# msg = f"read file: {filepath}"
|
460 |
+
# if logger:
|
461 |
+
# logger.debug(msg)
|
462 |
+
# else:
|
463 |
+
# print(msg)
|
464 |
+
data = []
|
465 |
+
orig_data = []
|
466 |
+
# with open(filepath, "r", encoding="utf-8") as file:
|
467 |
+
# file_data = file.read()
|
468 |
+
data_list = self.split_text(text)
|
469 |
+
# print("data_list", data_list)
|
470 |
+
for sen in tqdm(data_list, desc=f"Source: {data}"):
|
471 |
+
if sen == "":
|
472 |
+
continue
|
473 |
+
|
474 |
+
labels = []
|
475 |
+
text = ""
|
476 |
+
text_org = ""
|
477 |
+
index = 0
|
478 |
+
sentence_length = len(sen)
|
479 |
+
while index < sentence_length:
|
480 |
+
if (
|
481 |
+
ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION MAQAF"]
|
482 |
+
or ord(sen[index]) == Nikud.nikud_dict["PUNCTUATION PASEQ"]
|
483 |
+
or ord(sen[index]) == Nikud.nikud_dict["METEG"]
|
484 |
+
):
|
485 |
+
index += 1
|
486 |
+
continue
|
487 |
+
|
488 |
+
label = []
|
489 |
+
l = Letter(sen[index])
|
490 |
+
if not (l.letter not in Nikud.all_nikud_chr):
|
491 |
+
if sen[index - 1] == "\n":
|
492 |
+
index += 1
|
493 |
+
continue
|
494 |
+
assert l.letter not in Nikud.all_nikud_chr
|
495 |
+
if sen[index] in Letters.hebrew:
|
496 |
+
index += 1
|
497 |
+
while (
|
498 |
+
index < sentence_length
|
499 |
+
and ord(sen[index]) in Nikud.all_nikud_ord
|
500 |
+
):
|
501 |
+
label.append(ord(sen[index]))
|
502 |
+
index += 1
|
503 |
+
else:
|
504 |
+
index += 1
|
505 |
+
|
506 |
+
l.get_label_letter(label)
|
507 |
+
text += l.normalized
|
508 |
+
text_org += l.letter
|
509 |
+
labels.append(l)
|
510 |
+
|
511 |
+
data.append((text, labels))
|
512 |
+
orig_data.append(text_org)
|
513 |
+
self.data = data
|
514 |
+
self.origin_data = orig_data
|
515 |
+
return data, orig_data
|
516 |
+
|
517 |
+
def split_text(self, file_data):
|
518 |
+
file_data = file_data.replace("\n", f"\n{unique_key}")
|
519 |
+
data_list = file_data.split(unique_key)
|
520 |
+
data_list = combine_sentences(
|
521 |
+
data_list, is_train=self.is_train, max_length=MAX_LENGTH_SEN
|
522 |
+
)
|
523 |
+
return data_list
|
524 |
+
|
525 |
+
def show_data_labels(self, plots_folder=None):
|
526 |
+
nikud = [
|
527 |
+
Nikud.id_2_label["nikud"][label.nikud]
|
528 |
+
for _, label_list in self.data
|
529 |
+
for label in label_list
|
530 |
+
if label.nikud != -1
|
531 |
+
]
|
532 |
+
dagesh = [
|
533 |
+
Nikud.id_2_label["dagesh"][label.dagesh]
|
534 |
+
for _, label_list in self.data
|
535 |
+
for label in label_list
|
536 |
+
if label.dagesh != -1
|
537 |
+
]
|
538 |
+
sin = [
|
539 |
+
Nikud.id_2_label["sin"][label.sin]
|
540 |
+
for _, label_list in self.data
|
541 |
+
for label in label_list
|
542 |
+
if label.sin != -1
|
543 |
+
]
|
544 |
+
|
545 |
+
vowels = nikud + dagesh + sin
|
546 |
+
unique_vowels, label_counts = np.unique(vowels, return_counts=True)
|
547 |
+
unique_vowels_names = [
|
548 |
+
Nikud.sign_2_name[int(vowel)]
|
549 |
+
for vowel in unique_vowels
|
550 |
+
if vowel != "WITHOUT"
|
551 |
+
] + ["WITHOUT"]
|
552 |
+
fig, ax = plt.subplots(figsize=(16, 6))
|
553 |
+
|
554 |
+
bar_positions = np.arange(len(unique_vowels))
|
555 |
+
bar_width = 0.15
|
556 |
+
ax.bar(bar_positions, list(label_counts), bar_width)
|
557 |
+
|
558 |
+
ax.set_title("Distribution of Vowels in dataset")
|
559 |
+
ax.set_xlabel("Vowels")
|
560 |
+
ax.set_ylabel("Count")
|
561 |
+
ax.legend(loc="right", bbox_to_anchor=(1, 0.85))
|
562 |
+
ax.set_xticks(bar_positions)
|
563 |
+
ax.set_xticklabels(unique_vowels_names, rotation=30, ha="right", fontsize=8)
|
564 |
+
|
565 |
+
if plots_folder is None:
|
566 |
+
plt.show()
|
567 |
+
else:
|
568 |
+
plt.savefig(os.path.join(plots_folder, "show_data_labels.jpg"))
|
569 |
+
|
570 |
+
def calc_max_length(self, maximum=MAX_LENGTH_SEN):
|
571 |
+
if self.max_length > maximum:
|
572 |
+
self.max_length = maximum
|
573 |
+
return self.max_length
|
574 |
+
|
575 |
+
def prepare_data(self, name="train"):
|
576 |
+
dataset = []
|
577 |
+
for index, (sentence, label) in tqdm(
|
578 |
+
enumerate(self.data), desc=f"prepare data {name}"
|
579 |
+
):
|
580 |
+
encoded_sequence = self.tokenizer.encode_plus(
|
581 |
+
sentence,
|
582 |
+
add_special_tokens=True,
|
583 |
+
max_length=self.max_length,
|
584 |
+
padding="max_length",
|
585 |
+
truncation=True,
|
586 |
+
return_attention_mask=True,
|
587 |
+
return_tensors="pt",
|
588 |
+
)
|
589 |
+
label_lists = [
|
590 |
+
[letter.nikud, letter.dagesh, letter.sin] for letter in label
|
591 |
+
]
|
592 |
+
label = torch.tensor(
|
593 |
+
[
|
594 |
+
[
|
595 |
+
Nikud.PAD_OR_IRRELEVANT,
|
596 |
+
Nikud.PAD_OR_IRRELEVANT,
|
597 |
+
Nikud.PAD_OR_IRRELEVANT,
|
598 |
+
]
|
599 |
+
]
|
600 |
+
+ label_lists[: (self.max_length - 1)]
|
601 |
+
+ [
|
602 |
+
[
|
603 |
+
Nikud.PAD_OR_IRRELEVANT,
|
604 |
+
Nikud.PAD_OR_IRRELEVANT,
|
605 |
+
Nikud.PAD_OR_IRRELEVANT,
|
606 |
+
]
|
607 |
+
for i in range(self.max_length - len(label) - 1)
|
608 |
+
]
|
609 |
+
)
|
610 |
+
|
611 |
+
dataset.append(
|
612 |
+
(
|
613 |
+
encoded_sequence["input_ids"][0],
|
614 |
+
encoded_sequence["attention_mask"][0],
|
615 |
+
label,
|
616 |
+
)
|
617 |
+
)
|
618 |
+
|
619 |
+
self.prepered_data = dataset
|
620 |
+
|
621 |
+
def back_2_text(self, labels):
|
622 |
+
nikud = Nikud()
|
623 |
+
all_text = ""
|
624 |
+
for indx_sentance, (input_ids, _, label) in enumerate(self.prepered_data):
|
625 |
+
new_line = ""
|
626 |
+
for indx_char, c in enumerate(self.origin_data[indx_sentance]):
|
627 |
+
new_line += (
|
628 |
+
c
|
629 |
+
+ nikud.id_2_char(labels[indx_sentance, indx_char + 1, 1], "dagesh")
|
630 |
+
+ nikud.id_2_char(labels[indx_sentance, indx_char + 1, 2], "sin")
|
631 |
+
+ nikud.id_2_char(labels[indx_sentance, indx_char + 1, 0], "nikud")
|
632 |
+
)
|
633 |
+
all_text += new_line
|
634 |
+
return all_text
|
635 |
+
|
636 |
+
def __len__(self):
|
637 |
+
return self.data.shape[0]
|
638 |
+
|
639 |
+
def __getitem__(self, idx):
|
640 |
+
row = self.data[idx]
|
641 |
+
|
642 |
+
|
643 |
+
def get_sub_folders_paths(main_folder):
|
644 |
+
list_paths = []
|
645 |
+
for filename in os.listdir(main_folder):
|
646 |
+
path = os.path.join(main_folder, filename)
|
647 |
+
if os.path.isdir(path) and filename != ".git":
|
648 |
+
list_paths.append(path)
|
649 |
+
list_paths.extend(get_sub_folders_paths(path))
|
650 |
+
return list_paths
|
651 |
+
|
652 |
+
|
653 |
+
def create_missing_folders(folder_path):
|
654 |
+
# Check if the folder doesn't exist and create it if needed
|
655 |
+
if not os.path.exists(folder_path):
|
656 |
+
os.makedirs(folder_path)
|
657 |
+
|
658 |
+
|
659 |
+
def info_folder(folder, num_files, num_hebrew_letters):
|
660 |
+
"""
|
661 |
+
Recursively counts the number of files and the number of Hebrew letters in all subfolders of the given folder path.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
folder (str): The path of the folder to be analyzed.
|
665 |
+
num_files (int): The running total of the number of files encountered so far.
|
666 |
+
num_hebrew_letters (int): The running total of the number of Hebrew letters encountered so far.
|
667 |
+
|
668 |
+
Returns:
|
669 |
+
Tuple[int, int]: A tuple containing the total number of files and the total number of Hebrew letters.
|
670 |
+
"""
|
671 |
+
for filename in os.listdir(folder):
|
672 |
+
file_path = os.path.join(folder, filename)
|
673 |
+
if filename.lower().endswith(".txt") and os.path.isfile(file_path):
|
674 |
+
num_files += 1
|
675 |
+
dataset = NikudDataset(None, file=file_path)
|
676 |
+
for line in dataset.data:
|
677 |
+
for c in line[0]:
|
678 |
+
if c in Letters.hebrew:
|
679 |
+
num_hebrew_letters += 1
|
680 |
+
|
681 |
+
elif os.path.isdir(file_path) and filename != ".git":
|
682 |
+
sub_folder = file_path
|
683 |
+
n1, n2 = info_folder(sub_folder, num_files, num_hebrew_letters)
|
684 |
+
num_files += n1
|
685 |
+
num_hebrew_letters += n2
|
686 |
+
return num_files, num_hebrew_letters
|
687 |
+
|
688 |
+
|
689 |
+
def extract_text_to_compare_nakdimon(text):
|
690 |
+
res = text.replace("|", "")
|
691 |
+
res = res.replace(
|
692 |
+
chr(Nikud.nikud_dict["KUBUTZ"]) + "ื" + chr(Nikud.nikud_dict["METEG"]),
|
693 |
+
"ื" + chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
|
694 |
+
)
|
695 |
+
res = res.replace(
|
696 |
+
chr(Nikud.nikud_dict["HOLAM"]) + "ื" + chr(Nikud.nikud_dict["METEG"]), "ื"
|
697 |
+
)
|
698 |
+
res = res.replace(
|
699 |
+
"ื" + chr(Nikud.nikud_dict["HOLAM"]) + chr(Nikud.nikud_dict["KAMATZ"]),
|
700 |
+
"ื" + chr(Nikud.nikud_dict["KAMATZ"]),
|
701 |
+
)
|
702 |
+
res = res.replace(chr(Nikud.nikud_dict["METEG"]), "")
|
703 |
+
res = res.replace(
|
704 |
+
chr(Nikud.nikud_dict["KAMATZ"]) + chr(Nikud.nikud_dict["HIRIK"]),
|
705 |
+
chr(Nikud.nikud_dict["KAMATZ"]) + "ื" + chr(Nikud.nikud_dict["HIRIK"]),
|
706 |
+
)
|
707 |
+
res = res.replace(
|
708 |
+
chr(Nikud.nikud_dict["PATAKH"]) + chr(Nikud.nikud_dict["HIRIK"]),
|
709 |
+
chr(Nikud.nikud_dict["PATAKH"]) + "ื" + chr(Nikud.nikud_dict["HIRIK"]),
|
710 |
+
)
|
711 |
+
res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION MAQAF"]), "")
|
712 |
+
res = res.replace(chr(Nikud.nikud_dict["PUNCTUATION PASEQ"]), "")
|
713 |
+
res = res.replace(
|
714 |
+
chr(Nikud.nikud_dict["KAMATZ_KATAN"]), chr(Nikud.nikud_dict["KAMATZ"])
|
715 |
+
)
|
716 |
+
|
717 |
+
res = re.sub(chr(Nikud.nikud_dict["KUBUTZ"]) + "ื" + "(?=[ื-ืช])", "ื", res)
|
718 |
+
res = res.replace(chr(Nikud.nikud_dict["REDUCED_KAMATZ"]) + "ื", "ื")
|
719 |
+
|
720 |
+
res = res.replace(
|
721 |
+
chr(Nikud.nikud_dict["DAGESH OR SHURUK"]) * 2,
|
722 |
+
chr(Nikud.nikud_dict["DAGESH OR SHURUK"]),
|
723 |
+
)
|
724 |
+
res = res.replace("\u05be", "-")
|
725 |
+
res = res.replace("ืึฐืืึนึธื", "ืืืื")
|
726 |
+
|
727 |
+
return res
|
728 |
+
|
729 |
+
|
730 |
+
def orgenize_data(main_folder, logger):
|
731 |
+
x = NikudDataset(None)
|
732 |
+
x.delete_files(os.path.join(Path(main_folder).parent, "train"))
|
733 |
+
x.delete_files(os.path.join(Path(main_folder).parent, "dev"))
|
734 |
+
x.delete_files(os.path.join(Path(main_folder).parent, "test"))
|
735 |
+
x.split_data(
|
736 |
+
main_folder, main_folder_name=os.path.basename(main_folder), logger=logger
|
737 |
+
)
|