kcz358 commited on
Commit
1d06677
·
1 Parent(s): c64cd4d

Update image visualization

Browse files
Files changed (7) hide show
  1. .gitignore +1 -0
  2. Makefile +7 -1
  3. app.py +117 -41
  4. assets/greedy.jpg +0 -0
  5. assets/railway.jpg +0 -0
  6. assets/sunglasses.jpg +0 -0
  7. requirements.txt +2 -1
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
 
2
  __pycache__
 
 
1
 
2
  __pycache__
3
+ .vscode
Makefile CHANGED
@@ -1,4 +1,4 @@
1
- .PHONY: style format
2
 
3
 
4
  style:
@@ -11,3 +11,9 @@ quality:
11
  python -m black --check --line-length 119 .
12
  python -m isort --check-only .
13
  ruff check .
 
 
 
 
 
 
 
1
+ .PHONY: style format start clean
2
 
3
 
4
  style:
 
11
  python -m black --check --line-length 119 .
12
  python -m isort --check-only .
13
  ruff check .
14
+
15
+ start:
16
+ gradio app.py
17
+
18
+ clean:
19
+ ps aux | grep "app" | grep -v "grep" | awk '{print $$2}' | xargs kill -9
app.py CHANGED
@@ -1,45 +1,10 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
  from sae_auto_interp.sae import Sae
4
-
5
- """
6
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
7
- """
8
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
9
-
10
-
11
- def respond(
12
- message,
13
- history: list[tuple[str, str]],
14
- system_message,
15
- max_tokens,
16
- temperature,
17
- top_p,
18
- ):
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- for val in history:
22
- if val[0]:
23
- messages.append({"role": "user", "content": val[0]})
24
- if val[1]:
25
- messages.append({"role": "assistant", "content": val[1]})
26
-
27
- messages.append({"role": "user", "content": message})
28
-
29
- response = ""
30
-
31
- for message in client.chat_completion(
32
- messages,
33
- max_tokens=max_tokens,
34
- stream=True,
35
- temperature=temperature,
36
- top_p=top_p,
37
- ):
38
- token = message.choices[0].delta.content
39
-
40
- response += token
41
- yield response
42
-
43
 
44
  CITATION_BUTTON_TEXT = """
45
  @misc{zhang2024largemultimodalmodelsinterpret,
@@ -53,6 +18,84 @@ CITATION_BUTTON_TEXT = """
53
  }
54
  """
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  with gr.Blocks() as demo:
58
  gr.Markdown(
@@ -65,7 +108,30 @@ with gr.Blocks() as demo:
65
 
66
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
67
  with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0):
68
- image = gr.Image()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  with gr.TabItem("Steering Model", elem_id="steering", id=2):
71
  chatbot = gr.Chatbot()
@@ -76,4 +142,14 @@ with gr.Blocks() as demo:
76
 
77
 
78
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
79
  demo.launch()
 
1
  import gradio as gr
 
2
  from sae_auto_interp.sae import Sae
3
+ from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae
4
+ from sae_auto_interp.features.features import upsample_mask
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  CITATION_BUTTON_TEXT = """
10
  @misc{zhang2024largemultimodalmodelsinterpret,
 
18
  }
