Spaces:
Runtime error
Runtime error
import numpy as np | |
from sklearn.metrics import roc_curve | |
import argparse | |
def _compute_eer(label, pred, positive_label=1): | |
""" | |
Python compute equal error rate (eer) | |
ONLY tested on binary classification | |
:param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample | |
:param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample | |
:param positive_label: the class that is viewed as positive class when computing EER | |
:return: equal error rate (EER) | |
""" | |
# all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array) | |
fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label) | |
fnr = 1 - tpr | |
# the threshold of fnr == fpr | |
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] | |
# theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality | |
eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))] | |
eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))] | |
# return the mean of eer from fpr and from fnr | |
eer = (eer_1 + eer_2) / 2 | |
return eer, eer_threshold | |
def compute_eer(trials_path, scores_path): | |
labels = [] | |
for one_line in open(trials_path, "r"): | |
labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target") | |
labels = np.array(labels, dtype=int) | |
scores = [] | |
for one_line in open(scores_path, "r"): | |
scores.append(float(one_line.strip().rsplit(" ", 1)[-1])) | |
scores = np.array(scores, dtype=float) | |
eer, threshold = _compute_eer(labels, scores) | |
return eer, threshold | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("trials", help="trial list") | |
parser.add_argument("scores", help="score file, normalized to [0, 1]") | |
args = parser.parse_args() | |
eer, threshold = compute_eer(args.trials, args.scores) | |
print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold)) | |
if __name__ == "__main__": | |
main() | |