kcz358 commited on
Commit
1fe4523
·
1 Parent(s): 1d06677

Add steering models

Browse files
Files changed (3) hide show
  1. app.py +98 -3
  2. assets/happy.jpg +0 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -24,6 +24,7 @@ topk_indices = None
24
  sunglasses_file_path = "assets/sunglasses.jpg"
25
  greedy_file_path = "assets/greedy.jpg"
26
  railway_file_path = "assets/railway.jpg"
 
27
 
28
 
29
  def generate_activations(image):
@@ -69,7 +70,6 @@ def generate_activations(image):
69
  for handle in handles:
70
  handle.remove()
71
 
72
- print(cached_tensor.shape)
73
  torch.cuda.empty_cache()
74
  return topk_indices
75
 
@@ -96,6 +96,77 @@ def visualize_activations(image, feature_num):
96
 
97
  return activation_images
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  with gr.Blocks() as demo:
101
  gr.Markdown(
@@ -134,7 +205,31 @@ with gr.Blocks() as demo:
134
  )
135
 
136
  with gr.TabItem("Steering Model", elem_id="steering", id=2):
137
- chatbot = gr.Chatbot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  with gr.Row():
140
  with gr.Accordion("📙 Citation", open=False):
@@ -147,7 +242,7 @@ if __name__ == "__main__":
147
  model, processor = maybe_load_llava_model(
148
  "llava-hf/llama3-llava-next-8b-hf",
149
  rank=0,
150
- dtype=torch.bfloat16,
151
  hf_token=None
152
  )
153
  hooked_module = model.language_model.get_submodule("model.layers.24")
 
24
  sunglasses_file_path = "assets/sunglasses.jpg"
25
  greedy_file_path = "assets/greedy.jpg"
26
  railway_file_path = "assets/railway.jpg"
27
+ happy_file_path = "assets/happy.jpg"
28
 
29
 
30
  def generate_activations(image):
 
70
  for handle in handles:
71
  handle.remove()
72
 
 
73
  torch.cuda.empty_cache()
74
  return topk_indices
75
 
 
96
 
97
  return activation_images
98
 
99
+ def clamp_features_max(
100
+ sae: Sae, feature: int, hooked_module: torch.nn.Module, k: float = 10
101
+ ):
102
+ def hook(module: torch.nn.Module, _, outputs):
103
+ # Maybe unpack tuple outputs
104
+ if isinstance(outputs, tuple):
105
+ unpack_outputs = list(outputs)
106
+ else:
107
+ unpack_outputs = list(outputs)
108
+ latents = sae.pre_acts(unpack_outputs[0])
109
+ # Only clamp the feature for the first forward
110
+ if latents.shape[1] != 1:
111
+ latents[:, :, feature] = k
112
+ top_acts, top_indices = sae.select_topk(latents)
113
+ sae_out = sae.decode(top_acts[0], top_indices[0]).unsqueeze(0).to(torch.float16)
114
+ unpack_outputs[0] = sae_out
115
+ if isinstance(outputs, tuple):
116
+ outputs = tuple(unpack_outputs)
117
+ else:
118
+ outputs = unpack_outputs[0]
119
+ return outputs
120
+
121
+ handles = [hooked_module.register_forward_hook(hook)]
122
+
123
+ return handles
124
+
125
+ def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history):
126
+ if not isinstance(feature_idx, int):
127
+ feature_idx = int(feature_idx)
128
+ if not isinstance(feature_strength, float):
129
+ feature_strength = float(feature_strength)
130
+
131
+ conversation = [
132
+ {
133
+ "role": "user",
134
+ "content": [
135
+ {"type": "text", "text": text},
136
+ ],
137
+ },
138
+ ]
139
+ if image is not None:
140
+ conversation[0]["content"].append(
141
+ {"type": "image"},
142
+ )
143
+
144
+ chat_history.append({"role": "user", "content": gr.Image(value=image)})
145
+ chat_history.append({"role": "user", "content": text})
146
+
147
+
148
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
149
+
150
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
151
+ handles = clamp_features_max(sae, feature_idx, hooked_module, k=feature_strength)
152
+ try:
153
+ with torch.no_grad():
154
+ output = model.generate(**inputs, max_new_tokens=512)
155
+ cont = output[:, inputs["input_ids"].shape[-1] :]
156
+ finally:
157
+ for handle in handles:
158
+ handle.remove()
159
+
160
+ text = processor.batch_decode(cont, skip_special_tokens=True)[0]
161
+ chat_history.append(
162
+ {
163
+ "role": "assistant",
164
+ "content": text,
165
+ }
166
+ )
167
+
168
+ return chat_history
169
+
170
 
171
  with gr.Blocks() as demo:
172
  gr.Markdown(
 
205
  )
206
 
207
  with gr.TabItem("Steering Model", elem_id="steering", id=2):
208
+ chatbot = gr.Chatbot(type="messages")
209
+ with gr.Row(variant="compact", equal_height=True):
210
+ feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
211
+ feature_strength = gr.Number(value=50, label="Feature Strength", interactive=True)
212
+ with gr.Row(variant="compact", equal_height=True):
213
+ text_input = gr.Textbox(label="Text Input", placeholder="Type here", interactive=True)
214
+ image_input = gr.Image(type="pil", label="Image Input", interactive=True, height=250)
215
+ with gr.Row():
216
+ chatbot_clear = gr.ClearButton([text_input, image_input, chatbot], value="Clear")
217
+ chatbot_submit = gr.Button("Submit", variant="primary")
218
+ chatbot_submit.click(
219
+ generate_with_clamp,
220
+ inputs=[feature_num, feature_strength, text_input, image_input, chatbot],
221
+ outputs=[chatbot],
222
+ )
223
+ gr.Examples(
224
+ [
225
+ [19379, 50, "Look at this image, what is your feeling right now?", happy_file_path],
226
+ [14, 50, "Tell me a story about Alice and Bob", None],
227
+ [108692, 50, "What is your feeling right now?", None],
228
+ ],
229
+ inputs=[feature_num, feature_strength, text_input, image_input],
230
+ label="Examples",
231
+ )
232
+
233
 
234
  with gr.Row():
235
  with gr.Accordion("📙 Citation", open=False):
 
242
  model, processor = maybe_load_llava_model(
243
  "llava-hf/llama3-llava-next-8b-hf",
244
  rank=0,
245
+ dtype=torch.float16,
246
  hf_token=None
247
  )
248
  hooked_module = model.language_model.get_submodule("model.layers.24")
assets/happy.jpg ADDED
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  huggingface_hub==0.25.2
2
  gradio
3
  sae_auto_interp @ git+https://github.com/EvolvingLMMs-Lab/multimodal-sae
4
- fastapi==0.112.2
 
 
 
1
  huggingface_hub==0.25.2
2
  gradio
3
  sae_auto_interp @ git+https://github.com/EvolvingLMMs-Lab/multimodal-sae
4
+ fastapi==0.112.2
5
+ gradio==4.44.1
6
+ httpx==0.23.3