Spaces:
Sleeping
Sleeping
use n random target classes to maximise activation for
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384'
|
|
11 |
model.to(device)
|
12 |
model.eval()
|
13 |
|
14 |
-
def process_image(input_image, learning_rate, iterations):
|
15 |
if input_image is None:
|
16 |
return None
|
17 |
|
@@ -25,13 +25,18 @@ def process_image(input_image, learning_rate, iterations):
|
|
25 |
pixel_values = pixel_values.to(device)
|
26 |
pixel_values.requires_grad_(True)
|
27 |
|
|
|
|
|
|
|
|
|
28 |
for iteration in range(int(iterations)):
|
29 |
model.zero_grad()
|
30 |
if pixel_values.grad is not None:
|
31 |
pixel_values.grad.data.zero_()
|
32 |
|
33 |
final_activations = get_encoder_activations(pixel_values)
|
34 |
-
|
|
|
35 |
target_sum.backward()
|
36 |
|
37 |
with torch.no_grad():
|
@@ -48,7 +53,9 @@ iface = gr.Interface(
|
|
48 |
inputs=[
|
49 |
gr.Image(type="pil"),
|
50 |
gr.Number(value=4.0, label="Learning Rate"),
|
51 |
-
gr.Number(value=4, label="Iterations")
|
|
|
|
|
52 |
],
|
53 |
outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
|
54 |
)
|
|
|
11 |
model.to(device)
|
12 |
model.eval()
|
13 |
|
14 |
+
def process_image(input_image, learning_rate, iterations, n_targets, seed):
|
15 |
if input_image is None:
|
16 |
return None
|
17 |
|
|
|
25 |
pixel_values = pixel_values.to(device)
|
26 |
pixel_values.requires_grad_(True)
|
27 |
|
28 |
+
|
29 |
+
torch.manual_seed(int(seed))
|
30 |
+
random_indices = torch.randperm(1000)[:n_targets].to(pixel_values.device)
|
31 |
+
|
32 |
for iteration in range(int(iterations)):
|
33 |
model.zero_grad()
|
34 |
if pixel_values.grad is not None:
|
35 |
pixel_values.grad.data.zero_()
|
36 |
|
37 |
final_activations = get_encoder_activations(pixel_values)
|
38 |
+
logits = model.classifier(final_activations[0])
|
39 |
+
target_sum = logits[random_indices].sum()
|
40 |
target_sum.backward()
|
41 |
|
42 |
with torch.no_grad():
|
|
|
53 |
inputs=[
|
54 |
gr.Image(type="pil"),
|
55 |
gr.Number(value=4.0, label="Learning Rate"),
|
56 |
+
gr.Number(value=4, label="Iterations"),
|
57 |
+
gr.Number(value=420, label="Seed"),
|
58 |
+
gr.Number(value=50, minimum=1, maximum=1000, label="n target classes"),
|
59 |
],
|
60 |
outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
|
61 |
)
|