Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,60 +1,22 @@
|
|
1 |
-
Hugging Face's logo
|
2 |
-
Hugging Face
|
3 |
-
Search models, datasets, users...
|
4 |
-
Models
|
5 |
-
Datasets
|
6 |
-
Spaces
|
7 |
-
Posts
|
8 |
-
Docs
|
9 |
-
Enterprise
|
10 |
-
Pricing
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
Spaces:
|
15 |
-
|
16 |
-
WordLift
|
17 |
-
/
|
18 |
-
brand-llms
|
19 |
-
|
20 |
-
|
21 |
-
like
|
22 |
-
0
|
23 |
-
|
24 |
-
Logs
|
25 |
-
App
|
26 |
-
Files
|
27 |
-
Community
|
28 |
-
1
|
29 |
-
Settings
|
30 |
-
brand-llms
|
31 |
-
/
|
32 |
-
app.py
|
33 |
-
|
34 |
-
cyberandy's picture
|
35 |
-
cyberandy
|
36 |
-
Update app.py
|
37 |
-
3bc7e87
|
38 |
-
verified
|
39 |
-
21 days ago
|
40 |
-
raw
|
41 |
-
|
42 |
-
Copy download link
|
43 |
-
history
|
44 |
-
blame
|
45 |
-
edit
|
46 |
-
delete
|
47 |
-
|
48 |
-
4.55 kB
|
49 |
import requests
|
50 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
54 |
payload = {
|
55 |
-
"modelId":
|
56 |
"text": text,
|
57 |
-
"layer":
|
58 |
}
|
59 |
try:
|
60 |
response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload)
|
@@ -63,12 +25,15 @@ def get_features(text: str):
|
|
63 |
except Exception as e:
|
64 |
return None
|
65 |
|
66 |
-
def create_dashboard(feature_id: int) -> str:
|
|
|
|
|
|
|
67 |
return f"""
|
68 |
<div class="dashboard-container p-4">
|
69 |
<h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
|
70 |
<iframe
|
71 |
-
src="https://www.neuronpedia.org/
|
72 |
width="100%"
|
73 |
height="600"
|
74 |
frameborder="0"
|
@@ -77,14 +42,16 @@ def create_dashboard(feature_id: int) -> str:
|
|
77 |
</div>
|
78 |
"""
|
79 |
|
80 |
-
def handle_feature_click(feature_id):
|
81 |
-
return create_dashboard(feature_id)
|
82 |
|
83 |
-
def analyze_text(text: str):
|
|
|
|
|
84 |
if not text:
|
85 |
return [], ""
|
86 |
|
87 |
-
features_data = get_features(text)
|
88 |
if not features_data:
|
89 |
return [], ""
|
90 |
|
@@ -111,7 +78,7 @@ def analyze_text(text: str):
|
|
111 |
|
112 |
features.append({"token": token, "features": token_features})
|
113 |
|
114 |
-
return features, create_dashboard(first_feature_id) if first_feature_id else ""
|
115 |
|
116 |
css = """
|
117 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
@@ -139,6 +106,11 @@ body { font-family: 'Open Sans', sans-serif !important; }
|
|
139 |
.feature-button:hover {
|
140 |
background-color: #e5e7eb;
|
141 |
}
|
|
|
|
|
|
|
|
|
|
|
142 |
"""
|
143 |
|
144 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
@@ -146,6 +118,11 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
146 |
gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
147 |
|
148 |
features_state = gr.State([])
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
with gr.Row():
|
151 |
with gr.Column(scale=1):
|
@@ -161,11 +138,13 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
161 |
)
|
162 |
|
163 |
with gr.Column(scale=2):
|
164 |
-
@gr.render(inputs=features_state)
|
165 |
-
def render_features(features):
|
166 |
if not features:
|
167 |
return
|
168 |
|
|
|
|
|
169 |
for token_group in features:
|
170 |
gr.Markdown(f"### {token_group['token']}")
|
171 |
with gr.Row():
|
@@ -175,17 +154,32 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
175 |
elem_classes=["feature-button"]
|
176 |
)
|
177 |
btn.click(
|
178 |
-
fn=lambda fid=feature['id']: handle_feature_click(fid),
|
179 |
outputs=dashboard
|
180 |
)
|
181 |
|
182 |
dashboard = gr.HTML()
|
183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
analyze_btn.click(
|
185 |
fn=analyze_text,
|
186 |
-
inputs=[input_text],
|
187 |
outputs=[features_state, dashboard]
|
188 |
)
|
189 |
|
190 |
if __name__ == "__main__":
|
191 |
-
demo.launch(share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
import gradio as gr
|
3 |
+
from enum import Enum
|
4 |
+
|
5 |
+
class Model(Enum):
|
6 |
+
GEMMA = "gemma-2-2b"
|
7 |
+
GPT2 = "gpt2-small"
|
8 |
|
9 |
+
MODEL_CONFIGS = {
|
10 |
+
Model.GEMMA: "20-gemmascope-res-16k",
|
11 |
+
Model.GPT2: "9-res-jb"
|
12 |
+
}
|
13 |
+
|
14 |
+
def get_features(text: str, model: Model):
|
15 |
url = "https://www.neuronpedia.org/api/search-with-topk"
|
16 |
payload = {
|
17 |
+
"modelId": model.value,
|
18 |
"text": text,
|
19 |
+
"layer": MODEL_CONFIGS[model]
|
20 |
}
|
21 |
try:
|
22 |
response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload)
|
|
|
25 |
except Exception as e:
|
26 |
return None
|
27 |
|
28 |
+
def create_dashboard(feature_id: int, model: Model) -> str:
|
29 |
+
model_path = model.value.lower()
|
30 |
+
layer_name = MODEL_CONFIGS[model].lower()
|
31 |
+
|
32 |
return f"""
|
33 |
<div class="dashboard-container p-4">
|
34 |
<h3 class="text-lg font-semibold mb-4">Feature {feature_id} Dashboard</h3>
|
35 |
<iframe
|
36 |
+
src="https://www.neuronpedia.org/{model_path}/{layer_name}/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"
|
37 |
width="100%"
|
38 |
height="600"
|
39 |
frameborder="0"
|
|
|
42 |
</div>
|
43 |
"""
|
44 |
|
45 |
+
def handle_feature_click(feature_id: int, model: Model):
|
46 |
+
return create_dashboard(feature_id, model)
|
47 |
|
48 |
+
def analyze_text(text: str, selected_model: str):
|
49 |
+
model = Model.GEMMA if selected_model == "Gemini" else Model.GPT2
|
50 |
+
|
51 |
if not text:
|
52 |
return [], ""
|
53 |
|
54 |
+
features_data = get_features(text, model)
|
55 |
if not features_data:
|
56 |
return [], ""
|
57 |
|
|
|
78 |
|
79 |
features.append({"token": token, "features": token_features})
|
80 |
|
81 |
+
return features, create_dashboard(first_feature_id, model) if first_feature_id else ""
|
82 |
|
83 |
css = """
|
84 |
@import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap');
|
|
|
106 |
.feature-button:hover {
|
107 |
background-color: #e5e7eb;
|
108 |
}
|
109 |
+
.model-selector {
|
110 |
+
display: flex;
|
111 |
+
gap: 1rem;
|
112 |
+
margin-bottom: 1rem;
|
113 |
+
}
|
114 |
"""
|
115 |
|
116 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
|
118 |
gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6")
|
119 |
|
120 |
features_state = gr.State([])
|
121 |
+
selected_model = gr.State("Gemini") # Default to Gemini
|
122 |
+
|
123 |
+
with gr.Row(elem_classes="model-selector"):
|
124 |
+
gemini_btn = gr.Button("🧬 Gemini", variant="primary" if selected_model.value == "Gemini" else "secondary")
|
125 |
+
openai_btn = gr.Button("🤖 OpenAI", variant="secondary")
|
126 |
|
127 |
with gr.Row():
|
128 |
with gr.Column(scale=1):
|
|
|
138 |
)
|
139 |
|
140 |
with gr.Column(scale=2):
|
141 |
+
@gr.render(inputs=[features_state, selected_model])
|
142 |
+
def render_features(features, current_model):
|
143 |
if not features:
|
144 |
return
|
145 |
|
146 |
+
model = Model.GEMMA if current_model == "Gemini" else Model.GPT2
|
147 |
+
|
148 |
for token_group in features:
|
149 |
gr.Markdown(f"### {token_group['token']}")
|
150 |
with gr.Row():
|
|
|
154 |
elem_classes=["feature-button"]
|
155 |
)
|
156 |
btn.click(
|
157 |
+
fn=lambda fid=feature['id']: handle_feature_click(fid, model),
|
158 |
outputs=dashboard
|
159 |
)
|
160 |
|
161 |
dashboard = gr.HTML()
|
162 |
|
163 |
+
def update_model(new_model):
|
164 |
+
return new_model
|
165 |
+
|
166 |
+
gemini_btn.click(
|
167 |
+
fn=lambda: update_model("Gemini"),
|
168 |
+
outputs=selected_model,
|
169 |
+
queue=False
|
170 |
+
)
|
171 |
+
|
172 |
+
openai_btn.click(
|
173 |
+
fn=lambda: update_model("OpenAI"),
|
174 |
+
outputs=selected_model,
|
175 |
+
queue=False
|
176 |
+
)
|
177 |
+
|
178 |
analyze_btn.click(
|
179 |
fn=analyze_text,
|
180 |
+
inputs=[input_text, selected_model],
|
181 |
outputs=[features_state, dashboard]
|
182 |
)
|
183 |
|
184 |
if __name__ == "__main__":
|
185 |
+
demo.launch(share=False)
|