|
|
|
import json |
|
import os |
|
|
|
|
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from sklearn.metrics import confusion_matrix |
|
from tqdm import tqdm |
|
|
|
from src.running_params import DEBUG_MODE |
|
from src.utiles_data import Nikud, create_missing_folders |
|
|
|
CLASSES_LIST = ["nikud", "dagesh", "sin"] |
|
|
|
|
|
def calc_num_correct_words(input, letter_correct_mask): |
|
SPACE_TOKEN = 104 |
|
START_SENTENCE_TOKEN = 1 |
|
END_SENTENCE_TOKEN = 2 |
|
|
|
correct_words_count = 0 |
|
words_count = 0 |
|
for index in range(input.shape[0]): |
|
input[index][np.where(input[index] == SPACE_TOKEN)[0]] = 0 |
|
input[index][np.where(input[index] == START_SENTENCE_TOKEN)[0]] = 0 |
|
input[index][np.where(input[index] == END_SENTENCE_TOKEN)[0]] = 0 |
|
words_end_index = np.concatenate( |
|
(np.array([-1]), np.where(input[index] == 0)[0]) |
|
) |
|
is_correct_words_array = [ |
|
bool( |
|
letter_correct_mask[index][ |
|
list(range((words_end_index[s] + 1), words_end_index[s + 1])) |
|
].all() |
|
) |
|
for s in range(len(words_end_index) - 1) |
|
if words_end_index[s + 1] - (words_end_index[s] + 1) > 1 |
|
] |
|
correct_words_count += np.array(is_correct_words_array).sum() |
|
words_count += len(is_correct_words_array) |
|
|
|
return correct_words_count, words_count |
|
|
|
|
|
def predict(model, data_loader, device="cpu"): |
|
model.to(device) |
|
|
|
all_labels = None |
|
with torch.no_grad(): |
|
for index_data, data in enumerate(data_loader): |
|
(inputs, attention_mask, labels_demo) = data |
|
inputs = inputs.to(device) |
|
attention_mask = attention_mask.to(device) |
|
labels_demo = labels_demo.to(device) |
|
|
|
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1 |
|
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1 |
|
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1 |
|
|
|
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask) |
|
|
|
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape( |
|
inputs.shape[0], inputs.shape[1], 1 |
|
) |
|
pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape( |
|
inputs.shape[0], inputs.shape[1], 1 |
|
) |
|
pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape( |
|
inputs.shape[0], inputs.shape[1], 1 |
|
) |
|
|
|
pred_nikud[mask_cant_be_nikud] = -1 |
|
pred_dagesh[mask_cant_be_dagesh] = -1 |
|
pred_sin[mask_cant_be_sin] = -1 |
|
|
|
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2) |
|
|
|
if all_labels is None: |
|
all_labels = pred_labels |
|
else: |
|
all_labels = np.concatenate((all_labels, pred_labels), axis=0) |
|
|
|
return all_labels |
|
|
|
|
|
def predict_single(model, data, device="cpu"): |
|
|
|
|
|
all_labels = None |
|
with torch.no_grad(): |
|
(inputs, attention_mask, labels_demo) = data |
|
inputs = inputs.to(device) |
|
attention_mask = attention_mask.to(device) |
|
labels_demo = labels_demo.to(device) |
|
|
|
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1 |
|
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1 |
|
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1 |
|
|
|
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask) |
|
print("model output: ", nikud_probs, dagesh_probs, sin_probs) |
|
|
|
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape( |
|
inputs.shape[0], inputs.shape[1], 1 |
|
) |
|
pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape( |
|
inputs.shape[0], inputs.shape[1], 1 |
|
) |
|
pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape( |
|
inputs.shape[0], inputs.shape[1], 1 |
|
) |
|
|
|
pred_nikud[mask_cant_be_nikud] = -1 |
|
pred_dagesh[mask_cant_be_dagesh] = -1 |
|
pred_sin[mask_cant_be_sin] = -1 |
|
|
|
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2) |
|
print(pred_labels) |
|
if all_labels is None: |
|
all_labels = pred_labels |
|
else: |
|
all_labels = np.concatenate((all_labels, pred_labels), axis=0) |
|
|
|
return all_labels |
|
|
|
|
|
def training( |
|
model, |
|
train_loader, |
|
dev_loader, |
|
criterion_nikud, |
|
criterion_dagesh, |
|
criterion_sin, |
|
training_params, |
|
logger, |
|
output_model_path, |
|
optimizer, |
|
device="cpu", |
|
): |
|
max_length = None |
|
best_accuracy = 0.0 |
|
|
|
logger.info(f"start training with training_params: {training_params}") |
|
model = model.to(device) |
|
|
|
criteria = { |
|
"nikud": criterion_nikud.to(device), |
|
"dagesh": criterion_dagesh.to(device), |
|
"sin": criterion_sin.to(device), |
|
} |
|
|
|
output_checkpoints_path = os.path.join(output_model_path, "checkpoints") |
|
create_missing_folders(output_checkpoints_path) |
|
|
|
train_steps_loss_values = {"nikud": [], "dagesh": [], "sin": []} |
|
train_epochs_loss_values = {"nikud": [], "dagesh": [], "sin": []} |
|
dev_loss_values = {"nikud": [], "dagesh": [], "sin": []} |
|
dev_accuracy_values = { |
|
"nikud": [], |
|
"dagesh": [], |
|
"sin": [], |
|
"all_nikud_letter": [], |
|
"all_nikud_word": [], |
|
} |
|
|
|
for epoch in tqdm(range(training_params["n_epochs"]), desc="Training"): |
|
model.train() |
|
train_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
|
|
for index_data, data in enumerate(train_loader): |
|
(inputs, attention_mask, labels) = data |
|
|
|
if max_length is None: |
|
max_length = labels.shape[1] |
|
|
|
inputs = inputs.to(device) |
|
attention_mask = attention_mask.to(device) |
|
labels = labels.to(device) |
|
|
|
optimizer.zero_grad() |
|
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask) |
|
|
|
for i, (probs, class_name) in enumerate( |
|
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST) |
|
): |
|
reshaped_tensor = ( |
|
torch.transpose(probs, 1, 2) |
|
.contiguous() |
|
.view(probs.shape[0], probs.shape[2], probs.shape[1]) |
|
) |
|
loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(device) |
|
|
|
num_relevant = (labels[:, :, i] != -1).sum() |
|
train_loss[class_name] += loss.item() * num_relevant |
|
relevant_count[class_name] += num_relevant |
|
|
|
loss.backward(retain_graph=True) |
|
|
|
for i, class_name in enumerate(CLASSES_LIST): |
|
train_steps_loss_values[class_name].append( |
|
float(train_loss[class_name] / relevant_count[class_name]) |
|
) |
|
|
|
optimizer.step() |
|
if (index_data + 1) % 100 == 0: |
|
msg = f"epoch: {epoch} , index_data: {index_data + 1}\n" |
|
for i, class_name in enumerate(CLASSES_LIST): |
|
msg += f"mean loss train {class_name}: {float(train_loss[class_name] / relevant_count[class_name])}, " |
|
|
|
logger.debug(msg[:-2]) |
|
|
|
for i, class_name in enumerate(CLASSES_LIST): |
|
train_epochs_loss_values[class_name].append( |
|
float(train_loss[class_name] / relevant_count[class_name]) |
|
) |
|
|
|
for class_name in train_loss.keys(): |
|
train_loss[class_name] /= relevant_count[class_name] |
|
|
|
msg = f"Epoch {epoch + 1}/{training_params['n_epochs']}\n" |
|
for i, class_name in enumerate(CLASSES_LIST): |
|
msg += f"mean loss train {class_name}: {train_loss[class_name]}, " |
|
logger.debug(msg[:-2]) |
|
|
|
model.eval() |
|
dev_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
dev_accuracy = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
correct_preds = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
un_masks = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
predictions = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
|
|
all_nikud_types_correct_preds_letter = 0.0 |
|
|
|
letter_count = 0.0 |
|
correct_words_count = 0.0 |
|
word_count = 0.0 |
|
with torch.no_grad(): |
|
for index_data, data in enumerate(dev_loader): |
|
(inputs, attention_mask, labels) = data |
|
inputs = inputs.to(device) |
|
attention_mask = attention_mask.to(device) |
|
labels = labels.to(device) |
|
|
|
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask) |
|
|
|
for i, (probs, class_name) in enumerate( |
|
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST) |
|
): |
|
reshaped_tensor = ( |
|
torch.transpose(probs, 1, 2) |
|
.contiguous() |
|
.view(probs.shape[0], probs.shape[2], probs.shape[1]) |
|
) |
|
loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to( |
|
device |
|
) |
|
un_masked = labels[:, :, i] != -1 |
|
num_relevant = un_masked.sum() |
|
relevant_count[class_name] += num_relevant |
|
_, preds = torch.max(probs, 2) |
|
dev_loss[class_name] += loss.item() * num_relevant |
|
correct_preds[class_name] += torch.sum( |
|
preds[un_masked] == labels[:, :, i][un_masked] |
|
) |
|
un_masks[class_name] = un_masked |
|
predictions[class_name] = preds |
|
labels_class[class_name] = labels[:, :, i] |
|
|
|
un_mask_all_or = torch.logical_or( |
|
torch.logical_or(un_masks["nikud"], un_masks["dagesh"]), |
|
un_masks["sin"], |
|
) |
|
|
|
correct = { |
|
class_name: (torch.ones(un_mask_all_or.shape) == 1).to(device) |
|
for class_name in CLASSES_LIST |
|
} |
|
|
|
for i, class_name in enumerate(CLASSES_LIST): |
|
correct[class_name][un_masks[class_name]] = ( |
|
predictions[class_name][un_masks[class_name]] |
|
== labels_class[class_name][un_masks[class_name]] |
|
) |
|
|
|
letter_correct_mask = torch.logical_and( |
|
torch.logical_and(correct["sin"], correct["dagesh"]), |
|
correct["nikud"], |
|
) |
|
all_nikud_types_correct_preds_letter += torch.sum( |
|
letter_correct_mask[un_mask_all_or] |
|
) |
|
|
|
letter_correct_mask[~un_mask_all_or] = True |
|
correct_num, total_words_num = calc_num_correct_words( |
|
inputs.cpu(), letter_correct_mask |
|
) |
|
|
|
word_count += total_words_num |
|
correct_words_count += correct_num |
|
letter_count += un_mask_all_or.sum() |
|
|
|
for class_name in CLASSES_LIST: |
|
dev_loss[class_name] /= relevant_count[class_name] |
|
dev_accuracy[class_name] = float( |
|
correct_preds[class_name].double() / relevant_count[class_name] |
|
) |
|
|
|
dev_loss_values[class_name].append(float(dev_loss[class_name])) |
|
dev_accuracy_values[class_name].append(float(dev_accuracy[class_name])) |
|
|
|
dev_all_nikud_types_accuracy_letter = float( |
|
all_nikud_types_correct_preds_letter / letter_count |
|
) |
|
|
|
dev_accuracy_values["all_nikud_letter"].append( |
|
dev_all_nikud_types_accuracy_letter |
|
) |
|
|
|
word_all_nikud_accuracy = correct_words_count / word_count |
|
dev_accuracy_values["all_nikud_word"].append(word_all_nikud_accuracy) |
|
|
|
msg = ( |
|
f"Epoch {epoch + 1}/{training_params['n_epochs']}\n" |
|
f'mean loss Dev nikud: {train_loss["nikud"]}, ' |
|
f'mean loss Dev dagesh: {train_loss["dagesh"]}, ' |
|
f'mean loss Dev sin: {train_loss["sin"]}, ' |
|
f"Dev all nikud types letter Accuracy: {dev_all_nikud_types_accuracy_letter}, " |
|
f'Dev nikud letter Accuracy: {dev_accuracy["nikud"]}, ' |
|
f'Dev dagesh letter Accuracy: {dev_accuracy["dagesh"]}, ' |
|
f'Dev sin letter Accuracy: {dev_accuracy["sin"]}, ' |
|
f"Dev word Accuracy: {word_all_nikud_accuracy}" |
|
) |
|
logger.debug(msg) |
|
|
|
save_progress_details( |
|
dev_accuracy_values, |
|
train_epochs_loss_values, |
|
dev_loss_values, |
|
train_steps_loss_values, |
|
) |
|
|
|
if dev_all_nikud_types_accuracy_letter > best_accuracy: |
|
best_accuracy = dev_all_nikud_types_accuracy_letter |
|
best_model = { |
|
"epoch": epoch, |
|
"model_state_dict": model.state_dict(), |
|
"optimizer_state_dict": optimizer.state_dict(), |
|
"loss": loss, |
|
} |
|
|
|
if epoch % training_params["checkpoints_frequency"] == 0: |
|
save_checkpoint_path = os.path.join( |
|
output_checkpoints_path, f"checkpoint_model_epoch_{epoch + 1}.pth" |
|
) |
|
checkpoint = { |
|
"epoch": epoch, |
|
"model_state_dict": model.state_dict(), |
|
"optimizer_state_dict": optimizer.state_dict(), |
|
"loss": loss, |
|
} |
|
torch.save(checkpoint["model_state_dict"], save_checkpoint_path) |
|
|
|
save_model_path = os.path.join(output_model_path, "best_model.pth") |
|
torch.save(best_model["model_state_dict"], save_model_path) |
|
return ( |
|
best_model, |
|
best_accuracy, |
|
train_epochs_loss_values, |
|
train_steps_loss_values, |
|
dev_loss_values, |
|
dev_accuracy_values, |
|
) |
|
|
|
|
|
def save_progress_details( |
|
accuracy_dev_values, |
|
epochs_loss_train_values, |
|
loss_dev_values, |
|
steps_loss_train_values, |
|
): |
|
epochs_data_path = "epochs_data" |
|
create_missing_folders(epochs_data_path) |
|
|
|
save_dict_as_json( |
|
steps_loss_train_values, epochs_data_path, "steps_loss_train_values.json" |
|
) |
|
save_dict_as_json( |
|
epochs_loss_train_values, epochs_data_path, "epochs_loss_train_values.json" |
|
) |
|
save_dict_as_json(loss_dev_values, epochs_data_path, "loss_dev_values.json") |
|
save_dict_as_json(accuracy_dev_values, epochs_data_path, "accuracy_dev_values.json") |
|
|
|
|
|
def save_dict_as_json(dict, file_path, file_name): |
|
json_data = json.dumps(dict, indent=4) |
|
with open(os.path.join(file_path, file_name), "w") as json_file: |
|
json_file.write(json_data) |
|
|
|
|
|
def evaluate(model, test_data, plots_folder=None, device="cpu"): |
|
model.to(device) |
|
model.eval() |
|
|
|
true_labels = {"nikud": [], "dagesh": [], "sin": []} |
|
predictions = {"nikud": 0, "dagesh": 0, "sin": 0} |
|
predicted_labels_2_report = {"nikud": [], "dagesh": [], "sin": []} |
|
not_masks = {"nikud": 0, "dagesh": 0, "sin": 0} |
|
correct_preds = {"nikud": 0, "dagesh": 0, "sin": 0} |
|
relevant_count = {"nikud": 0, "dagesh": 0, "sin": 0} |
|
labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0} |
|
|
|
all_nikud_types_letter_level_correct = 0.0 |
|
nikud_letter_level_correct = 0.0 |
|
dagesh_letter_level_correct = 0.0 |
|
sin_letter_level_correct = 0.0 |
|
|
|
letters_count = 0.0 |
|
words_count = 0.0 |
|
correct_words_count = 0.0 |
|
with torch.no_grad(): |
|
for index_data, data in enumerate(test_data): |
|
if DEBUG_MODE and index_data > 100: |
|
break |
|
|
|
(inputs, attention_mask, labels) = data |
|
|
|
inputs = inputs.to(device) |
|
attention_mask = attention_mask.to(device) |
|
labels = labels.to(device) |
|
|
|
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask) |
|
|
|
for i, (probs, class_name) in enumerate( |
|
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST) |
|
): |
|
labels_class[class_name] = labels[:, :, i] |
|
not_masked = labels_class[class_name] != -1 |
|
num_relevant = not_masked.sum() |
|
relevant_count[class_name] += num_relevant |
|
_, preds = torch.max(probs, 2) |
|
correct_preds[class_name] += torch.sum( |
|
preds[not_masked] == labels_class[class_name][not_masked] |
|
) |
|
predictions[class_name] = preds |
|
not_masks[class_name] = not_masked |
|
|
|
if len(true_labels[class_name]) == 0: |
|
true_labels[class_name] = ( |
|
labels_class[class_name][not_masked].cpu().numpy() |
|
) |
|
else: |
|
true_labels[class_name] = np.concatenate( |
|
( |
|
true_labels[class_name], |
|
labels_class[class_name][not_masked].cpu().numpy(), |
|
) |
|
) |
|
|
|
if len(predicted_labels_2_report[class_name]) == 0: |
|
predicted_labels_2_report[class_name] = ( |
|
preds[not_masked].cpu().numpy() |
|
) |
|
else: |
|
predicted_labels_2_report[class_name] = np.concatenate( |
|
( |
|
predicted_labels_2_report[class_name], |
|
preds[not_masked].cpu().numpy(), |
|
) |
|
) |
|
|
|
not_mask_all_or = torch.logical_or( |
|
torch.logical_or(not_masks["nikud"], not_masks["dagesh"]), |
|
not_masks["sin"], |
|
) |
|
|
|
correct_nikud = (torch.ones(not_mask_all_or.shape) == 1).to(device) |
|
correct_dagesh = (torch.ones(not_mask_all_or.shape) == 1).to(device) |
|
correct_sin = (torch.ones(not_mask_all_or.shape) == 1).to(device) |
|
|
|
correct_nikud[not_masks["nikud"]] = ( |
|
predictions["nikud"][not_masks["nikud"]] |
|
== labels_class["nikud"][not_masks["nikud"]] |
|
) |
|
correct_dagesh[not_masks["dagesh"]] = ( |
|
predictions["dagesh"][not_masks["dagesh"]] |
|
== labels_class["dagesh"][not_masks["dagesh"]] |
|
) |
|
correct_sin[not_masks["sin"]] = ( |
|
predictions["sin"][not_masks["sin"]] |
|
== labels_class["sin"][not_masks["sin"]] |
|
) |
|
|
|
letter_correct_mask = torch.logical_and( |
|
torch.logical_and(correct_sin, correct_dagesh), correct_nikud |
|
) |
|
all_nikud_types_letter_level_correct += torch.sum( |
|
letter_correct_mask[not_mask_all_or] |
|
) |
|
|
|
letter_correct_mask[~not_mask_all_or] = True |
|
total_correct_count, total_words_num = calc_num_correct_words( |
|
inputs.cpu(), letter_correct_mask |
|
) |
|
|
|
words_count += total_words_num |
|
correct_words_count += total_correct_count |
|
|
|
letters_count += not_mask_all_or.sum() |
|
|
|
nikud_letter_level_correct += torch.sum(correct_nikud[not_mask_all_or]) |
|
dagesh_letter_level_correct += torch.sum(correct_dagesh[not_mask_all_or]) |
|
sin_letter_level_correct += torch.sum(correct_sin[not_mask_all_or]) |
|
|
|
for i, name in enumerate(CLASSES_LIST): |
|
index_labels = np.unique(true_labels[name]) |
|
cm = confusion_matrix( |
|
true_labels[name], predicted_labels_2_report[name], labels=index_labels |
|
) |
|
|
|
vowel_label = [Nikud.id_2_label[name][l] for l in index_labels] |
|
unique_vowels_names = [ |
|
Nikud.sign_2_name[int(vowel)] for vowel in vowel_label if vowel != "WITHOUT" |
|
] |
|
if "WITHOUT" in vowel_label: |
|
unique_vowels_names += ["WITHOUT"] |
|
cm_df = pd.DataFrame(cm, index=unique_vowels_names, columns=unique_vowels_names) |
|
|
|
|
|
plt.figure(figsize=(10, 8)) |
|
sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d") |
|
plt.title("Confusion Matrix") |
|
plt.xlabel("True Label") |
|
plt.ylabel("Predicted Label") |
|
if plots_folder is None: |
|
plt.show() |
|
else: |
|
plt.savefig(os.path.join(plots_folder, f"Confusion_Matrix_{name}.jpg")) |
|
|
|
all_nikud_types_letter_level_correct = ( |
|
all_nikud_types_letter_level_correct / letters_count |
|
) |
|
all_nikud_types_word_level_correct = correct_words_count / words_count |
|
nikud_letter_level_correct = nikud_letter_level_correct / letters_count |
|
dagesh_letter_level_correct = dagesh_letter_level_correct / letters_count |
|
sin_letter_level_correct = sin_letter_level_correct / letters_count |
|
print("\n") |
|
print(f"nikud_letter_level_correct = {nikud_letter_level_correct}") |
|
print(f"dagesh_letter_level_correct = {dagesh_letter_level_correct}") |
|
print(f"sin_letter_level_correct = {sin_letter_level_correct}") |
|
print(f"word_level_correct = {all_nikud_types_word_level_correct}") |
|
|
|
return all_nikud_types_word_level_correct, all_nikud_types_letter_level_correct |
|
|