File size: 751 Bytes
72bf36a
 
 
 
 
eeb533b
 
 
72bf36a
eeb533b
 
 
 
72bf36a
eeb533b
 
 
 
72bf36a
eeb533b
72bf36a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from sklearn_proxy import SklearnProxy
import unittest


accuracy_test_cases = [
    {
        "predictions": [0, 0],
        "references": [1, 1],
        "result": {"accuracy": 0.0}
    },
    {
        "predictions": [1, 1],
        "references": [1, 1],
        "result": {"accuracy": 1.0}
    },
    {
        "predictions": [1, 0],
        "references": [1, 1],
        "result": {"accuracy": 0.5}
    }
]


class TestGeneral(unittest.TestCase):

    def test_accuracy(self):
        metric = SklearnProxy()
        for test_case in accuracy_test_cases:
            result = metric.compute(predictions=test_case["predictions"],references=test_case["references"], metric_name="accuracy")
            self.assertEqual(result, test_case["result"])