cyberandy commited on
Commit
a5ec26c
·
verified ·
1 Parent(s): 3d2eab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -63
app.py CHANGED
@@ -1,60 +1,22 @@
1
- Hugging Face's logo
2
- Hugging Face
3
- Search models, datasets, users...
4
- Models
5
- Datasets
6
- Spaces
7
- Posts
8
- Docs
9
- Enterprise
10
- Pricing
11
-
12
-
13
-
14
- Spaces:
15
-
16
- WordLift
17
- /
18
- brand-llms
19
-
20
-
21
- like
22
- 0
23
-
24
- Logs
25
- App
26
- Files
27
- Community
28
- 1
29
- Settings
30
- brand-llms
31
- /
32
- app.py
33
-
34
- cyberandy's picture
35
- cyberandy
36
- Update app.py
37
- 3bc7e87
38
- verified
39
- 21 days ago
40
- raw
41
-
42
- Copy download link
43
- history
44
- blame
45
- edit
46
- delete
47
-
48
- 4.55 kB
49
  import requests
50
  import gradio as gr
 
 
 
 
 
51
 
52
- def get_features(text: str):
 
 
 
 
 
53
  url = "https://www.neuronpedia.org/api/search-with-topk"
54
  payload = {
55
- "modelId": "gemma-2-2b",
56
  "text": text,
57
- "layer": "20-gemmascope-res-16k"
58
  }
59
  try:
60
  response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload)
@@ -63,12 +25,15 @@ def get_features(text: str):
63
  except Exception as e:
64
  return None
65
 
66
- def create_dashboard(feature_id: int) -> str:
 
 
 
67
  return f"""
68
  <div class="dashboard-container p-4">
69
  <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
70
  <iframe
71
- src="https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
72
  width="100%"
73
  height="600"
74
  frameborder="0"
@@ -77,14 +42,16 @@ def create_dashboard(feature_id: int) -> str:
77
  </div>
78
  """
79
 
80
- def handle_feature_click(feature_id):
81
- return create_dashboard(feature_id)
82
 
83
- def analyze_text(text: str):
 
 
84
  if not text:
85
  return [], ""
86
 
87
- features_data = get_features(text)
88
  if not features_data:
89
  return [], ""
90
 
@@ -111,7 +78,7 @@ def analyze_text(text: str):
111
 
112
  features.append({"token": token, "features": token_features})
113
 
114
- return features, create_dashboard(first_feature_id) if first_feature_id else ""
115
 
116
  css = """
117
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
@@ -139,6 +106,11 @@ body { font-family: 'Open Sans', sans-serif !important; }
139
  .feature-button:hover {
140
  background-color: #e5e7eb;
141
  }
 
 
 
 
 
142
  """
143
 
144
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
@@ -146,6 +118,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
146
  gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6")
147
 
148
  features_state = gr.State([])
 
 
 
 
 
149
 
150
  with gr.Row():
151
  with gr.Column(scale=1):
@@ -161,11 +138,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
161
  )
162
 
163
  with gr.Column(scale=2):
164
- @gr.render(inputs=features_state)
165
- def render_features(features):
166
  if not features:
167
  return
168
 
 
 
169
  for token_group in features:
170
  gr.Markdown(f"### {token_group['token']}")
171
  with gr.Row():
@@ -175,17 +154,32 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
175
  elem_classes=["feature-button"]
176
  )
177
  btn.click(
178
- fn=lambda fid=feature['id']: handle_feature_click(fid),
179
  outputs=dashboard
180
  )
181
 
182
  dashboard = gr.HTML()
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  analyze_btn.click(
185
  fn=analyze_text,
186
- inputs=[input_text],
187
  outputs=[features_state, dashboard]
188
  )
189
 
190
  if __name__ == "__main__":
191
- demo.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  import gradio as gr
3
+ from enum import Enum
4
+
5
+ class Model(Enum):
6
+ GEMMA = "gemma-2-2b"
7
+ GPT2 = "gpt2-small"
8
 
