SoggyKiwi commited on
Commit
a4244e1
·
1 Parent(s): 2795721

use n random target classes to maximise activation for

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -11,7 +11,7 @@ model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384'
11
  model.to(device)
12
  model.eval()
13
 
14
- def process_image(input_image, learning_rate, iterations):
15
  if input_image is None:
16
  return None
17
 
@@ -25,13 +25,18 @@ def process_image(input_image, learning_rate, iterations):
25
  pixel_values = pixel_values.to(device)
26
  pixel_values.requires_grad_(True)
27
 
 
 
 
 
28
  for iteration in range(int(iterations)):
29
  model.zero_grad()
30
  if pixel_values.grad is not None:
31
  pixel_values.grad.data.zero_()
32
 
33
  final_activations = get_encoder_activations(pixel_values)
34
- target_sum = final_activations.sum()
 
35
  target_sum.backward()
36
 
37
  with torch.no_grad():
@@ -48,7 +53,9 @@ iface = gr.Interface(
48
  inputs=[
49
  gr.Image(type="pil"),
50
  gr.Number(value=4.0, label="Learning Rate"),
51
- gr.Number(value=4, label="Iterations")
 
 
52
  ],
53
  outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
54
  )
 
11
  model.to(device)
12
  model.eval()
13
 
14
+ def process_image(input_image, learning_rate, iterations, n_targets, seed):
15
  if input_image is None:
16
  return None
17
 
 
25
  pixel_values = pixel_values.to(device)
26
  pixel_values.requires_grad_(True)
27
 
28
+
29
+ torch.manual_seed(int(seed))
30
+ random_indices = torch.randperm(1000)[:n_targets].to(pixel_values.device)
31
+
32
  for iteration in range(int(iterations)):
33
  model.zero_grad()
34
  if pixel_values.grad is not None:
35
  pixel_values.grad.data.zero_()
36
 
37
  final_activations = get_encoder_activations(pixel_values)
38
+ logits = model.classifier(final_activations[0])
39
+ target_sum = logits[random_indices].sum()
40
  target_sum.backward()
41
 
42
  with torch.no_grad():
 
53
  inputs=[
54
  gr.Image(type="pil"),
55
  gr.Number(value=4.0, label="Learning Rate"),
56
+ gr.Number(value=4, label="Iterations"),
57
+ gr.Number(value=420, label="Seed"),
58
+ gr.Number(value=50, minimum=1, maximum=1000, label="n target classes"),
59
  ],
60
  outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
61
  )