xBitterT5 / src /utils.py
ndhieunguyen's picture
feat: first commit
6a53dd4
raw
history blame
1.25 kB
import evaluate
import numpy as np
from datetime import datetime
from zoneinfo import ZoneInfo
from torch.nn.functional import softmax
from torch import tensor
from sklearn.metrics import confusion_matrix, roc_curve, auc
bitter_metrics = evaluate.combine(
["accuracy", "f1", "precision", "recall", "matthews_correlation"]
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
preds = np.argmax(predictions[0], axis=1)
prediction_scores = softmax(tensor(predictions[0]), dim=-1)
prediction_scores = prediction_scores[:, 1].cpu().numpy()
metrics = bitter_metrics.compute(predictions=preds, references=labels)
tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
specificity = tn / (tn + fp)
metrics.update(
{
"eval_specificity": specificity,
"eval_tn": tn,
"eval_fp": fp,
"eval_fn": fn,
"eval_tp": tp,
}
)
fpr2, tpr2, _ = roc_curve(labels, prediction_scores, pos_label=1)
auc2 = auc(fpr2, tpr2)
metrics.update({"eval_auc": auc2})
metrics = dict(sorted(metrics.items()))
return metrics
def get_time_string():
return datetime.now(tz=ZoneInfo("Asia/Seoul")).strftime("%Y_%m_%d__%H_%M_%S")