Spaces:
Sleeping
Sleeping
remove tv loss, using BCE loss now
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
import numpy as np
|
4 |
from transformers import ViTImageProcessor, ViTForImageClassification
|
5 |
from PIL import Image
|
@@ -16,12 +17,7 @@ def get_encoder_activations(x):
|
|
16 |
final_activations = encoder_output.last_hidden_state[:,0,:]
|
17 |
return final_activations
|
18 |
|
19 |
-
def
|
20 |
-
pixel_dif1 = img[:, :, 1:, :] - img[:, :, :-1, :]
|
21 |
-
pixel_dif2 = img[:, :, :, 1:] - img[:, :, :, :-1]
|
22 |
-
return (torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2)))
|
23 |
-
|
24 |
-
def process_image(input_image, learning_rate, tv_weight, iterations, n_targets, seed):
|
25 |
if input_image is None:
|
26 |
return None
|
27 |
|
@@ -32,20 +28,22 @@ def process_image(input_image, learning_rate, tv_weight, iterations, n_targets,
|
|
32 |
|
33 |
|
34 |
torch.manual_seed(int(seed))
|
35 |
-
|
|
|
|
|
36 |
|
37 |
for iteration in range(int(iterations)):
|
38 |
model.zero_grad()
|
39 |
if pixel_values.grad is not None:
|
40 |
pixel_values.grad.data.zero_()
|
41 |
|
42 |
-
final_activations = get_encoder_activations(pixel_values)
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
-
original_loss
|
46 |
-
tv_loss = total_variation_loss(pixel_values)
|
47 |
-
total_loss = original_loss - tv_weight * tv_loss
|
48 |
-
total_loss.backward()
|
49 |
|
50 |
with torch.no_grad():
|
51 |
pixel_values.data += learning_rate * pixel_values.grad.data
|
@@ -60,11 +58,10 @@ iface = gr.Interface(
|
|
60 |
fn=process_image,
|
61 |
inputs=[
|
62 |
gr.Image(type="pil"),
|
63 |
-
gr.Number(value=
|
64 |
-
gr.Number(value=
|
65 |
-
gr.Number(value=4, minimum=1, label="Iterations"),
|
66 |
gr.Number(value=420, minimum=0, label="Seed"),
|
67 |
-
gr.Number(value=
|
68 |
],
|
69 |
outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
|
70 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from torch.nn import BCEWithLogitsLoss
|
4 |
import numpy as np
|
5 |
from transformers import ViTImageProcessor, ViTForImageClassification
|
6 |
from PIL import Image
|
|
|
17 |
final_activations = encoder_output.last_hidden_state[:,0,:]
|
18 |
return final_activations
|
19 |
|
20 |
+
def process_image(input_image, learning_rate, iterations, n_targets, seed):
|
|
|
|
|
|
|
|
|
|
|
21 |
if input_image is None:
|
22 |
return None
|
23 |
|
|
|
28 |
|
29 |
|
30 |
torch.manual_seed(int(seed))
|
31 |
+
random_one_logits = torch.zeros(1000)
|
32 |
+
random_one_logits[torch.randperm(1000)[:n_targets]] = 1
|
33 |
+
random_one_logits = random_one_logits.to(pixel_values.device)
|
34 |
|
35 |
for iteration in range(int(iterations)):
|
36 |
model.zero_grad()
|
37 |
if pixel_values.grad is not None:
|
38 |
pixel_values.grad.data.zero_()
|
39 |
|
40 |
+
final_activations = get_encoder_activations(pixel_values.to('cuda'))
|
41 |
+
|
42 |
+
logits = model.classifier(final_activations[0]).to(pixel_values.device)
|
43 |
+
|
44 |
+
original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits)
|
45 |
|
46 |
+
original_loss.backward()
|
|
|
|
|
|
|
47 |
|
48 |
with torch.no_grad():
|
49 |
pixel_values.data += learning_rate * pixel_values.grad.data
|
|
|
58 |
fn=process_image,
|
59 |
inputs=[
|
60 |
gr.Image(type="pil"),
|
61 |
+
gr.Number(value=1.0, minimum=0, label="Learning Rate"),
|
62 |
+
gr.Number(value=2, minimum=1, label="Iterations"),
|
|
|
63 |
gr.Number(value=420, minimum=0, label="Seed"),
|
64 |
+
gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
|
65 |
],
|
66 |
outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
|
67 |
)
|