jpxkqx commited on
Commit
c3f2132
·
1 Parent(s): 2145128

remove name

Browse files
Files changed (2) hide show
  1. structural_similarity_index_measure.py +45 -14
  2. tests.py +2 -2
structural_similarity_index_measure.py CHANGED
@@ -71,13 +71,32 @@ class StructuralSimilarityIndexMeasure(evaluate.Metric):
71
  description=_DESCRIPTION,
72
  citation=_CITATION,
73
  inputs_description=_KWARGS_DESCRIPTION,
74
- features=datasets.Features({
75
- "predictions": datasets.Sequence(datasets.Array2D("float32")),
76
- "references": datasets.Sequence(datasets.Array2D("float32")),
77
- }),
78
  reference_urls=["https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html"],
79
  )
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def _compute(
82
  self,
83
  predictions,
@@ -89,19 +108,31 @@ class StructuralSimilarityIndexMeasure(evaluate.Metric):
89
  sample_weight=None,
90
  **kwargs
91
  ) -> Dict[str, float]:
92
- def func_ssim(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return structural_similarity(
94
- *args,
 
95
  win_size=win_size,
96
  gaussian_weights=gaussian_weights,
97
  data_range=data_range,
98
  multichannel=multichannel,
99
  **kwargs
100
- )
101
-
102
- return {
103
- "ssim": np.average(
104
- list(map(func_ssim, zip(predictions, references))),
105
- weights=sample_weight
106
- )
107
- }
 
71
  description=_DESCRIPTION,
72
  citation=_CITATION,
73
  inputs_description=_KWARGS_DESCRIPTION,
74
+ features=datasets.Features(self._get_feature_types()),
 
 
 
75
  reference_urls=["https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html"],
76
  )
77
 
78
+ def _get_feature_types(self):
79
+ if self.config_name == "multilist":
80
+ return {
81
+ # 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width
82
+ "predictions": datasets.Sequence(
83
+ datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
84
+ ),
85
+ "references": datasets.Sequence(
86
+ datasets.Sequence(datasets.Sequence(datasets.Value("float32")))
87
+ ),
88
+ }
89
+ else:
90
+ return {
91
+ # 1st Seq - Height, 2rd Seq - Width
92
+ "predictions": datasets.Sequence(
93
+ datasets.Sequence(datasets.Value("float32"))
94
+ ),
95
+ "references": datasets.Sequence(
96
+ datasets.Sequence(datasets.Value("float32"))
97
+ ),
98
+ }
99
+
100
  def _compute(
101
  self,
102
  predictions,
 
108
  sample_weight=None,
109
  **kwargs
110
  ) -> Dict[str, float]:
111
+ if self.config_name == "multilist":
112
+ def func_ssim(args):
113
+ pred, target = args
114
+ pred = np.array(pred)
115
+ target = np.array(target)
116
+ return structural_similarity(
117
+ pred,
118
+ target,
119
+ win_size=win_size,
120
+ gaussian_weights=gaussian_weights,
121
+ data_range=data_range,
122
+ multichannel=multichannel,
123
+ **kwargs
124
+ )
125
+ return np.average(
126
+ list(map(func_ssim, zip(predictions, references))),
127
+ weights=sample_weight
128
+ )
129
+ else:
130
  return structural_similarity(
131
+ np.array(predictions),
132
+ np.array(references),
133
  win_size=win_size,
134
  gaussian_weights=gaussian_weights,
135
  data_range=data_range,
136
  multichannel=multichannel,
137
  **kwargs
138
+ )
 
 
 
 
 
 
 
tests.py CHANGED
@@ -2,11 +2,11 @@ test_cases = [
2
  {
3
  "predictions": [[0.1, 0.1], [1.1, 0.1]],
4
  "references": [[0.1, 0.1], [1.1, 0.1]],
5
- "result": {"Peak Signal-to-Noise Ratio": 23.010298856486173}
6
  },
7
  {
8
  "predictions": [[0, 1], [0, 0]],
9
  "references": [[0, 0], [-1, -1]],
10
- "result": {"Peak Signal-to-Noise Ratio": 1.2493873660829993}
11
  }
12
  ]
 
2
  {
3
  "predictions": [[0.1, 0.1], [1.1, 0.1]],
4
  "references": [[0.1, 0.1], [1.1, 0.1]],
5
+ "result": 1.0
6
  },
7
  {
8
  "predictions": [[0, 1], [0, 0]],
9
  "references": [[0, 0], [-1, -1]],
10
+ "result": 0.2
11
  }
12
  ]