Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
from dataclasses import dataclass | |
from typing import List, Dict, Optional | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class MarketingFeature: | |
"""Structure to hold marketing-relevant feature information""" | |
feature_id: int | |
name: str | |
category: str | |
description: str | |
interpretation_guide: str | |
layer: int | |
threshold: float = 0.1 | |
# Define relevant features | |
MARKETING_FEATURES = [ | |
MarketingFeature( | |
feature_id=35, | |
name="Technical Term Detector", | |
category="technical", | |
description="Detects technical and specialized terminology", | |
interpretation_guide="High activation indicates strong technical focus", | |
layer=20, | |
), | |
MarketingFeature( | |
feature_id=6680, | |
name="Compound Technical Terms", | |
category="technical", | |
description="Identifies complex technical concepts", | |
interpretation_guide="Consider simplifying language if activation is too high", | |
layer=20, | |
), | |
MarketingFeature( | |
feature_id=2, | |
name="SEO Keyword Detector", | |
category="seo", | |
description="Identifies potential SEO keywords", | |
interpretation_guide="High activation suggests strong SEO potential", | |
layer=20, | |
), | |
] | |
class JumpReLUSAE(nn.Module): | |
def __init__(self, d_model, d_sae): | |
super().__init__() | |
self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae)) | |
self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model)) | |
self.threshold = nn.Parameter(torch.zeros(d_sae)) | |
self.b_enc = nn.Parameter(torch.zeros(d_sae)) | |
self.b_dec = nn.Parameter(torch.zeros(d_model)) | |
def encode(self, input_acts): | |
pre_acts = input_acts @ self.W_enc + self.b_enc | |
mask = pre_acts > self.threshold | |
acts = mask * torch.nn.functional.relu(pre_acts) | |
return acts | |
def decode(self, acts): | |
return acts @ self.W_dec + self.b_dec | |
def forward(self, acts): | |
acts = self.encode(acts) | |
recon = self.decode(acts) | |
return recon | |
class MarketingAnalyzer: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.set_grad_enabled(False) # Avoid memory issues | |
self._initialize_model() | |
def _initialize_model(self): | |
try: | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2-2b", device_map="auto" | |
) | |
self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") | |
self.model.eval() | |
logger.info("Model initialized successfully") | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
raise | |
def _load_sae(self, feature_id: int, layer: int = 20): | |
"""Dynamically load a single SAE""" | |
try: | |
path = hf_hub_download( | |
repo_id="google/gemma-scope-2b-pt-res", | |
filename=f"layer_{layer}/width_16k/average_l0_71/params.npz", | |
force_download=False, | |
) | |
params = np.load(path) | |
# Create SAE | |
d_model = params["W_enc"].shape[0] | |
d_sae = params["W_enc"].shape[1] | |
sae = JumpReLUSAE(d_model, d_sae).to(self.device) | |
# Load parameters | |
sae_params = { | |
k: torch.from_numpy(v).to(self.device) for k, v in params.items() | |
} | |
sae.load_state_dict(sae_params) | |
return sae | |
except Exception as e: | |
logger.error(f"Error loading SAE for feature {feature_id}: {str(e)}") | |
return None | |
def _gather_activations(self, text: str, layer: int): | |
inputs = self.tokenizer(text, return_tensors="pt").to(self.device) | |
target_act = None | |
def hook(mod, inputs, outputs): | |
nonlocal target_act | |
target_act = outputs[0] | |
return outputs | |
handle = self.model.model.layers[layer].register_forward_hook(hook) | |
with torch.no_grad(): | |
_ = self.model(**inputs) | |
handle.remove() | |
return target_act, inputs | |
def _get_feature_activations(self, text: str, sae, layer: int = 20): | |
"""Get activations for a single feature""" | |
activations, _ = self._gather_activations(text, layer) | |
sae_acts = sae.encode(activations.to(torch.float32)) | |
sae_acts = sae_acts[:, 1:] # Skip BOS token | |
if sae_acts.numel() > 0: | |
mean_activation = float(sae_acts.mean()) | |
max_activation = float(sae_acts.max()) | |
else: | |
mean_activation = 0.0 | |
max_activation = 0.0 | |
return mean_activation, max_activation | |
def analyze_content(self, text: str) -> Dict: | |
"""Analyze content and find most relevant features""" | |
results = { | |
"text": text, | |
"features": {}, | |
"categories": {}, | |
"recommendations": [], | |
} | |
try: | |
# Start with a set of potential features to explore | |
feature_pool = list(range(1, 16385)) # Full range of features | |
sample_size = 50 # Number of features to sample | |
sampled_features = np.random.choice( | |
feature_pool, sample_size, replace=False | |
) | |
# Test each feature | |
feature_activations = [] | |
for feature_id in sampled_features: | |
sae = self._load_sae(feature_id) | |
if sae is None: | |
continue | |
mean_activation, max_activation = self._get_feature_activations( | |
text, sae | |
) | |
feature_activations.append( | |
{ | |
"feature_id": feature_id, | |
"mean_activation": mean_activation, | |
"max_activation": max_activation, | |
} | |
) | |
# Sort by activation and take top features | |
top_features = sorted( | |
feature_activations, key=lambda x: x["max_activation"], reverse=True | |
)[ | |
:3 | |
] # Keep top 3 features | |
# Analyze top features in detail | |
for feature_data in top_features: | |
feature_id = feature_data["feature_id"] | |
# Get neuronpedia data if available (this would be a placeholder) | |
feature_name = f"Feature {feature_id}" | |
feature_category = "neural" # Default category | |
feature_result = { | |
"name": feature_name, | |
"category": feature_category, | |
"activation_score": feature_data["mean_activation"], | |
"max_activation": feature_data["max_activation"], | |
"interpretation": self._interpret_activation( | |
feature_data["mean_activation"], feature_id | |
), | |
} | |
results["features"][feature_id] = feature_result | |
if feature_category not in results["categories"]: | |
results["categories"][feature_category] = [] | |
results["categories"][feature_category].append(feature_result) | |
# Generate recommendations based on activations | |
if top_features: | |
max_activation = max(f["max_activation"] for f in top_features) | |
if max_activation > 0.8: | |
results["recommendations"].append( | |
f"Strong activation detected in feature {top_features[0]['feature_id']}. " | |
"Consider exploring this aspect further." | |
) | |
elif max_activation < 0.3: | |
results["recommendations"].append( | |
"Low feature activations overall. Content might benefit from more distinctive elements." | |
) | |
except Exception as e: | |
logger.error(f"Error analyzing content: {str(e)}") | |
raise | |
return results | |
def _interpret_activation(self, activation: float, feature_id: int) -> str: | |
"""Interpret activation levels for a feature""" | |
if activation > 0.8: | |
return f"Very strong activation of feature {feature_id}" | |
elif activation > 0.5: | |
return f"Moderate activation of feature {feature_id}" | |
else: | |
return f"Limited activation of feature {feature_id}" | |
def create_gradio_interface(): | |
try: | |
analyzer = MarketingAnalyzer() | |
except Exception as e: | |
logger.error(f"Failed to initialize analyzer: {str(e)}") | |
return gr.Interface( | |
fn=lambda x: "Error: Failed to initialize model. Please check authentication.", | |
inputs=gr.Textbox(), | |
outputs=gr.Textbox(), | |
title="Marketing Content Analyzer (Error)", | |
description="Failed to initialize.", | |
) | |
def analyze(text): | |
results = analyzer.analyze_content(text) | |
output = "# Content Analysis Results\n\n" | |
output += "## Category Scores\n" | |
for category, features in results["categories"].items(): | |
if features: | |
avg_score = np.mean([f["activation_score"] for f in features]) | |
output += f"**{category.title()}**: {avg_score:.2f}\n" | |
output += "\n## Feature Details\n" | |
for feature_id, feature in results["features"].items(): | |
output += f"\n### {feature['name']} (Feature {feature_id})\n" | |
output += f"**Score**: {feature['activation_score']:.2f}\n\n" | |
output += f"**Interpretation**: {feature['interpretation']}\n\n" | |
# Add feature explanation from Neuronpedia reference | |
output += f"[View feature details on Neuronpedia](https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id})\n\n" | |
if results["recommendations"]: | |
output += "\n## Recommendations\n" | |
for rec in results["recommendations"]: | |
output += f"- {rec}\n" | |
feature_id = max( | |
results["features"].items(), key=lambda x: x[1]["activation_score"] | |
)[0] | |
# Build dashboard URL for the highest activating feature | |
dashboard_url = f"https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300" | |
return output, dashboard_url, feature_id | |
with gr.Blocks( | |
theme=gr.themes.Default( | |
font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"], | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="gray", | |
) | |
) as interface: | |
gr.Markdown("# Marketing Content Analyzer") | |
gr.Markdown( | |
"Analyze your marketing content using Gemma Scope's neural features" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter your marketing content here...", | |
label="Marketing Content", | |
) | |
analyze_btn = gr.Button("Analyze", variant="primary") | |
gr.Examples( | |
examples=[ | |
"WordLift is an AI-powered SEO tool", | |
"Our advanced machine learning algorithms optimize your content", | |
"Simple and effective website optimization", | |
], | |
inputs=input_text, | |
) | |
with gr.Column(scale=2): | |
output_text = gr.Markdown(label="Analysis Results") | |
with gr.Group(): | |
gr.Markdown("## Feature Dashboard") | |
feature_id_text = gr.Text( | |
label="Currently viewing feature", show_label=False | |
) | |
dashboard_frame = gr.HTML( | |
value="Analysis results will appear here", | |
label="Feature Dashboard", | |
) | |
def update_dashboard(text): | |
output, dashboard_url, feature_id = analyze(text) | |
return ( | |
output, | |
f"<iframe src='{dashboard_url}' width='100%' height='600px' frameborder='0' style='border: 1px solid #eee; border-radius: 8px;'></iframe>", | |
f"Currently viewing Feature {feature_id} - Most active feature in your content", | |
) | |
analyze_btn.click( | |
fn=update_dashboard, | |
inputs=input_text, | |
outputs=[output_text, dashboard_frame, feature_id_text], | |
) | |
return interface | |
if __name__ == "__main__": | |
iface = create_gradio_interface() | |
iface.launch() | |