danf0 commited on
Commit
74d77bc
1 Parent(s): 85ed75d

Try updating possible types.

Browse files
Files changed (1) hide show
  1. vendiscore.py +18 -12
vendiscore.py CHANGED
@@ -85,7 +85,10 @@ class VendiScore(evaluate.Metric):
85
  inputs_description=_KWARGS_DESCRIPTION,
86
  features=datasets.Features(
87
  {
88
- "samples": datasets.Value("string"),
 
 
 
89
  }
90
  ),
91
  homepage="http://github.com/Vertaix/Vendi-Score",
@@ -100,7 +103,10 @@ class VendiScore(evaluate.Metric):
100
 
101
  def _compute(
102
  self,
103
- samples,
 
 
 
104
  k="ngram_overlap",
105
  score_K=False,
106
  score_X=False,
@@ -115,18 +121,16 @@ class VendiScore(evaluate.Metric):
115
  device="cpu",
116
  ):
117
  if score_K:
118
- vs = vendi.score_K(samples, normalize=normalize)
119
  elif score_dual:
120
- vs = vendi.score_dual(samples, normalize=normalize)
121
  elif score_X:
122
- vs = vendi.score_X(samples, normalize=normalize)
123
  elif type(k) == str and k == "ngram_overlap":
124
- vs = text_utils.ngram_vendi_score(
125
- samples, ns=ns, tokenizer=tokenizer
126
- )
127
  elif type(k) == str and k == "text_embeddings":
128
  vs = text_utils.embedding_vendi_score(
129
- samples,
130
  model=model,
131
  tokenizer=tokenizer,
132
  batch_size=batch_size,
@@ -134,15 +138,17 @@ class VendiScore(evaluate.Metric):
134
  model_path=model_path,
135
  )
136
  elif type(k) == str and k == "pixels":
137
- vs = image_utils.pixel_vendi_score(samples)
138
  elif type(k) == str and k == "image_embeddings":
139
  vs = image_utils.embedding_vendi_score(
140
- samples,
141
  batch_size=batch_size,
142
  device=device,
143
  model=model,
144
  transform=transform,
145
  )
 
 
146
  else:
147
- vs = vendi.score(samples, k)
148
  return {"VS": vs}
 
85
  inputs_description=_KWARGS_DESCRIPTION,
86
  features=datasets.Features(
87
  {
88
+ "sents": datasets.Value("string"),
89
+ "imgs": datasets.Image,
90
+ "X": datasets.Array2D,
91
+ "K": datasets.Array2D,
92
  }
93
  ),
94
  homepage="http://github.com/Vertaix/Vendi-Score",
 
103
 
104
  def _compute(
105
  self,
106
+ sents=None,
107
+ imgs=None,
108
+ X=None,
109
+ K=None,
110
  k="ngram_overlap",
111
  score_K=False,
112
  score_X=False,
 
121
  device="cpu",
122
  ):
123
  if score_K:
124
+ vs = vendi.score_K(K, normalize=normalize)
125
  elif score_dual:
126
+ vs = vendi.score_dual(X, normalize=normalize)
127
  elif score_X:
128
+ vs = vendi.score_X(X, normalize=normalize)
129
  elif type(k) == str and k == "ngram_overlap":
130
+ vs = text_utils.ngram_vendi_score(sents, ns=ns, tokenizer=tokenizer)
 
 
131
  elif type(k) == str and k == "text_embeddings":
132
  vs = text_utils.embedding_vendi_score(
133
+ sents,
134
  model=model,
135
  tokenizer=tokenizer,
136
  batch_size=batch_size,
 
138
  model_path=model_path,
139
  )
140
  elif type(k) == str and k == "pixels":
141
+ vs = image_utils.pixel_vendi_score(imgs)
142
  elif type(k) == str and k == "image_embeddings":
143
  vs = image_utils.embedding_vendi_score(
144
+ imgs,
145
  batch_size=batch_size,
146
  device=device,
147
  model=model,
148
  transform=transform,
149
  )
150
+ elif sents is not None or imgs is not None or X is not None:
151
+ vs = vendi.score(sents or imgs or X, k)
152
  else:
153
+ raise ValueError(f"Must provide one of `sents` or `imgs` or `X`.")
154
  return {"VS": vs}