import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import hf_hub_download import numpy as np from dataclasses import dataclass from typing import List, Dict, Optional import logging # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class MarketingFeature: """Structure to hold marketing-relevant feature information""" feature_id: int name: str category: str description: str interpretation_guide: str layer: int threshold: float = 0.1 # Define marketing-relevant features from Gemma Scope MARKETING_FEATURES = [ MarketingFeature( feature_id=35, name="Technical Term Detector", category="technical", description="Detects technical and specialized terminology", interpretation_guide="High activation indicates strong technical focus", layer=20 ), MarketingFeature( feature_id=6680, name="Compound Technical Terms", category="technical", description="Identifies complex technical concepts", interpretation_guide="Consider simplifying language if activation is too high", layer=20 ), MarketingFeature( feature_id=2, name="SEO Keyword Detector", category="seo", description="Identifies potential SEO keywords", interpretation_guide="High activation suggests strong SEO potential", layer=20 ), # Add more relevant features as we discover them ] class MarketingAnalyzer: """Main class for analyzing marketing content using Gemma Scope""" def __init__(self, model_size: str = "2b"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._initialize_model(model_size) self._load_saes() def _initialize_model(self, model_size: str): """Initialize Gemma model and tokenizer""" try: import os model_name = f"google/gemma-{model_size}" # Access HF token from environment variable hf_token = os.environ.get('HF_TOKEN') if not hf_token: logger.warning("HF_TOKEN not found in environment variables") # Initialize model and tokenizer with token self.model = AutoModelForCausalLM.from_pretrained( model_name, token=hf_token, device_map='auto' # Automatically handle device placement ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, token=hf_token ) self.model.eval() # Set to evaluation mode logger.info(f"Initialized model: {model_name}") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def _load_saes(self): """Load relevant SAEs from Gemma Scope""" self.saes = {} for feature in MARKETING_FEATURES: try: # Load SAE parameters for each feature path = hf_hub_download( repo_id=f"google/gemma-scope-{self.model_size}-pt-res", filename=f"layer_{feature.layer}/width_16k/average_l0_71/params.npz" ) params = np.load(path) self.saes[feature.feature_id] = { 'params': {k: torch.from_numpy(v).cuda() for k, v in params.items()}, 'feature': feature } logger.info(f"Loaded SAE for feature {feature.feature_id}") except Exception as e: logger.error(f"Error loading SAE for feature {feature.feature_id}: {str(e)}") continue def analyze_content(self, text: str) -> Dict: """Analyze marketing content using loaded SAEs""" results = { 'text': text, 'features': {}, 'categories': {}, 'recommendations': [] } try: # Get model activations inputs = self.tokenizer(text, return_tensors="pt").to(self.device) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) # Analyze each feature for feature_id, sae_data in self.saes.items(): feature = sae_data['feature'] layer_output = outputs.hidden_states[feature.layer] # Apply SAE activations = self._apply_sae( layer_output, sae_data['params'], feature.threshold ) # Record results feature_result = { 'name': feature.name, 'category': feature.category, 'activation_score': float(activations.mean()), 'max_activation': float(activations.max()), 'interpretation': self._interpret_activation( activations, feature ) } results['features'][feature_id] = feature_result # Aggregate by category if feature.category not in results['categories']: results['categories'][feature.category] = [] results['categories'][feature.category].append(feature_result) # Generate recommendations results['recommendations'] = self._generate_recommendations(results) except Exception as e: logger.error(f"Error analyzing content: {str(e)}") raise return results def _apply_sae( self, activations: torch.Tensor, sae_params: Dict[str, torch.Tensor], threshold: float ) -> torch.Tensor: """Apply SAE to get feature activations""" pre_acts = activations @ sae_params['W_enc'] + sae_params['b_enc'] mask = pre_acts > sae_params['threshold'] acts = mask * torch.nn.functional.relu(pre_acts) return acts def _interpret_activation( self, activations: torch.Tensor, feature: MarketingFeature ) -> str: """Interpret activation patterns for a feature""" mean_activation = float(activations.mean()) if mean_activation > 0.8: return f"Very strong presence of {feature.name.lower()}" elif mean_activation > 0.5: return f"Moderate presence of {feature.name.lower()}" else: return f"Limited presence of {feature.name.lower()}" def _generate_recommendations(self, results: Dict) -> List[str]: """Generate content recommendations based on analysis""" recommendations = [] # Analyze technical complexity tech_score = np.mean([ f['activation_score'] for f in results['features'].values() if f['category'] == 'technical' ]) if tech_score > 0.8: recommendations.append( "Consider simplifying technical language for broader audience" ) elif tech_score < 0.3: recommendations.append( "Could benefit from more specific technical details" ) # Add more recommendation logic as needed return recommendations def create_gradio_interface(): """Create Gradio interface for marketing analysis""" try: analyzer = MarketingAnalyzer() except Exception as e: logger.error(f"Failed to initialize analyzer: {str(e)}") # Provide a more graceful fallback or error message in the interface return gr.Interface( fn=lambda x: "Error: Failed to initialize model. Please check authentication.", inputs=gr.Textbox(), outputs=gr.Textbox(), title="Marketing Content Analyzer (Error)", description="Failed to initialize. Please check if HF_TOKEN is properly set." ) def analyze(text): results = analyzer.analyze_content(text) # Format results for display output = "Content Analysis Results\n\n" # Overall category scores output += "Category Scores:\n" for category, features in results['categories'].items(): avg_score = np.mean([f['activation_score'] for f in features]) output += f"{category.title()}: {avg_score:.2f}\n" # Feature details output += "\nFeature Details:\n" for feature_id, feature in results['features'].items(): output += f"\n{feature['name']}:\n" output += f"Score: {feature['activation_score']:.2f}\n" output += f"Interpretation: {feature['interpretation']}\n" # Recommendations output += "\nRecommendations:\n" for rec in results['recommendations']: output += f"- {rec}\n" return output iface = gr.Interface( fn=analyze, inputs=gr.Textbox( lines=5, placeholder="Enter your marketing content here..." ), outputs=gr.Textbox(), title="Marketing Content Analyzer", description="Analyze your marketing content using Gemma Scope's neural features", examples=[ ["WordLift is an AI-powered SEO tool"], ["Our advanced machine learning algorithms optimize your content"], ["Simple and effective website optimization"] ] ) return iface if __name__ == "__main__": iface = create_gradio_interface() iface.launch()