huy-ha commited on
Commit
995510e
·
1 Parent(s): 4181a28

add option to change saliency config

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -6,9 +6,6 @@ from matplotlib import pyplot as plt
6
  import io
7
  from PIL import Image
8
 
9
- # TODO
10
- # - [ ] Expose more of multi-scale relevancy extractor API
11
-
12
 
13
  def plot_to_png(fig):
14
  buf = io.BytesIO()
@@ -18,7 +15,7 @@ def plot_to_png(fig):
18
  return img
19
 
20
 
21
- def generate_relevancy(img: np.array, labels: str, prompt: str):
22
  labels = labels.split(",")
23
  prompts = [prompt]
24
  assert img.dtype == np.uint8
@@ -28,7 +25,7 @@ def generate_relevancy(img: np.array, labels: str, prompt: str):
28
  img=img,
29
  text_labels=np.array(labels),
30
  prompts=prompts,
31
- **saliency_configs["ours"](h),
32
  )[0]
33
  grads -= grads.mean(axis=0)
34
  grads = grads.cpu().numpy()
@@ -59,6 +56,7 @@ iface = gr.Interface(
59
  gr.Image(type="numpy"),
60
  gr.Textbox(),
61
  gr.Textbox(),
 
62
  ],
63
  outputs=gr.Image(type="numpy"),
64
  examples=[
 
6
  import io
7
  from PIL import Image
8
 
 
 
 
9
 
10
  def plot_to_png(fig):
11
  buf = io.BytesIO()
 
15
  return img
16
 
17
 
18
+ def generate_relevancy(img: np.array, labels: str, prompt: str, saliency_config: str):
19
  labels = labels.split(",")
20
  prompts = [prompt]
21
  assert img.dtype == np.uint8
 
25
  img=img,
26
  text_labels=np.array(labels),
27
  prompts=prompts,
28
+ **saliency_configs[saliency_config](h),
29
  )[0]
30
  grads -= grads.mean(axis=0)
31
  grads = grads.cpu().numpy()
 
56
  gr.Image(type="numpy"),
57
  gr.Textbox(),
58
  gr.Textbox(),
59
+ gr.Dropdown(value="ours", choices=["ours", "chefer_et_al"]),
60
  ],
61
  outputs=gr.Image(type="numpy"),
62
  examples=[