danf0 commited on
Commit
a53b506
1 Parent(s): 74d77bc

Try updating possible types.

Browse files
Files changed (1) hide show
  1. vendiscore.py +12 -18
vendiscore.py CHANGED
@@ -85,10 +85,7 @@ class VendiScore(evaluate.Metric):
85
  inputs_description=_KWARGS_DESCRIPTION,
86
  features=datasets.Features(
87
  {
88
- "sents": datasets.Value("string"),
89
- "imgs": datasets.Image,
90
- "X": datasets.Array2D,
91
- "K": datasets.Array2D,
92
  }
93
  ),
94
  homepage="http://github.com/Vertaix/Vendi-Score",
@@ -103,10 +100,7 @@ class VendiScore(evaluate.Metric):
103
 
104
  def _compute(
105
  self,
106
- sents=None,
107
- imgs=None,
108
- X=None,
109
- K=None,
110
  k="ngram_overlap",
111
  score_K=False,
112
  score_X=False,
@@ -121,16 +115,18 @@ class VendiScore(evaluate.Metric):
121
  device="cpu",
122
  ):
123
  if score_K:
124
- vs = vendi.score_K(K, normalize=normalize)
125
  elif score_dual:
126
- vs = vendi.score_dual(X, normalize=normalize)
127
  elif score_X:
128
- vs = vendi.score_X(X, normalize=normalize)
129
  elif type(k) == str and k == "ngram_overlap":
130
- vs = text_utils.ngram_vendi_score(sents, ns=ns, tokenizer=tokenizer)
 
 
131
  elif type(k) == str and k == "text_embeddings":
132
  vs = text_utils.embedding_vendi_score(
133
- sents,
134
  model=model,
135
  tokenizer=tokenizer,
136
  batch_size=batch_size,
@@ -138,17 +134,15 @@ class VendiScore(evaluate.Metric):
138
  model_path=model_path,
139
  )
140
  elif type(k) == str and k == "pixels":
141
- vs = image_utils.pixel_vendi_score(imgs)
142
  elif type(k) == str and k == "image_embeddings":
143
  vs = image_utils.embedding_vendi_score(
144
- imgs,
145
  batch_size=batch_size,
146
  device=device,
147
  model=model,
148
  transform=transform,
149
  )
150
- elif sents is not None or imgs is not None or X is not None:
151
- vs = vendi.score(sents or imgs or X, k)
152
  else:
153
- raise ValueError(f"Must provide one of `sents` or `imgs` or `X`.")
154
  return {"VS": vs}
 
85
  inputs_description=_KWARGS_DESCRIPTION,
86
  features=datasets.Features(
87
  {
88
+ "samples": datasets.Array2D,
 
 
 
89
  }
90
  ),
91
  homepage="http://github.com/Vertaix/Vendi-Score",
 
100
 
101
  def _compute(
102
  self,
103
+ samples,
 
 
 
104
  k="ngram_overlap",
105
  score_K=False,
106
  score_X=False,
 
115
  device="cpu",
116
  ):
117
  if score_K:
118
+ vs = vendi.score_K(samples, normalize=normalize)
119
  elif score_dual:
120
+ vs = vendi.score_dual(samples, normalize=normalize)
121
  elif score_X:
122
+ vs = vendi.score_X(samples, normalize=normalize)
123
  elif type(k) == str and k == "ngram_overlap":
124
+ vs = text_utils.ngram_vendi_score(
125
+ samples, ns=ns, tokenizer=tokenizer
126
+ )
127
  elif type(k) == str and k == "text_embeddings":
128
  vs = text_utils.embedding_vendi_score(
129
+ samples,
130
  model=model,
131
  tokenizer=tokenizer,
132
  batch_size=batch_size,
 
134
  model_path=model_path,
135
  )
136
  elif type(k) == str and k == "pixels":
137
+ vs = image_utils.pixel_vendi_score(samples)
138
  elif type(k) == str and k == "image_embeddings":
139
  vs = image_utils.embedding_vendi_score(
140
+ samples,
141
  batch_size=batch_size,
142
  device=device,
143
  model=model,
144
  transform=transform,
145
  )
 
 
146
  else:
147
+ vs = vendi.score(samples, k)
148
  return {"VS": vs}