danf0 commited on
Commit
c3f0353
1 Parent(s): 1b92067

Update vendiscore.

Browse files
Files changed (2) hide show
  1. requirements.txt +11 -1
  2. vendiscore.py +98 -47
requirements.txt CHANGED
@@ -1 +1,11 @@
1
- git+https://github.com/huggingface/evaluate@main
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+ numpy>=1.13
3
+ scipy>=1.3.2
4
+ scikit-learn>=1.1
5
+ torch
6
+ torchvision
7
+ matplotlib
8
+ transformers
9
+ datasets
10
+ nltk
11
+ vendi_score
vendiscore.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -15,49 +15,62 @@
15
 
16
  import evaluate
17
  import datasets
 
18
 
 
19
 
20
  # TODO: Add BibTeX citation
21
- _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
26
- }
27
- """
28
-
29
- # TODO: Add description of the module here
30
  _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
32
  """
33
 
34
 
35
- # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
-
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
 
 
52
  >>> print(results)
53
- {'accuracy': 1.0}
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
-
59
 
60
- @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
 
 
61
  class VendiScore(evaluate.Metric):
62
  """TODO: Short description of my evaluation module."""
63
 
@@ -69,27 +82,65 @@ class VendiScore(evaluate.Metric):
69
  description=_DESCRIPTION,
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
- # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
- # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
  def _download_and_prepare(self, dl_manager):
85
  """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
  pass
88
 
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
15
 
16
  import evaluate
17
  import datasets
18
+ import numpy as np
19
 
20
+ from vendi_score import vendi, image_utils, text_utils
21
 
22
  # TODO: Add BibTeX citation
23
+ _CITATION = ""
 
 
 
 
 
 
 
 
24
  _DESCRIPTION = """\
25
+ A diversity evaluation metric for machine learning.
26
  """
27
 
28
 
 
29
  _KWARGS_DESCRIPTION = """
30
+ Calculates the Vendi Score given samples and a similarity function.
31
  Args:
32
+ samples: list of n sentences to score, an n x n similarity matrix K, or
33
+ an n x d feature matrix X.
34
+ k: a pairwise similarity function, or a string identifying a predefined
35
+ similarity function.
36
+ Options: ngram_overlap, text_embeddings, pixels, image_embeddings.
37
+ score_K: if true, samples is an n x n similarity matrix K.
38
+ score_X: if true, samples is an n x d feature matrix X.
39
+ score_dual: if true, compute diversity score of X @ X.T.
40
+ normalize: if true, normalize the similarity scores.
41
+ model (optional): if k is "text_embeddings", a model mapping sentences to
42
+ embeddings (output should be an object with an attribute called
43
+ `pooler_output` or `last_hidden_state`). If k is "image_embeddings", a
44
+ model mapping images to embeddings.
45
+ tokenizer (optional): if k is "text_embeddings" or "ngram_overlap", a
46
+ tokenizer mapping strings to lists.
47
+ transform (optional): if k is "image_embeddings", a torchvision transform
48
+ to apply to the samples.
49
+ model_path (optional): if k is "text_embeddings", the name of a model on the
50
+ HuggingFace hub.
51
+ ns (optional): if k is "ngram_overlap", the values of n to calculate.
52
+ batch_size (optional): batch size to use if k is "text_embedding" or
53
+ "image_embedding".
54
+ device (optional): a string (e.g. "cuda", "cpu") or torch.device identifying
55
+ the device to use if k is "text_embedding or "image_embedding".
56
  Returns:
57
+ VS: The Vendi Score.
 
58
  Examples:
59
+ >>> vendi_score = evaluate.load("vendi_score")
60
+ >>> samples = ["Look, Jane.",
61
+ "See Spot.",
62
+ "See Spot run.",
63
+ "Run, Spot, run.",
64
+ "Jane sees Spot run."]
65
+ >>> results = vendi_score.compute(samples, k="ngram_overlap", ns=[1, 2])
66
  >>> print(results)
67
+ {'VS': 3.90657...}
68
  """
69
 
 
 
 
70
 
71
+ @evaluate.utils.file_utils.add_start_docstrings(
72
+ _DESCRIPTION, _KWARGS_DESCRIPTION
73
+ )
74
  class VendiScore(evaluate.Metric):
75
  """TODO: Short description of my evaluation module."""
76
 
 
82
  description=_DESCRIPTION,
83
  citation=_CITATION,
84
  inputs_description=_KWARGS_DESCRIPTION,
85
+ features=datasets.Features(
86
+ {
87
+ "samples": datasets.Value("string"),
88
+ }
89
+ ),
90
+ homepage="http://github.com/Vertaix/Vendi-Score",
91
+ codebase_urls=["http://github.com/Vertaix/Vendi-Score"],
92
+ reference_urls=[],
 
 
93
  )
94
 
95
  def _download_and_prepare(self, dl_manager):
96
  """Optional: download external resources useful to compute the scores"""
 
97
  pass
98
 
99
+ def _compute(
100
+ self,
101
+ samples,
102
+ k="ngram_overlap",
103
+ score_K=False,
104
+ score_X=False,
105
+ score_dual=False,
106
+ normalize=False,
107
+ model=None,
108
+ tokenizer=None,
109
+ transform=None,
110
+ model_path=None,
111
+ ns=[1, 2],
112
+ batch_size=16,
113
+ device="cpu",
114
+ ):
115
+ if score_K:
116
+ vs = vendi.score_K(samples, normalize=normalize)
117
+ elif score_dual:
118
+ vs = vendi.score_dual(samples, normalize=normalize)
119
+ elif score_X:
120
+ vs = vendi.score_X(samples, normalize=normalize)
121
+ elif type(k) == str and k == "ngram_overlap":
122
+ vs = text_utils.ngram_vendi_score(
123
+ samples, ns=ns, tokenizer=tokenizer
124
+ )
125
+ elif type(k) == str and k == "text_embeddings":
126
+ vs = text_utils.embedding_vendi_score(
127
+ samples,
128
+ model=model,
129
+ tokenizer=tokenizer,
130
+ batch_size=batch_size,
131
+ device=device,
132
+ model_path=model_path,
133
+ )
134
+ elif type(k) == str and k == "pixels":
135
+ vs = image_utils.pixel_vendi_score(samples)
136
+ elif type(k) == str and k == "image_embeddings":
137
+ vs = image_utils.embedding_vendi_score(
138
+ samples,
139
+ batch_size=batch_size,
140
+ device=device,
141
+ model=model,
142
+ transform=transform,
143
+ )
144
+ else:
145
+ vs = vendi.score(samples, k)
146
+ return {"VS": vs}