Multimodal-SAE / app.py
kcz358's picture
Update image visualization
1d06677
raw
history blame
6.08 kB
import gradio as gr
from sae_auto_interp.sae import Sae
from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae
from sae_auto_interp.features.features import upsample_mask
import torch
from transformers import AutoTokenizer
from PIL import Image
CITATION_BUTTON_TEXT = """
@misc{zhang2024largemultimodalmodelsinterpret,
title={Large Multi-modal Models Can Interpret Features in Large Multi-modal Models},
author={Kaichen Zhang and Yifei Shen and Bo Li and Ziwei Liu},
year={2024},
eprint={2411.14982},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.14982},
}
"""
cached_tensor = None
topk_indices = None
sunglasses_file_path = "assets/sunglasses.jpg"
greedy_file_path = "assets/greedy.jpg"
railway_file_path = "assets/railway.jpg"
def generate_activations(image):
prompt = "<image>"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
global cached_tensor, topk_indices
def hook(module: torch.nn.Module, _, outputs):
global cached_tensor, topk_indices
# Maybe unpack tuple outputs
if isinstance(outputs, tuple):
unpack_outputs = list(outputs)
else:
unpack_outputs = list(outputs)
latents = sae.pre_acts(unpack_outputs[0])
# When the tokenizer is llama and text is None (image only)
# I skip the first bos tokens
if "llama" in tokenizer.name_or_path:
latents = latents[:, 1:, :]
topk = torch.topk(
latents, k=sae.cfg.k, dim=-1
)
# make all other values 0
result = torch.zeros_like(latents)
# results (bs, seq, num_latents)
result.scatter_(-1, topk.indices, topk.values)
cached_tensor = result.detach().cpu()
topk_indices = (
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
)
handles = [hooked_module.register_forward_hook(hook)]
try:
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"].to("cuda"),
pixel_values=inputs["pixel_values"].to("cuda"),
image_sizes=inputs["image_sizes"].to("cuda"),
attention_mask=inputs["attention_mask"].to("cuda"),
)
finally:
for handle in handles:
handle.remove()
print(cached_tensor.shape)
torch.cuda.empty_cache()
return topk_indices
def visualize_activations(image, feature_num):
base_img_tokens = 576
patch_size = 24
# Using Cached tensor
# select the feature_num-th feature
# Then keeping the first 576 tokens
base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size)
upsampled_image_mask = upsample_mask(base_image_activations, (336, 336))
background = Image.new("L", (336, 336), 0).convert("RGB")
# Somehow as I looked closer into the llava-hf preprocessing code,
# I found out that they don't use the padded image as the base image feat
# but use the simple resized image. This is different from original llava but
# we align to llava-hf for now as we use llava-hf
resized_image = image.resize((336, 336))
activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB")
return activation_images
with gr.Blocks() as demo:
gr.Markdown(
"""
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
πŸ” [ArXiv Paper](https://arxiv.org/abs/2411.14982) | 🏠 [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | πŸ€— [Huggingface Collections](https://huggingface.co./collections/lmms-lab/llava-sae-674026e4e7bc8c29c70bc3a3)
"""
)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0):
with gr.Row():
with gr.Column():
image = gr.Image(type="pil", interactive=True, label="Sample Image")
topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features")
with gr.Row():
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
submit_btn = gr.Button("Submit", variant="primary")
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features])
with gr.Column():
output = gr.Image(label="Activation Visualization")
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
visualize_btn = gr.Button("Visualize", variant="primary")
visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output])
dummy_text = gr.Textbox(visible=False, label="Explanation")
gr.Examples(
[
["assets/sunglasses.jpg", 10, "Sunglasses"],
["assets/greedy.jpg", 14, "Greedy eating"],
["assets/railway.jpg", 28, "Railway tracks"],
],
inputs=[image, feature_num, dummy_text],
label="Examples",
)
with gr.TabItem("Steering Model", elem_id="steering", id=2):
chatbot = gr.Chatbot()
with gr.Row():
with gr.Accordion("πŸ“™ Citation", open=False):
gr.Markdown("```bib\n" + CITATION_BUTTON_TEXT + "\n```")
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
model, processor = maybe_load_llava_model(
"llava-hf/llama3-llava-next-8b-hf",
rank=0,
dtype=torch.bfloat16,
hf_token=None
)
hooked_module = model.language_model.get_submodule("model.layers.24")
demo.launch()