Spaces:
Runtime error
Runtime error
File size: 6,517 Bytes
0102e16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
#!/usr/bin/env python3
# Copyright 2018 David Snyder
# Apache 2.0
# This script computes the minimum detection cost function, which is a common
# error metric used in speaker recognition. Compared to equal error-rate,
# which assigns equal weight to false negatives and false positives, this
# error-rate is usually used to assess performance in settings where achieving
# a low false positive rate is more important than achieving a low false
# negative rate. See the NIST 2016 Speaker Recognition Evaluation Plan at
# https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
# for more details about the metric.
from __future__ import print_function
from operator import itemgetter
import sys, argparse, os
def GetArgs():
parser = argparse.ArgumentParser(
description="Compute the minimum "
"detection cost function along with the threshold at which it occurs. "
"Usage: sid/compute_min_dcf.py [options...] <scores-file> "
"<trials-file> "
"E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
"exp/scores/trials data/test/trials",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--p-target",
type=float,
dest="p_target",
default=0.01,
help="The prior probability of the target speaker in a trial.",
)
parser.add_argument(
"--c-miss",
type=float,
dest="c_miss",
default=1,
help="Cost of a missed detection. This is usually not changed.",
)
parser.add_argument(
"--c-fa",
type=float,
dest="c_fa",
default=1,
help="Cost of a spurious detection. This is usually not changed.",
)
parser.add_argument(
"scores_filename",
help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
)
parser.add_argument(
"trials_filename",
help="Input trials file, with columns of the form "
"<utt1> <utt2> <target/nontarget>",
)
sys.stderr.write(" ".join(sys.argv) + "\n")
args = parser.parse_args()
args = CheckArgs(args)
return args
def CheckArgs(args):
if args.c_fa <= 0:
raise Exception("--c-fa must be greater than 0")
if args.c_miss <= 0:
raise Exception("--c-miss must be greater than 0")
if args.p_target <= 0 or args.p_target >= 1:
raise Exception("--p-target must be greater than 0 and less than 1")
return args
# Creates a list of false-negative rates, a list of false-positive rates
# and a list of decision thresholds that give those error-rates.
def ComputeErrorRates(scores, labels):
# Sort the scores from smallest to largest, and also get the corresponding
# indexes of the sorted scores. We will treat the sorted scores as the
# thresholds at which the the error-rates are evaluated.
sorted_indexes, thresholds = zip(
*sorted(
[(index, threshold) for index, threshold in enumerate(scores)],
key=itemgetter(1),
)
)
labels = [labels[i] for i in sorted_indexes]
fns = []
tns = []
# At the end of this loop, fns[i] is the number of errors made by
# incorrectly rejecting scores less than thresholds[i]. And, tns[i]
# is the total number of times that we have correctly rejected scores
# less than thresholds[i].
for i in range(0, len(labels)):
if i == 0:
fns.append(labels[i])
tns.append(1 - labels[i])
else:
fns.append(fns[i - 1] + labels[i])
tns.append(tns[i - 1] + 1 - labels[i])
positives = sum(labels)
negatives = len(labels) - positives
# Now divide the false negatives by the total number of
# positives to obtain the false negative rates across
# all thresholds
fnrs = [fn / float(positives) for fn in fns]
# Divide the true negatives by the total number of
# negatives to get the true negative rate. Subtract these
# quantities from 1 to get the false positive rates.
fprs = [1 - tn / float(negatives) for tn in tns]
return fnrs, fprs, thresholds
# Computes the minimum of the detection cost function. The comments refer to
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
min_c_det = float("inf")
min_c_det_threshold = thresholds[0]
for i in range(0, len(fnrs)):
# See Equation (2). it is a weighted sum of false negative
# and false positive errors.
c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
if c_det < min_c_det:
min_c_det = c_det
min_c_det_threshold = thresholds[i]
# See Equations (3) and (4). Now we normalize the cost.
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
min_dcf = min_c_det / c_def
return min_dcf, min_c_det_threshold
def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
scores_file = open(scores_filename, "r").readlines()
trials_file = open(trials_filename, "r").readlines()
c_miss = c_miss
c_fa = c_fa
p_target = p_target
scores = []
labels = []
trials = {}
for line in trials_file:
utt1, utt2, target = line.rstrip().split()
trial = utt1 + " " + utt2
trials[trial] = target
for line in scores_file:
utt1, utt2, score = line.rstrip().split()
trial = utt1 + " " + utt2
if trial in trials:
scores.append(float(score))
if trials[trial] == "target":
labels.append(1)
else:
labels.append(0)
else:
raise Exception(
"Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename
)
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
return mindcf, threshold
def main():
args = GetArgs()
mindcf, threshold = compute_min_dcf(
args.scores_filename,
args.trials_filename,
args.c_miss,
args.c_fa,
args.p_target,
)
sys.stdout.write(
"minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
"c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
)
if __name__ == "__main__":
main()
|