cyberandy commited on
Commit
89cfd4d
·
verified ·
1 Parent(s): 9ec0501

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -82
app.py CHANGED
@@ -1,43 +1,21 @@
1
  import requests
2
  import gradio as gr
3
 
4
- def parse_api_response(data):
5
- features_by_token = {}
6
- for result in data['results']:
7
- token = result['token']
8
- if token == '<bos>':
9
- continue
10
- features_by_token[token] = []
11
- for feature in result['top_features']:
12
- features_by_token[token].append({
13
- 'id': feature['feature_index'],
14
- 'activation': feature['activation_value']
15
- })
16
- return features_by_token
17
-
18
- def analyze_text(text: str):
19
- if not text:
20
- return None, ""
21
-
22
  try:
23
- response = requests.post(
24
- "https://www.neuronpedia.org/api/search-with-topk",
25
- headers={"Content-Type": "application/json"},
26
- json={
27
- "modelId": "gemma-2-2b",
28
- "text": text,
29
- "layer": "20-gemmascope-res-16k"
30
- }
31
- )
32
  response.raise_for_status()
33
- features = parse_api_response(response.json())
34
- first_feature = next(iter(next(iter(features.values()))))
35
- dashboard = create_dashboard(first_feature['id'])
36
- return features, dashboard
37
  except Exception as e:
38
- return None, str(e)
39
 
40
- def create_dashboard(feature_id: int):
41
  return f"""
42
  <div class="dashboard-container p-4">
43
  <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
@@ -51,77 +29,119 @@ def create_dashboard(feature_id: int):
51
  </div>
52
  """
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  css = """
55
- .token-section { margin-bottom: 1.5rem; }
56
- .token-header {
 
 
 
 
 
 
 
 
 
57
  font-size: 1.25rem;
58
  font-weight: 600;
59
- margin-bottom: 0.75rem;
 
60
  }
 
61
  .feature-button {
62
- background: #f8fafc;
63
- border: 1px solid #e2e8f0;
64
- border-radius: 0.5rem;
65
- padding: 0.5rem 1rem;
66
  margin: 0.25rem;
 
 
 
 
67
  font-size: 0.875rem;
68
- cursor: pointer;
69
  }
70
- .feature-button:hover { background: #f1f5f9; }
71
- .feature-button.selected {
72
- background: #e0e7ff;
73
- border-color: #6366f1;
74
  }
75
  """
76
 
77
- def update_dashboard(feature_id, features_state):
78
- return create_dashboard(feature_id)
79
-
80
- with gr.Blocks(css=css) as demo:
81
- gr.Markdown("# Brand Analyzer")
82
- gr.Markdown("*Analyze text using Gemma's interpretable neural features*")
83
 
84
- features_state = gr.State()
85
- selected_feature = gr.State()
86
 
87
  with gr.Row():
88
  with gr.Column(scale=1):
89
- input_text = gr.Textbox(lines=5, label="Input Text")
 
 
 
 
90
  analyze_btn = gr.Button("Analyze Features", variant="primary")
91
- gr.Examples(["WordLift", "Think Different", "Just Do It"], inputs=input_text)
 
 
 
92
 
93
  with gr.Column(scale=2):
94
- features_container = gr.HTML()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  dashboard = gr.HTML()
96
-
97
- def render_features(features_dict, selected=None):
98
- if not features_dict:
99
- return ""
100
-
101
- html = ""
102
- for token, features in features_dict.items():
103
- html += f'<div class="token-section">'
104
- html += f'<div class="token-header">{token}</div>'
105
- for feature in features[:3]:
106
- selected_class = "selected" if selected == feature['id'] else ""
107
- html += f"""
108
- <button
109
- onclick='updateDashboard({feature["id"]})'
110
- class='feature-button {selected_class}'>
111
- Feature {feature["id"]} (Activation: {feature["activation"]:.2f})
112
- </button>
113
- """
114
- html += '</div>'
115
- return html
116
-
117
  analyze_btn.click(
118
- analyze_text,
119
  inputs=[input_text],
120
  outputs=[features_state, dashboard]
121
- ).then(
122
- render_features,
123
- inputs=[features_state, selected_feature],
124
- outputs=[features_container]
125
  )
126
 
127
  if __name__ == "__main__":
 
1
  import requests
2
  import gradio as gr
3
 
