Spaces:
Runtime error
Runtime error
File size: 2,079 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import numpy as np
from sklearn.metrics import roc_curve
import argparse
def _compute_eer(label, pred, positive_label=1):
"""
Python compute equal error rate (eer)
ONLY tested on binary classification
:param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample
:param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample
:param positive_label: the class that is viewed as positive class when computing EER
:return: equal error rate (EER)
"""
# all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array)
fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label)
fnr = 1 - tpr
# the threshold of fnr == fpr
eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
# theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality
eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
# return the mean of eer from fpr and from fnr
eer = (eer_1 + eer_2) / 2
return eer, eer_threshold
def compute_eer(trials_path, scores_path):
labels = []
for one_line in open(trials_path, "r"):
labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target")
labels = np.array(labels, dtype=int)
scores = []
for one_line in open(scores_path, "r"):
scores.append(float(one_line.strip().rsplit(" ", 1)[-1]))
scores = np.array(scores, dtype=float)
eer, threshold = _compute_eer(labels, scores)
return eer, threshold
def main():
parser = argparse.ArgumentParser()
parser.add_argument("trials", help="trial list")
parser.add_argument("scores", help="score file, normalized to [0, 1]")
args = parser.parse_args()
eer, threshold = compute_eer(args.trials, args.scores)
print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold))
if __name__ == "__main__":
main()
|