File size: 6,517 Bytes
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
# Copyright 2018  David Snyder
# Apache 2.0

# This script computes the minimum detection cost function, which is a common
# error metric used in speaker recognition.  Compared to equal error-rate,
# which assigns equal weight to false negatives and false positives, this
# error-rate is usually used to assess performance in settings where achieving
# a low false positive rate is more important than achieving a low false
# negative rate.  See the NIST 2016 Speaker Recognition Evaluation Plan at
# https://www.nist.gov/sites/default/files/documents/2016/10/07/sre16_eval_plan_v1.3.pdf
# for more details about the metric.
from __future__ import print_function
from operator import itemgetter
import sys, argparse, os


def GetArgs():
    parser = argparse.ArgumentParser(
        description="Compute the minimum "
        "detection cost function along with the threshold at which it occurs. "
        "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
        "<trials-file> "
        "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
        "exp/scores/trials data/test/trials",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--p-target",
        type=float,
        dest="p_target",
        default=0.01,
        help="The prior probability of the target speaker in a trial.",
    )
    parser.add_argument(
        "--c-miss",
        type=float,
        dest="c_miss",
        default=1,
        help="Cost of a missed detection.  This is usually not changed.",
    )
    parser.add_argument(
        "--c-fa",
        type=float,
        dest="c_fa",
        default=1,
        help="Cost of a spurious detection.  This is usually not changed.",
    )
    parser.add_argument(
        "scores_filename",
        help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
    )
    parser.add_argument(
        "trials_filename",
        help="Input trials file, with columns of the form "
        "<utt1> <utt2> <target/nontarget>",
    )
    sys.stderr.write(" ".join(sys.argv) + "\n")
    args = parser.parse_args()
    args = CheckArgs(args)
    return args


def CheckArgs(args):
    if args.c_fa <= 0:
        raise Exception("--c-fa must be greater than 0")
    if args.c_miss <= 0:
        raise Exception("--c-miss must be greater than 0")
    if args.p_target <= 0 or args.p_target >= 1:
        raise Exception("--p-target must be greater than 0 and less than 1")
    return args


# Creates a list of false-negative rates, a list of false-positive rates
# and a list of decision thresholds that give those error-rates.
def ComputeErrorRates(scores, labels):

    # Sort the scores from smallest to largest, and also get the corresponding
    # indexes of the sorted scores.  We will treat the sorted scores as the
    # thresholds at which the the error-rates are evaluated.
    sorted_indexes, thresholds = zip(
        *sorted(
            [(index, threshold) for index, threshold in enumerate(scores)],
            key=itemgetter(1),
        )
    )
    labels = [labels[i] for i in sorted_indexes]
    fns = []
    tns = []

    # At the end of this loop, fns[i] is the number of errors made by
    # incorrectly rejecting scores less than thresholds[i]. And, tns[i]
    # is the total number of times that we have correctly rejected scores
    # less than thresholds[i].
    for i in range(0, len(labels)):
        if i == 0:
            fns.append(labels[i])
            tns.append(1 - labels[i])
        else:
            fns.append(fns[i - 1] + labels[i])
            tns.append(tns[i - 1] + 1 - labels[i])
    positives = sum(labels)
    negatives = len(labels) - positives

    # Now divide the false negatives by the total number of
    # positives to obtain the false negative rates across
    # all thresholds
    fnrs = [fn / float(positives) for fn in fns]

    # Divide the true negatives by the total number of
    # negatives to get the true negative rate. Subtract these
    # quantities from 1 to get the false positive rates.
    fprs = [1 - tn / float(negatives) for tn in tns]
    return fnrs, fprs, thresholds


# Computes the minimum of the detection cost function.  The comments refer to
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
    min_c_det = float("inf")
    min_c_det_threshold = thresholds[0]
    for i in range(0, len(fnrs)):
        # See Equation (2).  it is a weighted sum of false negative
        # and false positive errors.
        c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
        if c_det < min_c_det:
            min_c_det = c_det
            min_c_det_threshold = thresholds[i]
    # See Equations (3) and (4).  Now we normalize the cost.
    c_def = min(c_miss * p_target, c_fa * (1 - p_target))
    min_dcf = min_c_det / c_def
    return min_dcf, min_c_det_threshold


def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
    scores_file = open(scores_filename, "r").readlines()
    trials_file = open(trials_filename, "r").readlines()
    c_miss = c_miss
    c_fa = c_fa
    p_target = p_target

    scores = []
    labels = []

    trials = {}
    for line in trials_file:
        utt1, utt2, target = line.rstrip().split()
        trial = utt1 + " " + utt2
        trials[trial] = target

    for line in scores_file:
        utt1, utt2, score = line.rstrip().split()
        trial = utt1 + " " + utt2
        if trial in trials:
            scores.append(float(score))
            if trials[trial] == "target":
                labels.append(1)
            else:
                labels.append(0)
        else:
            raise Exception(
                "Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename
            )

    fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
    mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
    return mindcf, threshold


def main():
    args = GetArgs()
    mindcf, threshold = compute_min_dcf(
        args.scores_filename,
        args.trials_filename,
        args.c_miss,
        args.c_fa,
        args.p_target,
    )
    sys.stdout.write(
        "minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
        "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
    )


if __name__ == "__main__":
    main()