Spaces:
Build error
Build error
Try updating possible types.
Browse files- vendiscore.py +18 -12
vendiscore.py
CHANGED
@@ -85,7 +85,10 @@ class VendiScore(evaluate.Metric):
|
|
85 |
inputs_description=_KWARGS_DESCRIPTION,
|
86 |
features=datasets.Features(
|
87 |
{
|
88 |
-
"
|
|
|
|
|
|
|
89 |
}
|
90 |
),
|
91 |
homepage="http://github.com/Vertaix/Vendi-Score",
|
@@ -100,7 +103,10 @@ class VendiScore(evaluate.Metric):
|
|
100 |
|
101 |
def _compute(
|
102 |
self,
|
103 |
-
|
|
|
|
|
|
|
104 |
k="ngram_overlap",
|
105 |
score_K=False,
|
106 |
score_X=False,
|
@@ -115,18 +121,16 @@ class VendiScore(evaluate.Metric):
|
|
115 |
device="cpu",
|
116 |
):
|
117 |
if score_K:
|
118 |
-
vs = vendi.score_K(
|
119 |
elif score_dual:
|
120 |
-
vs = vendi.score_dual(
|
121 |
elif score_X:
|
122 |
-
vs = vendi.score_X(
|
123 |
elif type(k) == str and k == "ngram_overlap":
|
124 |
-
vs = text_utils.ngram_vendi_score(
|
125 |
-
samples, ns=ns, tokenizer=tokenizer
|
126 |
-
)
|
127 |
elif type(k) == str and k == "text_embeddings":
|
128 |
vs = text_utils.embedding_vendi_score(
|
129 |
-
|
130 |
model=model,
|
131 |
tokenizer=tokenizer,
|
132 |
batch_size=batch_size,
|
@@ -134,15 +138,17 @@ class VendiScore(evaluate.Metric):
|
|
134 |
model_path=model_path,
|
135 |
)
|
136 |
elif type(k) == str and k == "pixels":
|
137 |
-
vs = image_utils.pixel_vendi_score(
|
138 |
elif type(k) == str and k == "image_embeddings":
|
139 |
vs = image_utils.embedding_vendi_score(
|
140 |
-
|
141 |
batch_size=batch_size,
|
142 |
device=device,
|
143 |
model=model,
|
144 |
transform=transform,
|
145 |
)
|
|
|
|
|
146 |
else:
|
147 |
-
|
148 |
return {"VS": vs}
|
|
|
85 |
inputs_description=_KWARGS_DESCRIPTION,
|
86 |
features=datasets.Features(
|
87 |
{
|
88 |
+
"sents": datasets.Value("string"),
|
89 |
+
"imgs": datasets.Image,
|
90 |
+
"X": datasets.Array2D,
|
91 |
+
"K": datasets.Array2D,
|
92 |
}
|
93 |
),
|
94 |
homepage="http://github.com/Vertaix/Vendi-Score",
|
|
|
103 |
|
104 |
def _compute(
|
105 |
self,
|
106 |
+
sents=None,
|
107 |
+
imgs=None,
|
108 |
+
X=None,
|
109 |
+
K=None,
|
110 |
k="ngram_overlap",
|
111 |
score_K=False,
|
112 |
score_X=False,
|
|
|
121 |
device="cpu",
|
122 |
):
|
123 |
if score_K:
|
124 |
+
vs = vendi.score_K(K, normalize=normalize)
|
125 |
elif score_dual:
|
126 |
+
vs = vendi.score_dual(X, normalize=normalize)
|
127 |
elif score_X:
|
128 |
+
vs = vendi.score_X(X, normalize=normalize)
|
129 |
elif type(k) == str and k == "ngram_overlap":
|
130 |
+
vs = text_utils.ngram_vendi_score(sents, ns=ns, tokenizer=tokenizer)
|
|
|
|
|
131 |
elif type(k) == str and k == "text_embeddings":
|
132 |
vs = text_utils.embedding_vendi_score(
|
133 |
+
sents,
|
134 |
model=model,
|
135 |
tokenizer=tokenizer,
|
136 |
batch_size=batch_size,
|
|
|
138 |
model_path=model_path,
|
139 |
)
|
140 |
elif type(k) == str and k == "pixels":
|
141 |
+
vs = image_utils.pixel_vendi_score(imgs)
|
142 |
elif type(k) == str and k == "image_embeddings":
|
143 |
vs = image_utils.embedding_vendi_score(
|
144 |
+
imgs,
|
145 |
batch_size=batch_size,
|
146 |
device=device,
|
147 |
model=model,
|
148 |
transform=transform,
|
149 |
)
|
150 |
+
elif sents is not None or imgs is not None or X is not None:
|
151 |
+
vs = vendi.score(sents or imgs or X, k)
|
152 |
else:
|
153 |
+
raise ValueError(f"Must provide one of `sents` or `imgs` or `X`.")
|
154 |
return {"VS": vs}
|