Spaces:
Runtime error
Runtime error
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Signal-to-Reconstruction Error (SRE) metric.""" | |
import evaluate | |
import datasets | |
import numpy as np | |
_DESCRIPTION = """\ | |
Compute the Signal-to-Reconstruction Error (SRE) metric. This metric is commonly used to | |
asses the performance of denoising, super-resolution and style transfer algorithms in | |
audio and image processing. | |
""" | |
_CITATION = """\ | |
@InProceedings{huggingface:module, | |
title = {A great new module}, | |
authors={huggingface, Inc.}, | |
year={2020} | |
} | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`list` of `np.array`): Predicted labels. | |
references (`list` of `np.array`): Ground truth labels. | |
sample_weight (`list` of `float`): Sample weights Defaults to None. | |
Returns: | |
sre (`float`): Signal-to-Reconstruction Error (SRE) metric. The SRE values are | |
positive and they are expressed in decibels (dB). The higher the SRE value, the better. | |
Examples: | |
Example 1-A simple example | |
>>> sre = evaluate.load("jpxkqx/signal_to_reconstruction_error") | |
>>> results = sre.compute(references=[[0, 0], [-1, -1]], predictions=[[0, 1], [0, 0]]) | |
>>> print(results) | |
{"Signal-to-Reconstruction Error": 23.01} | |
""" | |
def signal_reconstruction_error(y_true: np.array, y_hat: np.array) -> np.array: | |
return 10 * np.log10(np.sum(y_true ** 2) / np.sum((y_true - y_hat) ** 2)) | |
class SignaltoReconstrutionError(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features(self._get_feature_types()), | |
homepage="https://huggingface.co./spaces/jpxkqx/signal_to_reconstrution_error", | |
) | |
def _get_feature_types(self): | |
if self.config_name == "multilist": | |
return { | |
# 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width | |
"predictions": datasets.Sequence( | |
datasets.Sequence(datasets.Sequence(datasets.Value("float32"))) | |
), | |
"references": datasets.Sequence( | |
datasets.Sequence(datasets.Sequence(datasets.Value("float32"))) | |
), | |
} | |
else: | |
return { | |
# 1st Seq - Height, 2rd Seq - Width | |
"predictions": datasets.Sequence( | |
datasets.Sequence(datasets.Value("float32")) | |
), | |
"references": datasets.Sequence( | |
datasets.Sequence(datasets.Value("float32")) | |
), | |
} | |
def _compute(self, predictions, references, sample_weight=None): | |
"""Returns the scores""" | |
samples = zip(np.array(references), np.array(predictions)) | |
psnrs = list(map(lambda args: signal_reconstruction_error(*args), samples)) | |
return { | |
"Signal-to-Reconstruction Error": np.average(psnrs, weights=sample_weight) | |
} | |