import requests import gradio as gr from enum import Enum class Model(Enum): GEMMA = "gemma-2-2b" GPT2 = "gpt2-small" MODEL_CONFIGS = { Model.GEMMA: "20-gemmascope-res-16k", Model.GPT2: "9-res-jb" } def get_features(text: str, model: Model): url = "https://www.neuronpedia.org/api/search-with-topk" payload = { "modelId": model.value, "text": text, "layer": MODEL_CONFIGS[model] } try: response = requests.post(url, headers={"Content-Type": "application/json"}, json=payload) response.raise_for_status() return response.json() except Exception as e: return None def create_dashboard(feature_id: int, model: Model) -> str: model_path = model.value.lower() layer_name = MODEL_CONFIGS[model].lower() return f"""

Feature {feature_id} Dashboard

""" def handle_feature_click(feature_id: int, model: str): selected_model = Model.GEMMA if model == "Gemini" else Model.GPT2 return create_dashboard(feature_id, selected_model) def analyze_text(text: str, selected_model: str): model = Model.GEMMA if selected_model == "Gemini" else Model.GPT2 if not text: return [], "" features_data = get_features(text, model) if not features_data: return [], "" features = [] first_feature_id = None for result in features_data['results']: if result['token'] == '': continue token = result['token'] token_features = [] for feature in result['top_features'][:3]: feature_id = feature['feature_index'] if first_feature_id is None: first_feature_id = feature_id token_features.append({ "token": token, "id": feature_id, "activation": feature['activation_value'] }) features.append({"token": token, "features": token_features}) return features, create_dashboard(first_feature_id, model) if first_feature_id else "" css = """ @import url('https://fonts.googleapis.com/css2?family=Open+Sans:wght@300;400;600;700&display=swap'); body { font-family: 'Open Sans', sans-serif !important; } .dashboard-container { border: 1px solid #e0e5ff; border-radius: 8px; background-color: #ffffff; } .token-header { font-size: 1.25rem; font-weight: 600; margin-top: 1rem; margin-bottom: 0.5rem; } .feature-button { display: inline-block; margin: 0.25rem; padding: 0.5rem 1rem; background-color: #f3f4f6; border: 1px solid #e5e7eb; border-radius: 0.375rem; font-size: 0.875rem; } .feature-button:hover { background-color: #e5e7eb; } .model-selector { display: flex; gap: 8px; margin-bottom: 1rem; } #model-buttons .gr-form { background: transparent !important; border: none !important; box-shadow: none !important; } #model-buttons .gr-radio-row { gap: 8px !important; } #model-buttons label { display: flex !important; align-items: center !important; gap: 4px !important; padding: 4px 12px !important; border: 1px solid #e5e7eb !important; border-radius: 6px !important; font-size: 14px !important; cursor: pointer !important; transition: all 0.2s !important; } #model-buttons label:hover { background-color: #f3f4f6 !important; } #model-buttons label.selected { background-color: #4c4ce3 !important; color: white !important; border-color: #4c4ce3 !important; } #model-buttons label:before { content: "" !important; width: 20px !important; height: 20px !important; background-size: contain !important; background-repeat: no-repeat !important; background-position: center !important; } #model-buttons label:nth-child(1):before { background-image: url('img/gemini-icon.png') !important; } #model-buttons label:nth-child(2):before { background-image: url('img/openai-icon.png') !important; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: gr.Markdown("# Brand Analyzer", elem_classes="text-2xl font-bold mb-2") gr.Markdown("*Analyze text using interpretable neural features*", elem_classes="text-gray-600 mb-6") current_model = gr.State("Gemini") features_state = gr.State([]) with gr.Row(elem_classes="model-selector"): with gr.Column(scale=1): with gr.Row(): model_choice = gr.Radio( choices=["Gemini", "OpenAI"], value="Gemini", label="", elem_classes="model-selector", elem_id="model-buttons", container=False, interactive=True ) with gr.Row(): with gr.Column(scale=1): input_text = gr.Textbox( lines=5, placeholder="Enter text to analyze...", label="Input Text" ) analyze_btn = gr.Button("Analyze Features", variant="primary") gr.Examples( examples=["WordLift", "Think Different", "Just Do It"], inputs=input_text ) with gr.Column(scale=2): @gr.render(inputs=[features_state, current_model]) def render_features(features, model): if not features: return for token_group in features: gr.Markdown(f"### {token_group['token']}") with gr.Row(): for feature in token_group['features']: btn = gr.Button( f"Feature {feature['id']} (Activation: {feature['activation']:.2f})", elem_classes=["feature-button"] ) btn.click( fn=lambda fid=feature['id']: handle_feature_click(fid, model), outputs=dashboard ) dashboard = gr.HTML() def update_and_analyze(text, model): return analyze_text(text, model) model_choice.change( fn=lambda x: x, inputs=[model_choice], outputs=[current_model] ) analyze_btn.click( fn=update_and_analyze, inputs=[input_text, current_model], outputs=[features_state, dashboard] ) input_text.submit( fn=update_and_analyze, inputs=[input_text, current_model], outputs=[features_state, dashboard] ) if __name__ == "__main__": demo.launch(share=False)