9
+ MODEL_CONFIGS = {
10
+ Model.GEMMA: "20-gemmascope-res-16k",
11
+ Model.GPT2: "9-res-jb"
12
+ }
13
+
14
+ def get_features(text: str, model: Model):
15
  url = "https://www.neuronpedia.org/api/search-with-topk"
16
  payload = {
17
+ "modelId": model.value,
18
  "text": text,
19
+ "layer": MODEL_CONFIGS[model]
20
  }
21
  try:
22
  response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload)
 
25
  except Exception as e:
26
  return None
27
 
28
+ def create_dashboard(feature_id: int, model: Model) -> str:
29
+ model_path = model.value.lower()
30
+ layer_name = MODEL_CONFIGS[model].lower()
31
+
32
  return f"""
33
  <div class="dashboard-container p-4">
34
  <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
35
  <iframe
36
+ src="https://www.neuronpedia.org/{model_path}/{layer_name}/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
37
  width="100%"
38
  height="600"
39
  frameborder="0"
 
42
  </div>
43
  """
44
 
45
+ def handle_feature_click(feature_id: int, model: Model):
46
+ return create_dashboard(feature_id, model)
47
 
48
+ def analyze_text(text: str, selected_model: str):
49
+ model = Model.GEMMA if selected_model == "Gemini" else Model.GPT2
50
+
51
  if not text:
52
  return [], ""
53
 
54
+ features_data = get_features(text, model)
55
  if not features_data:
56
  return [], ""
57
 
 
78
 
79
  features.append({"token": token, "features": token_features})
80
 
81
+ return features, create_dashboard(first_feature_id, model) if first_feature_id else ""
82
 
83
  css = """
84
  @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
 
106
  .feature-button:hover {
107
  background-color: #e5e7eb;
108
  }
109
+ .model-selector {
110
+ display: flex;
111
+ gap: 1rem;
112
+ margin-bottom: 1rem;
113
+ }
114
  """
115
 
116
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
 
118
  gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6")
119
 
120
  features_state = gr.State([])
121
+ selected_model = gr.State("Gemini") # Default to Gemini
122
+
123
+ with gr.Row(elem_classes="model-selector"):
124
+ gemini_btn = gr.Button("🧬 Gemini", variant="primary" if selected_model.value == "Gemini" else "secondary")
125
+ openai_btn = gr.Button("🤖 OpenAI", variant="secondary")
126
 
127
  with gr.Row():
128
  with gr.Column(scale=1):
 
138
  )
139
 
140
  with gr.Column(scale=2):
141
+ @gr.render(inputs=[features_state, selected_model])
142
+ def render_features(features, current_model):
143
  if not features:
144
  return
145
 
146
+ model = Model.GEMMA if current_model == "Gemini" else Model.GPT2
147
+
148
  for token_group in features:
149
  gr.Markdown(f"### {token_group['token']}")
150
  with gr.Row():
 
154
  elem_classes=["feature-button"]
155
  )
156
  btn.click(
157
+ fn=lambda fid=feature['id']: handle_feature_click(fid, model),
158
  outputs=dashboard
159
  )
160
 
161
  dashboard = gr.HTML()
162
 
163
+ def update_model(new_model):
164
+ return new_model
165
+
166
+ gemini_btn.click(
167
+ fn=lambda: update_model("Gemini"),
168
+ outputs=selected_model,
169
+ queue=False
170
+ )
171
+
172
+ openai_btn.click(
173
+ fn=lambda: update_model("OpenAI"),
174
+ outputs=selected_model,
175
+ queue=False
176
+ )
177
+
178
  analyze_btn.click(
179
  fn=analyze_text,
180
+ inputs=[input_text, selected_model],
181
  outputs=[features_state, dashboard]
182
  )
183
 
184
  if __name__ == "__main__":
185
+ demo.launch(share=False)