huy-ha commited on
Commit
ac86774
·
1 Parent(s): 995510e

option to subtract mean

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -15,7 +15,9 @@ def plot_to_png(fig):
15
  return img
16
 
17
 
18
- def generate_relevancy(img: np.array, labels: str, prompt: str, saliency_config: str):
 
 
19
  labels = labels.split(",")
20
  prompts = [prompt]
21
  assert img.dtype == np.uint8
@@ -27,11 +29,15 @@ def generate_relevancy(img: np.array, labels: str, prompt: str, saliency_config:
27
  prompts=prompts,
28
  **saliency_configs[saliency_config](h),
29
  )[0]
30
- grads -= grads.mean(axis=0)
 
31
  grads = grads.cpu().numpy()
32
  num_axes = int(np.ceil(np.sqrt(len(labels))))
33
  fig, axes = plt.subplots(num_axes, num_axes)
34
- axes = axes.flatten()
 
 
 
35
  vmin = 0.002
36
  cmap = plt.get_cmap("jet")
37
  vmax = 0.008
@@ -53,10 +59,15 @@ def generate_relevancy(img: np.array, labels: str, prompt: str, saliency_config:
53
  iface = gr.Interface(
54
  fn=generate_relevancy,
55
  inputs=[
56
- gr.Image(type="numpy"),
57
- gr.Textbox(),
58
- gr.Textbox(),
59
- gr.Dropdown(value="ours", choices=["ours", "chefer_et_al"]),
 
 
 
 
 
60
  ],
61
  outputs=gr.Image(type="numpy"),
62
  examples=[
 
15
  return img
16
 
17
 
18
+ def generate_relevancy(
19
+ img: np.array, labels: str, prompt: str, saliency_config: str, subtract_mean: bool
20
+ ):
21
  labels = labels.split(",")
22
  prompts = [prompt]
23
  assert img.dtype == np.uint8
 
29
  prompts=prompts,
30
  **saliency_configs[saliency_config](h),
31
  )[0]
32
+ if subtract_mean:
33
+ grads -= grads.mean(axis=0)
34
  grads = grads.cpu().numpy()
35
  num_axes = int(np.ceil(np.sqrt(len(labels))))
36
  fig, axes = plt.subplots(num_axes, num_axes)
37
+ if num_axes == 1:
38
+ axes = [axes]
39
+ else:
40
+ axes = axes.flatten()
41
  vmin = 0.002
42
  cmap = plt.get_cmap("jet")
43
  vmax = 0.008
 
59
  iface = gr.Interface(
60
  fn=generate_relevancy,
61
  inputs=[
62
+ gr.Image(type="numpy", label="Image"),
63
+ gr.Textbox(label="Labels (comma separated)"),
64
+ gr.Textbox(label="Prompt"),
65
+ gr.Dropdown(
66
+ value="ours",
67
+ choices=["ours", "chefer_et_al"],
68
+ label="Relevancy Configuration",
69
+ ),
70
+ gr.Checkbox(value=True, label="subtract mean"),
71
  ],
72
  outputs=gr.Image(type="numpy"),
73
  examples=[