19
  """
20
 
21
+ cached_tensor = None
22
+ topk_indices = None
23
+
24
+ sunglasses_file_path = "assets/sunglasses.jpg"
25
+ greedy_file_path = "assets/greedy.jpg"
26
+ railway_file_path = "assets/railway.jpg"
27
+
28
+
29
+ def generate_activations(image):
30
+ prompt = "<image>"
31
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
32
+ global cached_tensor, topk_indices
33
+
34
+ def hook(module: torch.nn.Module, _, outputs):
35
+ global cached_tensor, topk_indices
36
+ # Maybe unpack tuple outputs
37
+ if isinstance(outputs, tuple):
38
+ unpack_outputs = list(outputs)
39
+ else:
40
+ unpack_outputs = list(outputs)
41
+ latents = sae.pre_acts(unpack_outputs[0])
42
+ # When the tokenizer is llama and text is None (image only)
43
+ # I skip the first bos tokens
44
+ if "llama" in tokenizer.name_or_path:
45
+ latents = latents[:, 1:, :]
46
+
47
+ topk = torch.topk(
48
+ latents, k=sae.cfg.k, dim=-1
49
+ )
50
+ # make all other values 0
51
+ result = torch.zeros_like(latents)
52
+ # results (bs, seq, num_latents)
53
+ result.scatter_(-1, topk.indices, topk.values)
54
+ cached_tensor = result.detach().cpu()
55
+ topk_indices = (
56
+ latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
57
+ )
58
+
59
+ handles = [hooked_module.register_forward_hook(hook)]
60
+ try:
61
+ with torch.no_grad():
62
+ outputs = model(
63
+ input_ids=inputs["input_ids"].to("cuda"),
64
+ pixel_values=inputs["pixel_values"].to("cuda"),
65
+ image_sizes=inputs["image_sizes"].to("cuda"),
66
+ attention_mask=inputs["attention_mask"].to("cuda"),
67
+ )
68
+ finally:
69
+ for handle in handles:
70
+ handle.remove()
71
+
72
+ print(cached_tensor.shape)
73
+ torch.cuda.empty_cache()
74
+ return topk_indices
75
+
76
+
77
+ def visualize_activations(image, feature_num):
78
+ base_img_tokens = 576
79
+ patch_size = 24
80
+ # Using Cached tensor
81
+ # select the feature_num-th feature
82
+ # Then keeping the first 576 tokens
83
+ base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size)
84
+
85
+ upsampled_image_mask = upsample_mask(base_image_activations, (336, 336))
86
+
87
+
88
+ background = Image.new("L", (336, 336), 0).convert("RGB")
89
+
90
+ # Somehow as I looked closer into the llava-hf preprocessing code,
91
+ # I found out that they don't use the padded image as the base image feat
92
+ # but use the simple resized image. This is different from original llava but
93
+ # we align to llava-hf for now as we use llava-hf
94
+ resized_image = image.resize((336, 336))
95
+ activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB")
96
+
97
+ return activation_images
98
+
99
 
100
  with gr.Blocks() as demo:
101
  gr.Markdown(
 
108
 
109
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
110
  with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0):
111
+ with gr.Row():
112
+ with gr.Column():
113
+ image = gr.Image(type="pil", interactive=True, label="Sample Image")
114
+ topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features")
115
+ with gr.Row():
116
+ clear_btn = gr.ClearButton([image, topk_features], value="Clear")
117
+ submit_btn = gr.Button("Submit", variant="primary")
118
+ submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features])
119
+ with gr.Column():
120
+ output = gr.Image(label="Activation Visualization")
121
+ feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
122
+ visualize_btn = gr.Button("Visualize", variant="primary")
123
+ visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output])
124
+
125
+ dummy_text = gr.Textbox(visible=False, label="Explanation")
126
+ gr.Examples(
127
+ [
128
+ ["assets/sunglasses.jpg", 10, "Sunglasses"],
129
+ ["assets/greedy.jpg", 14, "Greedy eating"],
130
+ ["assets/railway.jpg", 28, "Railway tracks"],
131
+ ],
132
+ inputs=[image, feature_num, dummy_text],
133
+ label="Examples",
134
+ )
135
 
136
  with gr.TabItem("Steering Model", elem_id="steering", id=2):
137
  chatbot = gr.Chatbot()
 
142
 
143
 
144
  if __name__ == "__main__":
145
+ tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
146
+ sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
147
+ model, processor = maybe_load_llava_model(
148
+ "llava-hf/llama3-llava-next-8b-hf",
149
+ rank=0,
150
+ dtype=torch.bfloat16,
151
+ hf_token=None
152
+ )
153
+ hooked_module = model.language_model.get_submodule("model.layers.24")
154
+
155
  demo.launch()
assets/greedy.jpg ADDED
assets/railway.jpg ADDED
assets/sunglasses.jpg ADDED
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  huggingface_hub==0.25.2
2
  gradio
3
- sae_auto_interp @ git+https://github.com/EvolvingLMMs-Lab/multimodal-sae
 
 
1
  huggingface_hub==0.25.2
2
  gradio
3
+ sae_auto_interp @ git+https://github.com/EvolvingLMMs-Lab/multimodal-sae
4
+ fastapi==0.112.2