kcz358 commited on
Commit
37dda2c
·
1 Parent(s): bb033d4
Files changed (2) hide show
  1. app.py +4 -1
  2. requirements.txt +1 -1
app.py CHANGED
@@ -5,6 +5,7 @@ from sae_auto_interp.features.features import upsample_mask
5
  import torch
6
  from transformers import AutoTokenizer
7
  from PIL import Image
 
8
 
9
  CITATION_BUTTON_TEXT = """
10
  @misc{zhang2024largemultimodalmodelsinterpret,
@@ -46,7 +47,7 @@ greedy_file_path = "assets/greedy.jpg"
46
  railway_file_path = "assets/railway.jpg"
47
  happy_file_path = "assets/happy.jpg"
48
 
49
-
50
  def generate_activations(image):
51
  prompt = "<image>"
52
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
@@ -116,6 +117,7 @@ def visualize_activations(image, feature_num):
116
 
117
  return activation_images
118
 
 
119
  def clamp_features_max(
120
  sae: Sae, feature: int, hooked_module: torch.nn.Module, k: float = 10
121
  ):
@@ -142,6 +144,7 @@ def clamp_features_max(
142
 
143
  return handles
144
 
 
145
  def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history):
146
  if not isinstance(feature_idx, int):
147
  feature_idx = int(feature_idx)
 
5
  import torch
6
  from transformers import AutoTokenizer
7
  from PIL import Image
8
+ import spaces
9
 
10
  CITATION_BUTTON_TEXT = """
11
  @misc{zhang2024largemultimodalmodelsinterpret,
 
47
  railway_file_path = "assets/railway.jpg"
48
  happy_file_path = "assets/happy.jpg"
49
 
50
+ @spaces.GPU
51
  def generate_activations(image):
52
  prompt = "<image>"
53
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
 
117
 
118
  return activation_images
119
 
120
+ @spaces.GPU
121
  def clamp_features_max(
122
  sae: Sae, feature: int, hooked_module: torch.nn.Module, k: float = 10
123
  ):
 
144
 
145
  return handles
146
 
147
+ @spaces.GPU
148
  def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history):
149
  if not isinstance(feature_idx, int):
150
  feature_idx = int(feature_idx)
requirements.txt CHANGED
@@ -3,4 +3,4 @@ gradio
3
  sae_auto_interp @ git+https://github.com/EvolvingLMMs-Lab/multimodal-sae
4
  fastapi==0.112.2
5
  gradio==4.44.1
6
- httpx==0.23.3
 
3
  sae_auto_interp @ git+https://github.com/EvolvingLMMs-Lab/multimodal-sae
4
  fastapi==0.112.2
5
  gradio==4.44.1
6
+ httpx==0.24.1