import gradio as gr from sae_auto_interp.sae import Sae from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae from sae_auto_interp.features.features import upsample_mask import torch from transformers import AutoTokenizer from PIL import Image CITATION_BUTTON_TEXT = """ @misc{zhang2024largemultimodalmodelsinterpret, title={Large Multi-modal Models Can Interpret Features in Large Multi-modal Models}, author={Kaichen Zhang and Yifei Shen and Bo Li and Ziwei Liu}, year={2024}, eprint={2411.14982}, archivePrefix={arXiv}, primaryClass={cs.CV}, url={https://arxiv.org/abs/2411.14982}, } """ cached_tensor = None topk_indices = None sunglasses_file_path = "assets/sunglasses.jpg" greedy_file_path = "assets/greedy.jpg" railway_file_path = "assets/railway.jpg" happy_file_path = "assets/happy.jpg" def generate_activations(image): prompt = "" inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) global cached_tensor, topk_indices def hook(module: torch.nn.Module, _, outputs): global cached_tensor, topk_indices # Maybe unpack tuple outputs if isinstance(outputs, tuple): unpack_outputs = list(outputs) else: unpack_outputs = list(outputs) latents = sae.pre_acts(unpack_outputs[0]) # When the tokenizer is llama and text is None (image only) # I skip the first bos tokens if "llama" in tokenizer.name_or_path: latents = latents[:, 1:, :] topk = torch.topk( latents, k=sae.cfg.k, dim=-1 ) # make all other values 0 result = torch.zeros_like(latents) # results (bs, seq, num_latents) result.scatter_(-1, topk.indices, topk.values) cached_tensor = result.detach().cpu() topk_indices = ( latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu() ) handles = [hooked_module.register_forward_hook(hook)] try: with torch.no_grad(): outputs = model( input_ids=inputs["input_ids"].to("cuda"), pixel_values=inputs["pixel_values"].to("cuda"), image_sizes=inputs["image_sizes"].to("cuda"), attention_mask=inputs["attention_mask"].to("cuda"), ) finally: for handle in handles: handle.remove() torch.cuda.empty_cache() return topk_indices def visualize_activations(image, feature_num): base_img_tokens = 576 patch_size = 24 # Using Cached tensor # select the feature_num-th feature # Then keeping the first 576 tokens base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size) upsampled_image_mask = upsample_mask(base_image_activations, (336, 336)) background = Image.new("L", (336, 336), 0).convert("RGB") # Somehow as I looked closer into the llava-hf preprocessing code, # I found out that they don't use the padded image as the base image feat # but use the simple resized image. This is different from original llava but # we align to llava-hf for now as we use llava-hf resized_image = image.resize((336, 336)) activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB") return activation_images def clamp_features_max( sae: Sae, feature: int, hooked_module: torch.nn.Module, k: float = 10 ): def hook(module: torch.nn.Module, _, outputs): # Maybe unpack tuple outputs if isinstance(outputs, tuple): unpack_outputs = list(outputs) else: unpack_outputs = list(outputs) latents = sae.pre_acts(unpack_outputs[0]) # Only clamp the feature for the first forward if latents.shape[1] != 1: latents[:, :, feature] = k top_acts, top_indices = sae.select_topk(latents) sae_out = sae.decode(top_acts[0], top_indices[0]).unsqueeze(0).to(torch.float16) unpack_outputs[0] = sae_out if isinstance(outputs, tuple): outputs = tuple(unpack_outputs) else: outputs = unpack_outputs[0] return outputs handles = [hooked_module.register_forward_hook(hook)] return handles def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history): if not isinstance(feature_idx, int): feature_idx = int(feature_idx) if not isinstance(feature_strength, float): feature_strength = float(feature_strength) conversation = [ { "role": "user", "content": [ {"type": "text", "text": text}, ], }, ] if image is not None: conversation[0]["content"].append( {"type": "image"}, ) chat_history.append({"role": "user", "content": gr.Image(value=image)}) chat_history.append({"role": "user", "content": text}) prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) handles = clamp_features_max(sae, feature_idx, hooked_module, k=feature_strength) try: with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=512) cont = output[:, inputs["input_ids"].shape[-1] :] finally: for handle in handles: handle.remove() text = processor.batch_decode(cont, skip_special_tokens=True)[0] chat_history.append( { "role": "assistant", "content": text, } ) return chat_history with gr.Blocks() as demo: gr.Markdown( """ # Large Multi-modal Models Can Interpret Features in Large Multi-modal Models 🔍 [ArXiv Paper](https://arxiv.org/abs/2411.14982) | 🏠 [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | 🤗 [Huggingface Collections](https://huggingface.co./collections/lmms-lab/llava-sae-674026e4e7bc8c29c70bc3a3) """ ) with gr.Tabs(elem_classes="tab-buttons") as tabs: with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0): with gr.Row(): with gr.Column(): image = gr.Image(type="pil", interactive=True, label="Sample Image") topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features") with gr.Row(): clear_btn = gr.ClearButton([image, topk_features], value="Clear") submit_btn = gr.Button("Submit", variant="primary") submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features]) with gr.Column(): output = gr.Image(label="Activation Visualization") feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True) visualize_btn = gr.Button("Visualize", variant="primary") visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output]) dummy_text = gr.Textbox(visible=False, label="Explanation") gr.Examples( [ ["assets/sunglasses.jpg", 10, "Sunglasses"], ["assets/greedy.jpg", 14, "Greedy eating"], ["assets/railway.jpg", 28, "Railway tracks"], ], inputs=[image, feature_num, dummy_text], label="Examples", ) with gr.TabItem("Steering Model", elem_id="steering", id=2): chatbot = gr.Chatbot(type="messages") with gr.Row(variant="compact", equal_height=True): feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True) feature_strength = gr.Number(value=50, label="Feature Strength", interactive=True) with gr.Row(variant="compact", equal_height=True): text_input = gr.Textbox(label="Text Input", placeholder="Type here", interactive=True) image_input = gr.Image(type="pil", label="Image Input", interactive=True, height=250) with gr.Row(): chatbot_clear = gr.ClearButton([text_input, image_input, chatbot], value="Clear") chatbot_submit = gr.Button("Submit", variant="primary") chatbot_submit.click( generate_with_clamp, inputs=[feature_num, feature_strength, text_input, image_input, chatbot], outputs=[chatbot], ) gr.Examples( [ [19379, 50, "Look at this image, what is your feeling right now?", happy_file_path], [14, 50, "Tell me a story about Alice and Bob", None], [108692, 50, "What is your feeling right now?", None], ], inputs=[feature_num, feature_strength, text_input, image_input], label="Examples", ) with gr.Row(): with gr.Accordion("📙 Citation", open=False): gr.Markdown("```bib\n" + CITATION_BUTTON_TEXT + "\n```") if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf") sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24") model, processor = maybe_load_llava_model( "llava-hf/llama3-llava-next-8b-hf", rank=0, dtype=torch.float16, hf_token=None ) hooked_module = model.language_model.get_submodule("model.layers.24") demo.launch()