D_Nikud / src /models_utils.py
NadavShaked's picture
Upload 7 files
91da6cc verified
raw
history blame
21.6 kB
# general
import json
import os
# ML
import numpy as np
import pandas as pd
import torch
# visual
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
from src.running_params import DEBUG_MODE
from src.utiles_data import Nikud, create_missing_folders
CLASSES_LIST = ["nikud", "dagesh", "sin"]
def calc_num_correct_words(input, letter_correct_mask):
SPACE_TOKEN = 104
START_SENTENCE_TOKEN = 1
END_SENTENCE_TOKEN = 2
correct_words_count = 0
words_count = 0
for index in range(input.shape[0]):
input[index][np.where(input[index] == SPACE_TOKEN)[0]] = 0
input[index][np.where(input[index] == START_SENTENCE_TOKEN)[0]] = 0
input[index][np.where(input[index] == END_SENTENCE_TOKEN)[0]] = 0
words_end_index = np.concatenate(
(np.array([-1]), np.where(input[index] == 0)[0])
)
is_correct_words_array = [
bool(
letter_correct_mask[index][
list(range((words_end_index[s] + 1), words_end_index[s + 1]))
].all()
)
for s in range(len(words_end_index) - 1)
if words_end_index[s + 1] - (words_end_index[s] + 1) > 1
]
correct_words_count += np.array(is_correct_words_array).sum()
words_count += len(is_correct_words_array)
return correct_words_count, words_count
def predict(model, data_loader, device="cpu"):
model.to(device)
all_labels = None
with torch.no_grad():
for index_data, data in enumerate(data_loader):
(inputs, attention_mask, labels_demo) = data
inputs = inputs.to(device)
attention_mask = attention_mask.to(device)
labels_demo = labels_demo.to(device)
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
inputs.shape[0], inputs.shape[1], 1
)
pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
inputs.shape[0], inputs.shape[1], 1
)
pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
inputs.shape[0], inputs.shape[1], 1
)
pred_nikud[mask_cant_be_nikud] = -1
pred_dagesh[mask_cant_be_dagesh] = -1
pred_sin[mask_cant_be_sin] = -1
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
if all_labels is None:
all_labels = pred_labels
else:
all_labels = np.concatenate((all_labels, pred_labels), axis=0)
return all_labels
def predict_single(model, data, device="cpu"):
# model.to(device)
all_labels = None
with torch.no_grad():
(inputs, attention_mask, labels_demo) = data
inputs = inputs.to(device)
attention_mask = attention_mask.to(device)
labels_demo = labels_demo.to(device)
mask_cant_be_nikud = np.array(labels_demo.cpu())[:, :, 0] == -1
mask_cant_be_dagesh = np.array(labels_demo.cpu())[:, :, 1] == -1
mask_cant_be_sin = np.array(labels_demo.cpu())[:, :, 2] == -1
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
print("model output: ", nikud_probs, dagesh_probs, sin_probs)
pred_nikud = np.array(torch.max(nikud_probs, 2).indices.cpu()).reshape(
inputs.shape[0], inputs.shape[1], 1
)
pred_dagesh = np.array(torch.max(dagesh_probs, 2).indices.cpu()).reshape(
inputs.shape[0], inputs.shape[1], 1
)
pred_sin = np.array(torch.max(sin_probs, 2).indices.cpu()).reshape(
inputs.shape[0], inputs.shape[1], 1
)
pred_nikud[mask_cant_be_nikud] = -1
pred_dagesh[mask_cant_be_dagesh] = -1
pred_sin[mask_cant_be_sin] = -1
# print(pred_nikud, pred_dagesh, pred_sin)
pred_labels = np.concatenate((pred_nikud, pred_dagesh, pred_sin), axis=2)
print(pred_labels)
if all_labels is None:
all_labels = pred_labels
else:
all_labels = np.concatenate((all_labels, pred_labels), axis=0)
return all_labels
def training(
model,
train_loader,
dev_loader,
criterion_nikud,
criterion_dagesh,
criterion_sin,
training_params,
logger,
output_model_path,
optimizer,
device="cpu",
):
max_length = None
best_accuracy = 0.0
logger.info(f"start training with training_params: {training_params}")
model = model.to(device)
criteria = {
"nikud": criterion_nikud.to(device),
"dagesh": criterion_dagesh.to(device),
"sin": criterion_sin.to(device),
}
output_checkpoints_path = os.path.join(output_model_path, "checkpoints")
create_missing_folders(output_checkpoints_path)
train_steps_loss_values = {"nikud": [], "dagesh": [], "sin": []}
train_epochs_loss_values = {"nikud": [], "dagesh": [], "sin": []}
dev_loss_values = {"nikud": [], "dagesh": [], "sin": []}
dev_accuracy_values = {
"nikud": [],
"dagesh": [],
"sin": [],
"all_nikud_letter": [],
"all_nikud_word": [],
}
for epoch in tqdm(range(training_params["n_epochs"]), desc="Training"):
model.train()
train_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
for index_data, data in enumerate(train_loader):
(inputs, attention_mask, labels) = data
if max_length is None:
max_length = labels.shape[1]
inputs = inputs.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
optimizer.zero_grad()
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
for i, (probs, class_name) in enumerate(
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
):
reshaped_tensor = (
torch.transpose(probs, 1, 2)
.contiguous()
.view(probs.shape[0], probs.shape[2], probs.shape[1])
)
loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(device)
num_relevant = (labels[:, :, i] != -1).sum()
train_loss[class_name] += loss.item() * num_relevant
relevant_count[class_name] += num_relevant
loss.backward(retain_graph=True)
for i, class_name in enumerate(CLASSES_LIST):
train_steps_loss_values[class_name].append(
float(train_loss[class_name] / relevant_count[class_name])
)
optimizer.step()
if (index_data + 1) % 100 == 0:
msg = f"epoch: {epoch} , index_data: {index_data + 1}\n"
for i, class_name in enumerate(CLASSES_LIST):
msg += f"mean loss train {class_name}: {float(train_loss[class_name] / relevant_count[class_name])}, "
logger.debug(msg[:-2])
for i, class_name in enumerate(CLASSES_LIST):
train_epochs_loss_values[class_name].append(
float(train_loss[class_name] / relevant_count[class_name])
)
for class_name in train_loss.keys():
train_loss[class_name] /= relevant_count[class_name]
msg = f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
for i, class_name in enumerate(CLASSES_LIST):
msg += f"mean loss train {class_name}: {train_loss[class_name]}, "
logger.debug(msg[:-2])
model.eval()
dev_loss = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
dev_accuracy = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
relevant_count = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
correct_preds = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
un_masks = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
predictions = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
all_nikud_types_correct_preds_letter = 0.0
letter_count = 0.0
correct_words_count = 0.0
word_count = 0.0
with torch.no_grad():
for index_data, data in enumerate(dev_loader):
(inputs, attention_mask, labels) = data
inputs = inputs.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
for i, (probs, class_name) in enumerate(
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
):
reshaped_tensor = (
torch.transpose(probs, 1, 2)
.contiguous()
.view(probs.shape[0], probs.shape[2], probs.shape[1])
)
loss = criteria[class_name](reshaped_tensor, labels[:, :, i]).to(
device
)
un_masked = labels[:, :, i] != -1
num_relevant = un_masked.sum()
relevant_count[class_name] += num_relevant
_, preds = torch.max(probs, 2)
dev_loss[class_name] += loss.item() * num_relevant
correct_preds[class_name] += torch.sum(
preds[un_masked] == labels[:, :, i][un_masked]
)
un_masks[class_name] = un_masked
predictions[class_name] = preds
labels_class[class_name] = labels[:, :, i]
un_mask_all_or = torch.logical_or(
torch.logical_or(un_masks["nikud"], un_masks["dagesh"]),
un_masks["sin"],
)
correct = {
class_name: (torch.ones(un_mask_all_or.shape) == 1).to(device)
for class_name in CLASSES_LIST
}
for i, class_name in enumerate(CLASSES_LIST):
correct[class_name][un_masks[class_name]] = (
predictions[class_name][un_masks[class_name]]
== labels_class[class_name][un_masks[class_name]]
)
letter_correct_mask = torch.logical_and(
torch.logical_and(correct["sin"], correct["dagesh"]),
correct["nikud"],
)
all_nikud_types_correct_preds_letter += torch.sum(
letter_correct_mask[un_mask_all_or]
)
letter_correct_mask[~un_mask_all_or] = True
correct_num, total_words_num = calc_num_correct_words(
inputs.cpu(), letter_correct_mask
)
word_count += total_words_num
correct_words_count += correct_num
letter_count += un_mask_all_or.sum()
for class_name in CLASSES_LIST:
dev_loss[class_name] /= relevant_count[class_name]
dev_accuracy[class_name] = float(
correct_preds[class_name].double() / relevant_count[class_name]
)
dev_loss_values[class_name].append(float(dev_loss[class_name]))
dev_accuracy_values[class_name].append(float(dev_accuracy[class_name]))
dev_all_nikud_types_accuracy_letter = float(
all_nikud_types_correct_preds_letter / letter_count
)
dev_accuracy_values["all_nikud_letter"].append(
dev_all_nikud_types_accuracy_letter
)
word_all_nikud_accuracy = correct_words_count / word_count
dev_accuracy_values["all_nikud_word"].append(word_all_nikud_accuracy)
msg = (
f"Epoch {epoch + 1}/{training_params['n_epochs']}\n"
f'mean loss Dev nikud: {train_loss["nikud"]}, '
f'mean loss Dev dagesh: {train_loss["dagesh"]}, '
f'mean loss Dev sin: {train_loss["sin"]}, '
f"Dev all nikud types letter Accuracy: {dev_all_nikud_types_accuracy_letter}, "
f'Dev nikud letter Accuracy: {dev_accuracy["nikud"]}, '
f'Dev dagesh letter Accuracy: {dev_accuracy["dagesh"]}, '
f'Dev sin letter Accuracy: {dev_accuracy["sin"]}, '
f"Dev word Accuracy: {word_all_nikud_accuracy}"
)
logger.debug(msg)
save_progress_details(
dev_accuracy_values,
train_epochs_loss_values,
dev_loss_values,
train_steps_loss_values,
)
if dev_all_nikud_types_accuracy_letter > best_accuracy:
best_accuracy = dev_all_nikud_types_accuracy_letter
best_model = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
}
if epoch % training_params["checkpoints_frequency"] == 0:
save_checkpoint_path = os.path.join(
output_checkpoints_path, f"checkpoint_model_epoch_{epoch + 1}.pth"
)
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
}
torch.save(checkpoint["model_state_dict"], save_checkpoint_path)
save_model_path = os.path.join(output_model_path, "best_model.pth")
torch.save(best_model["model_state_dict"], save_model_path)
return (
best_model,
best_accuracy,
train_epochs_loss_values,
train_steps_loss_values,
dev_loss_values,
dev_accuracy_values,
)
def save_progress_details(
accuracy_dev_values,
epochs_loss_train_values,
loss_dev_values,
steps_loss_train_values,
):
epochs_data_path = "epochs_data"
create_missing_folders(epochs_data_path)
save_dict_as_json(
steps_loss_train_values, epochs_data_path, "steps_loss_train_values.json"
)
save_dict_as_json(
epochs_loss_train_values, epochs_data_path, "epochs_loss_train_values.json"
)
save_dict_as_json(loss_dev_values, epochs_data_path, "loss_dev_values.json")
save_dict_as_json(accuracy_dev_values, epochs_data_path, "accuracy_dev_values.json")
def save_dict_as_json(dict, file_path, file_name):
json_data = json.dumps(dict, indent=4)
with open(os.path.join(file_path, file_name), "w") as json_file:
json_file.write(json_data)
def evaluate(model, test_data, plots_folder=None, device="cpu"):
model.to(device)
model.eval()
true_labels = {"nikud": [], "dagesh": [], "sin": []}
predictions = {"nikud": 0, "dagesh": 0, "sin": 0}
predicted_labels_2_report = {"nikud": [], "dagesh": [], "sin": []}
not_masks = {"nikud": 0, "dagesh": 0, "sin": 0}
correct_preds = {"nikud": 0, "dagesh": 0, "sin": 0}
relevant_count = {"nikud": 0, "dagesh": 0, "sin": 0}
labels_class = {"nikud": 0.0, "dagesh": 0.0, "sin": 0.0}
all_nikud_types_letter_level_correct = 0.0
nikud_letter_level_correct = 0.0
dagesh_letter_level_correct = 0.0
sin_letter_level_correct = 0.0
letters_count = 0.0
words_count = 0.0
correct_words_count = 0.0
with torch.no_grad():
for index_data, data in enumerate(test_data):
if DEBUG_MODE and index_data > 100:
break
(inputs, attention_mask, labels) = data
inputs = inputs.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
nikud_probs, dagesh_probs, sin_probs = model(inputs, attention_mask)
for i, (probs, class_name) in enumerate(
zip([nikud_probs, dagesh_probs, sin_probs], CLASSES_LIST)
):
labels_class[class_name] = labels[:, :, i]
not_masked = labels_class[class_name] != -1
num_relevant = not_masked.sum()
relevant_count[class_name] += num_relevant
_, preds = torch.max(probs, 2)
correct_preds[class_name] += torch.sum(
preds[not_masked] == labels_class[class_name][not_masked]
)
predictions[class_name] = preds
not_masks[class_name] = not_masked
if len(true_labels[class_name]) == 0:
true_labels[class_name] = (
labels_class[class_name][not_masked].cpu().numpy()
)
else:
true_labels[class_name] = np.concatenate(
(
true_labels[class_name],
labels_class[class_name][not_masked].cpu().numpy(),
)
)
if len(predicted_labels_2_report[class_name]) == 0:
predicted_labels_2_report[class_name] = (
preds[not_masked].cpu().numpy()
)
else:
predicted_labels_2_report[class_name] = np.concatenate(
(
predicted_labels_2_report[class_name],
preds[not_masked].cpu().numpy(),
)
)
not_mask_all_or = torch.logical_or(
torch.logical_or(not_masks["nikud"], not_masks["dagesh"]),
not_masks["sin"],
)
correct_nikud = (torch.ones(not_mask_all_or.shape) == 1).to(device)
correct_dagesh = (torch.ones(not_mask_all_or.shape) == 1).to(device)
correct_sin = (torch.ones(not_mask_all_or.shape) == 1).to(device)
correct_nikud[not_masks["nikud"]] = (
predictions["nikud"][not_masks["nikud"]]
== labels_class["nikud"][not_masks["nikud"]]
)
correct_dagesh[not_masks["dagesh"]] = (
predictions["dagesh"][not_masks["dagesh"]]
== labels_class["dagesh"][not_masks["dagesh"]]
)
correct_sin[not_masks["sin"]] = (
predictions["sin"][not_masks["sin"]]
== labels_class["sin"][not_masks["sin"]]
)
letter_correct_mask = torch.logical_and(
torch.logical_and(correct_sin, correct_dagesh), correct_nikud
)
all_nikud_types_letter_level_correct += torch.sum(
letter_correct_mask[not_mask_all_or]
)
letter_correct_mask[~not_mask_all_or] = True
total_correct_count, total_words_num = calc_num_correct_words(
inputs.cpu(), letter_correct_mask
)
words_count += total_words_num
correct_words_count += total_correct_count
letters_count += not_mask_all_or.sum()
nikud_letter_level_correct += torch.sum(correct_nikud[not_mask_all_or])
dagesh_letter_level_correct += torch.sum(correct_dagesh[not_mask_all_or])
sin_letter_level_correct += torch.sum(correct_sin[not_mask_all_or])
for i, name in enumerate(CLASSES_LIST):
index_labels = np.unique(true_labels[name])
cm = confusion_matrix(
true_labels[name], predicted_labels_2_report[name], labels=index_labels
)
vowel_label = [Nikud.id_2_label[name][l] for l in index_labels]
unique_vowels_names = [
Nikud.sign_2_name[int(vowel)] for vowel in vowel_label if vowel != "WITHOUT"
]
if "WITHOUT" in vowel_label:
unique_vowels_names += ["WITHOUT"]
cm_df = pd.DataFrame(cm, index=unique_vowels_names, columns=unique_vowels_names)
# Display confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm_df, annot=True, cmap="Blues", fmt="d")
plt.title("Confusion Matrix")
plt.xlabel("True Label")
plt.ylabel("Predicted Label")
if plots_folder is None:
plt.show()
else:
plt.savefig(os.path.join(plots_folder, f"Confusion_Matrix_{name}.jpg"))
all_nikud_types_letter_level_correct = (
all_nikud_types_letter_level_correct / letters_count
)
all_nikud_types_word_level_correct = correct_words_count / words_count
nikud_letter_level_correct = nikud_letter_level_correct / letters_count
dagesh_letter_level_correct = dagesh_letter_level_correct / letters_count
sin_letter_level_correct = sin_letter_level_correct / letters_count
print("\n")
print(f"nikud_letter_level_correct = {nikud_letter_level_correct}")
print(f"dagesh_letter_level_correct = {dagesh_letter_level_correct}")
print(f"sin_letter_level_correct = {sin_letter_level_correct}")
print(f"word_level_correct = {all_nikud_types_word_level_correct}")
return all_nikud_types_word_level_correct, all_nikud_types_letter_level_correct