Spaces:
Build error
Build error
add option to change saliency config
Browse files
app.py
CHANGED
@@ -6,9 +6,6 @@ from matplotlib import pyplot as plt
|
|
6 |
import io
|
7 |
from PIL import Image
|
8 |
|
9 |
-
# TODO
|
10 |
-
# - [ ] Expose more of multi-scale relevancy extractor API
|
11 |
-
|
12 |
|
13 |
def plot_to_png(fig):
|
14 |
buf = io.BytesIO()
|
@@ -18,7 +15,7 @@ def plot_to_png(fig):
|
|
18 |
return img
|
19 |
|
20 |
|
21 |
-
def generate_relevancy(img: np.array, labels: str, prompt: str):
|
22 |
labels = labels.split(",")
|
23 |
prompts = [prompt]
|
24 |
assert img.dtype == np.uint8
|
@@ -28,7 +25,7 @@ def generate_relevancy(img: np.array, labels: str, prompt: str):
|
|
28 |
img=img,
|
29 |
text_labels=np.array(labels),
|
30 |
prompts=prompts,
|
31 |
-
**saliency_configs[
|
32 |
)[0]
|
33 |
grads -= grads.mean(axis=0)
|
34 |
grads = grads.cpu().numpy()
|
@@ -59,6 +56,7 @@ iface = gr.Interface(
|
|
59 |
gr.Image(type="numpy"),
|
60 |
gr.Textbox(),
|
61 |
gr.Textbox(),
|
|
|
62 |
],
|
63 |
outputs=gr.Image(type="numpy"),
|
64 |
examples=[
|
|
|
6 |
import io
|
7 |
from PIL import Image
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
def plot_to_png(fig):
|
11 |
buf = io.BytesIO()
|
|
|
15 |
return img
|
16 |
|
17 |
|
18 |
+
def generate_relevancy(img: np.array, labels: str, prompt: str, saliency_config: str):
|
19 |
labels = labels.split(",")
|
20 |
prompts = [prompt]
|
21 |
assert img.dtype == np.uint8
|
|
|
25 |
img=img,
|
26 |
text_labels=np.array(labels),
|
27 |
prompts=prompts,
|
28 |
+
**saliency_configs[saliency_config](h),
|
29 |
)[0]
|
30 |
grads -= grads.mean(axis=0)
|
31 |
grads = grads.cpu().numpy()
|
|
|
56 |
gr.Image(type="numpy"),
|
57 |
gr.Textbox(),
|
58 |
gr.Textbox(),
|
59 |
+
gr.Dropdown(value="ours", choices=["ours", "chefer_et_al"]),
|
60 |
],
|
61 |
outputs=gr.Image(type="numpy"),
|
62 |
examples=[
|