4
+ def get_features(text: str):
5
+ url = "https://www.neuronpedia.org/api/search-with-topk"
6
+ payload = {
7
+ "modelId": "gemma-2-2b",
8
+ "text": text,
9
+ "layer": "20-gemmascope-res-16k"
10
+ }
 
 
 
 
 
 
 
 
 
 
 
11
  try:
12
+ response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload)
 
 
 
 
 
 
 
 
13
  response.raise_for_status()
14
+ return response.json()
 
 
 
15
  except Exception as e:
16
+ return None
17
 
18
+ def create_dashboard(feature_id: int) -> str:
19
  return f"""
20
  <div class="dashboard-container p-4">
21
  <h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
 
29
  </div>
30
  """
31
 
32
+ def handle_feature_click(feature_id):
33
+ return create_dashboard(feature_id)
34
+
35
+ def analyze_text(text: str):
36
+ if not text:
37
+ return [], ""
38
+
39
+ features_data = get_features(text)
40
+ if not features_data:
41
+ return [], ""
42
+
43
+ features = []
44
+ first_feature_id = None
45
+
46
+ for result in features_data['results']:
47
+ if result['token'] == '<bos>':
48
+ continue
49
+
50
+ token = result['token']
51
+ token_features = []
52
+
53
+ for feature in result['top_features'][:3]:
54
+ feature_id = feature['feature_index']
55
+ if first_feature_id is None:
56
+ first_feature_id = feature_id
57
+
58
+ token_features.append({
59
+ "token": token,
60
+ "id": feature_id,
61
+ "activation": feature['activation_value']
62
+ })
63
+
64
+ features.append({"token": token, "features": token_features})
65
+
66
+ return features, create_dashboard(first_feature_id) if first_feature_id else ""
67
+
68
  css = """
69
+ @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
70
+
71
+ body { font-family: 'Open Sans', sans-serif !important; }
72
+
73
+ .dashboard-container {
74
+ border: 1px solid #e0e5ff;
75
+ border-radius: 8px;
76
+ background-color: #ffffff;
77
+ }
78
+
79
+ .token-header {
80
  font-size: 1.25rem;
81
  font-weight: 600;
82
+ margin-top: 1rem;
83
+ margin-bottom: 0.5rem;
84
  }
85
+
86
  .feature-button {
87
+ display: inline-block;
 
 
 
88
  margin: 0.25rem;
89
+ padding: 0.5rem 1rem;
90
+ background-color: #f3f4f6;
91
+ border: 1px solid #e5e7eb;
92
+ border-radius: 0.375rem;
93
  font-size: 0.875rem;
 
94
  }
95
+
96
+ .feature-button:hover {
97
+ background-color: #e5e7eb;
 
98
  }
99
  """
100
 
101
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
102
+ gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2")
103
+ gr.Markdown("*Analyze text using Gemma's interpretable neural features*", elem_classes="text-gray-600 mb-6")
 
 
 
104
 
105
+ features_state = gr.State([])
 
106
 
107
  with gr.Row():
108
  with gr.Column(scale=1):
109
+ input_text = gr.Textbox(
110
+ lines=5,
111
+ placeholder="Enter text to analyze...",
112
+ label="Input Text"
113
+ )
114
  analyze_btn = gr.Button("Analyze Features", variant="primary")
115
+ gr.Examples(
116
+ examples=["WordLift", "Think Different", "Just Do It"],
117
+ inputs=input_text
118
+ )
119
 
120
  with gr.Column(scale=2):
121
+ @gr.render(inputs=features_state)
122
+ def render_features(features):
123
+ if not features:
124
+ return
125
+
126
+ for token_group in features:
127
+ gr.Markdown(f"### {token_group['token']}")
128
+ with gr.Row():
129
+ for feature in token_group['features']:
130
+ btn = gr.Button(
131
+ f"Feature {feature['id']} (Activation: {feature['activation']:.2f})",
132
+ elem_classes=["feature-button"]
133
+ )
134
+ btn.click(
135
+ fn=lambda fid=feature['id']: handle_feature_click(fid),
136
+ outputs=dashboard
137
+ )
138
+
139
  dashboard = gr.HTML()
140
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  analyze_btn.click(
142
+ fn=analyze_text,
143
  inputs=[input_text],
144
  outputs=[features_state, dashboard]
 
 
 
 
145
  )
146
 
147
  if __name__ == "__main__":