hynky HF staff commited on
Commit
72bf36a
·
1 Parent(s): ad4b7f5
Files changed (2) hide show
  1. sklearn_proxy.py +16 -1
  2. tests.py +18 -5
sklearn_proxy.py CHANGED
@@ -16,6 +16,7 @@
16
  import evaluate
17
  import datasets
18
  from sklearn.metrics import get_scorer
 
19
 
20
 
21
  # TODO: Add BibTeX citation
@@ -61,6 +62,9 @@ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
61
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
62
  class SklearnProxy(evaluate.Metric):
63
  """TODO: Short description of my evaluation module."""
 
 
 
64
 
65
  def _info(self):
66
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
@@ -74,6 +78,7 @@ class SklearnProxy(evaluate.Metric):
74
  features=datasets.Features({
75
  'predictions': datasets.Value('int64'),
76
  'references': datasets.Value('int64'),
 
77
  }),
78
  # Homepage of the module for documentation
79
  homepage="http://module.homepage",
@@ -89,4 +94,14 @@ class SklearnProxy(evaluate.Metric):
89
 
90
  def _compute(self, predictions, references, metric_name="accuracy", **kwargs):
91
  scorer = get_scorer(metric_name)
92
- return scorer(references, predictions, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
16
  import evaluate
17
  import datasets
18
  from sklearn.metrics import get_scorer
19
+ from sklearn.base import BaseEstimator
20
 
21
 
22
  # TODO: Add BibTeX citation
 
62
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63
  class SklearnProxy(evaluate.Metric):
64
  """TODO: Short description of my evaluation module."""
65
+ def __init__(self, **kwargs):
66
+ super().__init__(**kwargs)
67
+ self.dummy_estimator = PassThroughtEstimator()
68
 
69
  def _info(self):
70
  # TODO: Specifies the evaluate.EvaluationModuleInfo object
 
78
  features=datasets.Features({
79
  'predictions': datasets.Value('int64'),
80
  'references': datasets.Value('int64'),
81
+
82
  }),
83
  # Homepage of the module for documentation
84
  homepage="http://module.homepage",
 
94
 
95
  def _compute(self, predictions, references, metric_name="accuracy", **kwargs):
96
  scorer = get_scorer(metric_name)
97
+
98
+ return {metric_name: scorer(self.dummy_estimator, references, predictions, **kwargs)}
99
+
100
+
101
+ class PassThroughtEstimator(BaseEstimator):
102
+ def __init__(self):
103
+ pass
104
+ def fit(self, X, y):
105
+ return self
106
+ def predict(self, X):
107
+ return X
tests.py CHANGED
@@ -1,17 +1,30 @@
1
- test_cases = [
 
 
 
 
2
  {
3
  "predictions": [0, 0],
4
  "references": [1, 1],
5
- "result": {"metric_score": 0}
6
  },
7
  {
8
  "predictions": [1, 1],
9
  "references": [1, 1],
10
- "result": {"metric_score": 1}
11
  },
12
  {
13
  "predictions": [1, 0],
14
  "references": [1, 1],
15
- "result": {"metric_score": 0.5}
16
  }
17
- ]
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn_proxy import SklearnProxy
2
+ import unittest
3
+
4
+
5
+ accuracy_test_cases = [
6
  {
7
  "predictions": [0, 0],
8
  "references": [1, 1],
9
+ "result": {"accuracy": 0.0}
10
  },
11
  {
12
  "predictions": [1, 1],
13
  "references": [1, 1],
14
+ "result": {"accuracy": 1.0}
15
  },
16
  {
17
  "predictions": [1, 0],
18
  "references": [1, 1],
19
+ "result": {"accuracy": 0.5}
20
  }
21
+ ]
22
+
23
+
24
+ class TestGeneral(unittest.TestCase):
25
+
26
+ def test_accuracy(self):
27
+ metric = SklearnProxy()
28
+ for test_case in accuracy_test_cases:
29
+ result = metric.compute(predictions=test_case["predictions"],references=test_case["references"], metric_name="accuracy")
30
+ self.assertEqual(result, test_case["result"])