import gradio as gr from sae_auto_interp.sae import Sae from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae from sae_auto_interp.features.features import upsample_mask import torch from transformers import AutoTokenizer from PIL import Image CITATION_BUTTON_TEXT = """ @misc{zhang2024largemultimodalmodelsinterpret, title={Large Multi-modal Models Can Interpret Features in Large Multi-modal Models}, author={Kaichen Zhang and Yifei Shen and Bo Li and Ziwei Liu}, year={2024}, eprint={2411.14982}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2411.14982}, } """ cached_tensor = None topk_indices = None sunglasses_file_path = "assets/sunglasses.jpg" greedy_file_path = "assets/greedy.jpg" railway_file_path = "assets/railway.jpg" def generate_activations(image): prompt = "" inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) global cached_tensor, topk_indices def hook(module: torch.nn.Module, _, outputs): global cached_tensor, topk_indices # Maybe unpack tuple outputs if isinstance(outputs, tuple): unpack_outputs = list(outputs) else: unpack_outputs = list(outputs) latents = sae.pre_acts(unpack_outputs[0]) # When the tokenizer is llama and text is None (image only) # I skip the first bos tokens if "llama" in tokenizer.name_or_path: latents = latents[:, 1:, :] topk = torch.topk( latents, k=sae.cfg.k, dim=-1 ) # make all other values 0 result = torch.zeros_like(latents) # results (bs, seq, num_latents) result.scatter_(-1, topk.indices, topk.values) cached_tensor = result.detach().cpu() topk_indices = ( latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu() ) handles = [hooked_module.register_forward_hook(hook)] try: with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"].to("cuda"), pixel_values=inputs["pixel_values"].to("cuda"), image_sizes=inputs["image_sizes"].to("cuda"), attention_mask=inputs["attention_mask"].to("cuda"), ) finally: for handle in handles: handle.remove() print(cached_tensor.shape) torch.cuda.empty_cache() return topk_indices def visualize_activations(image, feature_num): base_img_tokens = 576 patch_size = 24 # Using Cached tensor # select the feature_num-th feature # Then keeping the first 576 tokens base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size) upsampled_image_mask = upsample_mask(base_image_activations, (336, 336)) background = Image.new("L", (336, 336), 0).convert("RGB") # Somehow as I looked closer into the llava-hf preprocessing code, # I found out that they don't use the padded image as the base image feat # but use the simple resized image. This is different from original llava but # we align to llava-hf for now as we use llava-hf resized_image = image.resize((336, 336)) activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB") return activation_images with gr.Blocks() as demo: gr.Markdown( """ # Large Multi-modal Models Can Interpret Features in Large Multi-modal Models 🔍 [ArXiv Paper](https://arxiv.org/abs/2411.14982) | 🏠 [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | 🤗 [Huggingface Collections](https://huggingface.co./collections/lmms-lab/llava-sae-674026e4e7bc8c29c70bc3a3) """ ) with gr.Tabs(elem_classes="tab-buttons") as tabs: with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0): with gr.Row(): with gr.Column(): image = gr.Image(type="pil", interactive=True, label="Sample Image") topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features") with gr.Row(): clear_btn = gr.ClearButton([image, topk_features], value="Clear") submit_btn = gr.Button("Submit", variant="primary") submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features]) with gr.Column(): output = gr.Image(label="Activation Visualization") feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True) visualize_btn = gr.Button("Visualize", variant="primary") visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output]) dummy_text = gr.Textbox(visible=False, label="Explanation") gr.Examples( [ ["assets/sunglasses.jpg", 10, "Sunglasses"], ["assets/greedy.jpg", 14, "Greedy eating"], ["assets/railway.jpg", 28, "Railway tracks"], ], inputs=[image, feature_num, dummy_text], label="Examples", ) with gr.TabItem("Steering Model", elem_id="steering", id=2): chatbot = gr.Chatbot() with gr.Row(): with gr.Accordion("📙 Citation", open=False): gr.Markdown("```bib\n" + CITATION_BUTTON_TEXT + "\n```") if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf") sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24") model, processor = maybe_load_llava_model( "llava-hf/llama3-llava-next-8b-hf", rank=0, dtype=torch.bfloat16, hf_token=None ) hooked_module = model.language_model.get_submodule("model.layers.24") demo.launch()