Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from sae_auto_interp.sae import Sae | |
from sae_auto_interp.utils import maybe_load_llava_model, load_single_sae | |
from sae_auto_interp.features.features import upsample_mask | |
import torch | |
from transformers import AutoTokenizer | |
from PIL import Image | |
CITATION_BUTTON_TEXT = """ | |
@misc{zhang2024largemultimodalmodelsinterpret, | |
title={Large Multi-modal Models Can Interpret Features in Large Multi-modal Models}, | |
author={Kaichen Zhang and Yifei Shen and Bo Li and Ziwei Liu}, | |
year={2024}, | |
eprint={2411.14982}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CV}, | |
url={https://arxiv.org/abs/2411.14982}, | |
} | |
""" | |
cached_tensor = None | |
topk_indices = None | |
sunglasses_file_path = "assets/sunglasses.jpg" | |
greedy_file_path = "assets/greedy.jpg" | |
railway_file_path = "assets/railway.jpg" | |
def generate_activations(image): | |
prompt = "<image>" | |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) | |
global cached_tensor, topk_indices | |
def hook(module: torch.nn.Module, _, outputs): | |
global cached_tensor, topk_indices | |
# Maybe unpack tuple outputs | |
if isinstance(outputs, tuple): | |
unpack_outputs = list(outputs) | |
else: | |
unpack_outputs = list(outputs) | |
latents = sae.pre_acts(unpack_outputs[0]) | |
# When the tokenizer is llama and text is None (image only) | |
# I skip the first bos tokens | |
if "llama" in tokenizer.name_or_path: | |
latents = latents[:, 1:, :] | |
topk = torch.topk( | |
latents, k=sae.cfg.k, dim=-1 | |
) | |
# make all other values 0 | |
result = torch.zeros_like(latents) | |
# results (bs, seq, num_latents) | |
result.scatter_(-1, topk.indices, topk.values) | |
cached_tensor = result.detach().cpu() | |
topk_indices = ( | |
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu() | |
) | |
handles = [hooked_module.register_forward_hook(hook)] | |
try: | |
with torch.no_grad(): | |
outputs = model( | |
input_ids=inputs["input_ids"].to("cuda"), | |
pixel_values=inputs["pixel_values"].to("cuda"), | |
image_sizes=inputs["image_sizes"].to("cuda"), | |
attention_mask=inputs["attention_mask"].to("cuda"), | |
) | |
finally: | |
for handle in handles: | |
handle.remove() | |
print(cached_tensor.shape) | |
torch.cuda.empty_cache() | |
return topk_indices | |
def visualize_activations(image, feature_num): | |
base_img_tokens = 576 | |
patch_size = 24 | |
# Using Cached tensor | |
# select the feature_num-th feature | |
# Then keeping the first 576 tokens | |
base_image_activations = cached_tensor[0, :base_img_tokens, feature_num].view(patch_size, patch_size) | |
upsampled_image_mask = upsample_mask(base_image_activations, (336, 336)) | |
background = Image.new("L", (336, 336), 0).convert("RGB") | |
# Somehow as I looked closer into the llava-hf preprocessing code, | |
# I found out that they don't use the padded image as the base image feat | |
# but use the simple resized image. This is different from original llava but | |
# we align to llava-hf for now as we use llava-hf | |
resized_image = image.resize((336, 336)) | |
activation_images = Image.composite(background, resized_image, upsampled_image_mask).convert("RGB") | |
return activation_images | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models | |
π [ArXiv Paper](https://arxiv.org/abs/2411.14982) | π [LMMs-Lab Homepage](https://lmms-lab.framer.ai) | π€ [Huggingface Collections](https://huggingface.co./collections/lmms-lab/llava-sae-674026e4e7bc8c29c70bc3a3) | |
""" | |
) | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("Visualization of Activations", elem_id="visualization", id=0): | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(type="pil", interactive=True, label="Sample Image") | |
topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features") | |
with gr.Row(): | |
clear_btn = gr.ClearButton([image, topk_features], value="Clear") | |
submit_btn = gr.Button("Submit", variant="primary") | |
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features]) | |
with gr.Column(): | |
output = gr.Image(label="Activation Visualization") | |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True) | |
visualize_btn = gr.Button("Visualize", variant="primary") | |
visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output]) | |
dummy_text = gr.Textbox(visible=False, label="Explanation") | |
gr.Examples( | |
[ | |
["assets/sunglasses.jpg", 10, "Sunglasses"], | |
["assets/greedy.jpg", 14, "Greedy eating"], | |
["assets/railway.jpg", 28, "Railway tracks"], | |
], | |
inputs=[image, feature_num, dummy_text], | |
label="Examples", | |
) | |
with gr.TabItem("Steering Model", elem_id="steering", id=2): | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Accordion("π Citation", open=False): | |
gr.Markdown("```bib\n" + CITATION_BUTTON_TEXT + "\n```") | |
if __name__ == "__main__": | |
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf") | |
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24") | |
model, processor = maybe_load_llava_model( | |
"llava-hf/llama3-llava-next-8b-hf", | |
rank=0, | |
dtype=torch.bfloat16, | |
hf_token=None | |
) | |
hooked_module = model.language_model.get_submodule("model.layers.24") | |
demo.launch() | |