vendiscore / vendiscore.py
danf0's picture
Try updating possible types.
a53b506
raw
history blame
5.77 kB
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import evaluate
import datasets
import numpy as np
from vendi_score import vendi, image_utils, text_utils
# TODO: Add BibTeX citation
_CITATION = ""
_DESCRIPTION = """\
The Vendi Score is a metric for evaluating diversity in machine learning.
The input to metric is a collection of samples and a pairwise similarity function, and the output is a number, which can be interpreted as the effective number of unique elements in the sample.
See the project's README at https://github.com/vertaix/Vendi-Score for more information.
The interactive example calculates the Vendi Score for a set of strings using the n-gram overlap similarity, averaged between n=1 and n=2.
"""
_KWARGS_DESCRIPTION = """
Calculates the Vendi Score given samples and a similarity function.
Args:
samples: an iterable containing n samples to score, an n x n similarity
matrix K, or an n x d feature matrix X.
k: a pairwise similarity function, or a string identifying a predefined
similarity function.
Options: ngram_overlap, text_embeddings, pixels, image_embeddings.
score_K: if true, samples is an n x n similarity matrix K.
score_X: if true, samples is an n x d feature matrix X.
score_dual: if true, compute diversity score of X @ X.T.
normalize: if true, normalize the similarity scores.
model (optional): if k is "text_embeddings", a model mapping sentences to
embeddings (output should be an object with an attribute called
`pooler_output` or `last_hidden_state`). If k is "image_embeddings", a
model mapping images to embeddings.
tokenizer (optional): if k is "text_embeddings" or "ngram_overlap", a
tokenizer mapping strings to lists.
transform (optional): if k is "image_embeddings", a torchvision transform
to apply to the samples.
model_path (optional): if k is "text_embeddings", the name of a model on the
HuggingFace hub.
ns (optional): if k is "ngram_overlap", the values of n to calculate.
batch_size (optional): batch size to use if k is "text_embedding" or
"image_embedding".
device (optional): a string (e.g. "cuda", "cpu") or torch.device identifying
the device to use if k is "text_embedding or "image_embedding".
Returns:
VS: The Vendi Score.
Examples:
>>> vendiscore = evaluate.load("danf0/vendiscore")
>>> samples = ["Look, Jane.",
"See Spot.",
"See Spot run.",
"Run, Spot, run.",
"Jane sees Spot run."]
>>> results = vendiscore.compute(samples, k="ngram_overlap", ns=[1, 2])
>>> print(results)
{'VS': 3.90657...}
"""
@evaluate.utils.file_utils.add_start_docstrings(
_DESCRIPTION, _KWARGS_DESCRIPTION
)
class VendiScore(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"samples": datasets.Array2D,
}
),
homepage="http://github.com/Vertaix/Vendi-Score",
codebase_urls=["http://github.com/Vertaix/Vendi-Score"],
reference_urls=[],
)
def _download_and_prepare(self, dl_manager):
import nltk
nltk.download("punkt")
def _compute(
self,
samples,
k="ngram_overlap",
score_K=False,
score_X=False,
score_dual=False,
normalize=False,
model=None,
tokenizer=None,
transform=None,
model_path=None,
ns=[1, 2],
batch_size=16,
device="cpu",
):
if score_K:
vs = vendi.score_K(samples, normalize=normalize)
elif score_dual:
vs = vendi.score_dual(samples, normalize=normalize)
elif score_X:
vs = vendi.score_X(samples, normalize=normalize)
elif type(k) == str and k == "ngram_overlap":
vs = text_utils.ngram_vendi_score(
samples, ns=ns, tokenizer=tokenizer
)
elif type(k) == str and k == "text_embeddings":
vs = text_utils.embedding_vendi_score(
samples,
model=model,
tokenizer=tokenizer,
batch_size=batch_size,
device=device,
model_path=model_path,
)
elif type(k) == str and k == "pixels":
vs = image_utils.pixel_vendi_score(samples)
elif type(k) == str and k == "image_embeddings":
vs = image_utils.embedding_vendi_score(
samples,
batch_size=batch_size,
device=device,
model=model,
transform=transform,
)
else:
vs = vendi.score(samples, k)
return {"VS": vs}