Multimodal-SAE / app.py
kcz358's picture
Add steering models
1fe4523
raw
history blame
9.83 kB
import gradio as gr
from sae_auto_interp.sae import Sae
from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae
from sae_auto_interp.features.features import upsample_mask
import torch
from transformers import AutoTokenizer
from PIL import Image
CITATION_BUTTON_TEXT = """
@misc{zhang2024largemultimodalmodelsinterpret,
title={Large Multi-modal Models Can Interpret Features in Large Multi-modal Models},
author={Kaichen Zhang and Yifei Shen and Bo Li and Ziwei Liu},
year={2024},
eprint={2411.14982},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.14982},
}
"""
cached_tensor = None
topk_indices = None
sunglasses_file_path = "assets/sunglasses.jpg"
greedy_file_path = "assets/greedy.jpg"
railway_file_path = "assets/railway.jpg"
happy_file_path = "assets/happy.jpg"
def generate_activations(image):
prompt = "<image>"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
global cached_tensor, topk_indices
def hook(module: torch.nn.Module, _, outputs):
global cached_tensor, topk_indices
# Maybe unpack tuple outputs
if isinstance(outputs, tuple):
unpack_outputs = list(outputs)
else:
unpack_outputs = list(outputs)
latents = sae.pre_acts(unpack_outputs[0])
# When the tokenizer is llama and text is None (image only)
# I skip the first bos tokens
if "llama" in tokenizer.name_or_path:
latents = latents[:, 1:, :]
topk = torch.topk(
latents, k=sae.cfg.k, dim=-1
)
# make all other values 0
result = torch.zeros_like(latents)
# results (bs, seq, num_latents)
result.scatter_(-1, topk.indices, topk.values)
cached_tensor = result.detach().cpu()
topk_indices = (
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
)
handles = [hooked_module.register_forward_hook(hook)]
try:
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"].to("cuda"),
pixel_values=inputs["pixel_values"].to("cuda"),
image_sizes=inputs["image_sizes"].to("cuda"),
attention_mask=inputs["attention_mask"].to("cuda"),
)
finally:
for handle in handles:
handle.remove()
torch.cuda.empty_cache()
return topk_indices
def visualize_activations(image, feature_num):
base_img_tokens = 576
patch_size = 24
# Using Cached tensor
# select the feature_num-th feature
# Then keeping the first 576 tokens
base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size)
upsampled_image_mask = upsample_mask(base_image_activations, (336, 336))
background = Image.new("L", (336, 336), 0).convert("RGB")
# Somehow as I looked closer into the llava-hf preprocessing code,
# I found out that they don't use the padded image as the base image feat
# but use the simple resized image. This is different from original llava but
# we align to llava-hf for now as we use llava-hf
resized_image = image.resize((336, 336))
activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB")
return activation_images
def clamp_features_max(
sae: Sae, feature: int, hooked_module: torch.nn.Module, k: float = 10
):
def hook(module: torch.nn.Module, _, outputs):
# Maybe unpack tuple outputs
if isinstance(outputs, tuple):
unpack_outputs = list(outputs)
else:
unpack_outputs = list(outputs)
latents = sae.pre_acts(unpack_outputs[0])
# Only clamp the feature for the first forward
if latents.shape[1] != 1:
latents[:, :, feature] = k
top_acts, top_indices = sae.select_topk(latents)
sae_out = sae.decode(top_acts[0], top_indices[0]).unsqueeze(0).to(torch.float16)
unpack_outputs[0] = sae_out
if isinstance(outputs, tuple):
outputs = tuple(unpack_outputs)
else:
outputs = unpack_outputs[0]
return outputs
handles = [hooked_module.register_forward_hook(hook)]
return handles
def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history):
if not isinstance(feature_idx, int):
feature_idx = int(feature_idx)
if not isinstance(feature_strength, float):
feature_strength = float(feature_strength)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
},
]
if image is not None:
conversation[0]["content"].append(
{"type": "image"},
)
chat_history.append({"role": "user", "content": gr.Image(value=image)})
chat_history.append({"role": "user", "content": text})
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
handles = clamp_features_max(sae, feature_idx, hooked_module, k=feature_strength)
try:
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=512)
cont = output[:, inputs["input_ids"].shape[-1] :]
finally:
for handle in handles:
handle.remove()
text = processor.batch_decode(cont, skip_special_tokens=True)[0]
chat_history.append(
{
"role": "assistant",
"content": text,
}
)
return chat_history
with gr.Blocks() as demo:
gr.Markdown(
"""
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
πŸ” [ArXiv Paper](https://arxiv.org/abs/2411.14982) | 🏠 [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | πŸ€— [Huggingface Collections](https://huggingface.co./collections/lmms-lab/llava-sae-674026e4e7bc8c29c70bc3a3)
"""
)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0):
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", interactive=True, label="Sample Image")
topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features")
with gr.Row():
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
submit_btn = gr.Button("Submit", variant="primary")
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features])
with gr.Column():
output = gr.Image(label="Activation Visualization")
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
visualize_btn = gr.Button("Visualize", variant="primary")
visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output])
dummy_text = gr.Textbox(visible=False, label="Explanation")
gr.Examples(
[
["assets/sunglasses.jpg", 10, "Sunglasses"],
["assets/greedy.jpg", 14, "Greedy eating"],
["assets/railway.jpg", 28, "Railway tracks"],
],
inputs=[image, feature_num, dummy_text],
label="Examples",
)
with gr.TabItem("Steering Model", elem_id="steering", id=2):
chatbot = gr.Chatbot(type="messages")
with gr.Row(variant="compact", equal_height=True):
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
feature_strength = gr.Number(value=50, label="Feature Strength", interactive=True)
with gr.Row(variant="compact", equal_height=True):
text_input = gr.Textbox(label="Text Input", placeholder="Type here", interactive=True)
image_input = gr.Image(type="pil", label="Image Input", interactive=True, height=250)
with gr.Row():
chatbot_clear = gr.ClearButton([text_input, image_input, chatbot], value="Clear")
chatbot_submit = gr.Button("Submit", variant="primary")
chatbot_submit.click(
generate_with_clamp,
inputs=[feature_num, feature_strength, text_input, image_input, chatbot],
outputs=[chatbot],
)
gr.Examples(
[
[19379, 50, "Look at this image, what is your feeling right now?", happy_file_path],
[14, 50, "Tell me a story about Alice and Bob", None],
[108692, 50, "What is your feeling right now?", None],
],
inputs=[feature_num, feature_strength, text_input, image_input],
label="Examples",
)
with gr.Row():
with gr.Accordion("πŸ“™ Citation", open=False):
gr.Markdown("```bib\n" + CITATION_BUTTON_TEXT + "\n```")
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
model, processor = maybe_load_llava_model(
"llava-hf/llama3-llava-next-8b-hf",
rank=0,
dtype=torch.float16,
hf_token=None
)
hooked_module = model.language_model.get_submodule("model.layers.24")
demo.launch()