File size: 13,087 Bytes
f85532f
 
7e6371a
f85532f
 
 
 
 
 
 
 
 
 
94ca202
f85532f
 
 
94ca202
f85532f
 
 
 
 
 
 
 
94ca202
7e6371a
f85532f
 
 
 
 
 
 
7e6371a
f85532f
 
 
 
 
 
 
7e6371a
f85532f
 
 
 
 
 
 
7e6371a
f85532f
 
 
94ca202
7e6371a
 
 
 
 
 
 
 
 
 
 
 
 
 
574ab91
7e6371a
 
 
 
 
 
 
 
 
 
9186441
f85532f
7e6371a
9186441
f85532f
9186441
f85532f
deaf693
7e6371a
deaf693
7e6371a
9186441
7e6371a
f85532f
 
 
 
e7c964f
 
 
 
 
 
 
 
 
7e6371a
e7c964f
 
 
 
7e6371a
e7c964f
 
 
 
 
7e6371a
e7c964f
 
 
 
f85532f
7e6371a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c964f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f85532f
e7c964f
f85532f
94ca202
 
 
 
f85532f
574ab91
f85532f
e7c964f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f85532f
e7c964f
 
 
 
94ca202
e7c964f
94ca202
f85532f
574ab91
94ca202
574ab91
e7c964f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574ab91
f85532f
 
 
574ab91
f85532f
 
e7c964f
 
9186441
e7c964f
9186441
e7c964f
f85532f
e7c964f
f85532f
94ca202
f85532f
deaf693
 
 
 
 
 
 
 
 
7e6371a
deaf693
574ab91
f85532f
 
574ab91
ad0839d
574ab91
ad0839d
94ca202
7e6371a
94ca202
ad0839d
574ab91
ad0839d
94ca202
ad0839d
 
 
 
 
574ab91
94ca202
ad0839d
94ca202
9186441
574ab91
ad0839d
 
 
 
 
 
 
 
574ab91
7e6371a
 
 
 
 
 
 
 
 
 
 
 
 
 
ad0839d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e6371a
ad0839d
 
e78ab36
ad0839d
 
 
 
e78ab36
 
 
 
ad0839d
 
 
 
 
e78ab36
ad0839d
 
7e6371a
ad0839d
 
7e6371a
ad0839d
7e6371a
574ab91
9186441
f85532f
94ca202
f85532f
 
94ca202
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class MarketingFeature:
    """Structure to hold marketing-relevant feature information"""

    feature_id: int
    name: str
    category: str
    description: str
    interpretation_guide: str
    layer: int
    threshold: float = 0.1


# Define relevant features
MARKETING_FEATURES = [
    MarketingFeature(
        feature_id=35,
        name="Technical Term Detector",
        category="technical",
        description="Detects technical and specialized terminology",
        interpretation_guide="High activation indicates strong technical focus",
        layer=20,
    ),
    MarketingFeature(
        feature_id=6680,
        name="Compound Technical Terms",
        category="technical",
        description="Identifies complex technical concepts",
        interpretation_guide="Consider simplifying language if activation is too high",
        layer=20,
    ),
    MarketingFeature(
        feature_id=2,
        name="SEO Keyword Detector",
        category="seo",
        description="Identifies potential SEO keywords",
        interpretation_guide="High activation suggests strong SEO potential",
        layer=20,
    ),
]


class JumpReLUSAE(nn.Module):
    def __init__(self, d_model, d_sae):
        super().__init__()
        self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
        self.threshold = nn.Parameter(torch.zeros(d_sae))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_model))

    def encode(self, input_acts):
        pre_acts = input_acts @ self.W_enc + self.b_enc
        mask = pre_acts > self.threshold
        acts = mask * torch.nn.functional.relu(pre_acts)
        return acts

    def decode(self, acts):
        return acts @ self.W_dec + self.b_dec

    def forward(self, acts):
        acts = self.encode(acts)
        recon = self.decode(acts)
        return recon


class MarketingAnalyzer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.set_grad_enabled(False)  # Avoid memory issues
        self._initialize_model()

    def _initialize_model(self):
        try:
            self.model = AutoModelForCausalLM.from_pretrained(
                "google/gemma-2-2b", device_map="auto"
            )
            self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
            self.model.eval()
            logger.info("Model initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def _load_sae(self, feature_id: int, layer: int = 20):
        """Dynamically load a single SAE"""
        try:
            path = hf_hub_download(
                repo_id="google/gemma-scope-2b-pt-res",
                filename=f"layer_{layer}/width_16k/average_l0_71/params.npz",
                force_download=False,
            )
            params = np.load(path)

            # Create SAE
            d_model = params["W_enc"].shape[0]
            d_sae = params["W_enc"].shape[1]
            sae = JumpReLUSAE(d_model, d_sae).to(self.device)

            # Load parameters
            sae_params = {
                k: torch.from_numpy(v).to(self.device) for k, v in params.items()
            }
            sae.load_state_dict(sae_params)

            return sae
        except Exception as e:
            logger.error(f"Error loading SAE for feature {feature_id}: {str(e)}")
            return None

    def _gather_activations(self, text: str, layer: int):
        inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
        target_act = None

        def hook(mod, inputs, outputs):
            nonlocal target_act
            target_act = outputs[0]
            return outputs

        handle = self.model.model.layers[layer].register_forward_hook(hook)
        with torch.no_grad():
            _ = self.model(**inputs)
        handle.remove()

        return target_act, inputs

    def _get_feature_activations(self, text: str, sae, layer: int = 20):
        """Get activations for a single feature"""
        activations, _ = self._gather_activations(text, layer)
        sae_acts = sae.encode(activations.to(torch.float32))
        sae_acts = sae_acts[:, 1:]  # Skip BOS token

        if sae_acts.numel() > 0:
            mean_activation = float(sae_acts.mean())
            max_activation = float(sae_acts.max())
        else:
            mean_activation = 0.0
            max_activation = 0.0

        return mean_activation, max_activation

    def analyze_content(self, text: str) -> Dict:
        """Analyze content and find most relevant features"""
        results = {
            "text": text,
            "features": {},
            "categories": {},
            "recommendations": [],
        }

        try:
            # Start with a set of potential features to explore
            feature_pool = list(range(1, 16385))  # Full range of features
            sample_size = 50  # Number of features to sample
            sampled_features = np.random.choice(
                feature_pool, sample_size, replace=False
            )

            # Test each feature
            feature_activations = []
            for feature_id in sampled_features:
                sae = self._load_sae(feature_id)
                if sae is None:
                    continue

                mean_activation, max_activation = self._get_feature_activations(
                    text, sae
                )
                feature_activations.append(
                    {
                        "feature_id": feature_id,
                        "mean_activation": mean_activation,
                        "max_activation": max_activation,
                    }
                )

            # Sort by activation and take top features
            top_features = sorted(
                feature_activations, key=lambda x: x["max_activation"], reverse=True
            )[
                :3
            ]  # Keep top 3 features

            # Analyze top features in detail
            for feature_data in top_features:
                feature_id = feature_data["feature_id"]

                # Get neuronpedia data if available (this would be a placeholder)
                feature_name = f"Feature {feature_id}"
                feature_category = "neural"  # Default category

                feature_result = {
                    "name": feature_name,
                    "category": feature_category,
                    "activation_score": feature_data["mean_activation"],
                    "max_activation": feature_data["max_activation"],
                    "interpretation": self._interpret_activation(
                        feature_data["mean_activation"], feature_id
                    ),
                }

                results["features"][feature_id] = feature_result

                if feature_category not in results["categories"]:
                    results["categories"][feature_category] = []
                results["categories"][feature_category].append(feature_result)

            # Generate recommendations based on activations
            if top_features:
                max_activation = max(f["max_activation"] for f in top_features)
                if max_activation > 0.8:
                    results["recommendations"].append(
                        f"Strong activation detected in feature {top_features[0]['feature_id']}. "
                        "Consider exploring this aspect further."
                    )
                elif max_activation < 0.3:
                    results["recommendations"].append(
                        "Low feature activations overall. Content might benefit from more distinctive elements."
                    )

        except Exception as e:
            logger.error(f"Error analyzing content: {str(e)}")
            raise

        return results

    def _interpret_activation(self, activation: float, feature_id: int) -> str:
        """Interpret activation levels for a feature"""
        if activation > 0.8:
            return f"Very strong activation of feature {feature_id}"
        elif activation > 0.5:
            return f"Moderate activation of feature {feature_id}"
        else:
            return f"Limited activation of feature {feature_id}"


def create_gradio_interface():
    try:
        analyzer = MarketingAnalyzer()
    except Exception as e:
        logger.error(f"Failed to initialize analyzer: {str(e)}")
        return gr.Interface(
            fn=lambda x: "Error: Failed to initialize model. Please check authentication.",
            inputs=gr.Textbox(),
            outputs=gr.Textbox(),
            title="Marketing Content Analyzer (Error)",
            description="Failed to initialize.",
        )

    def analyze(text):
        results = analyzer.analyze_content(text)

        output = "# Content Analysis Results\n\n"

        output += "## Category Scores\n"
        for category, features in results["categories"].items():
            if features:
                avg_score = np.mean([f["activation_score"] for f in features])
                output += f"**{category.title()}**: {avg_score:.2f}\n"

        output += "\n## Feature Details\n"
        for feature_id, feature in results["features"].items():
            output += f"\n### {feature['name']} (Feature {feature_id})\n"
            output += f"**Score**: {feature['activation_score']:.2f}\n\n"
            output += f"**Interpretation**: {feature['interpretation']}\n\n"
            # Add feature explanation from Neuronpedia reference
            output += f"[View feature details on Neuronpedia](https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id})\n\n"

        if results["recommendations"]:
            output += "\n## Recommendations\n"
            for rec in results["recommendations"]:
                output += f"- {rec}\n"

        feature_id = max(
            results["features"].items(), key=lambda x: x[1]["activation_score"]
        )[0]

        # Build dashboard URL for the highest activating feature
        dashboard_url = f"https://neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k/{feature_id}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

        return output, dashboard_url, feature_id

    with gr.Blocks(
        theme=gr.themes.Default(
            font=[gr.themes.GoogleFont("Open Sans"), "Arial", "sans-serif"],
            primary_hue="indigo",
            secondary_hue="blue",
            neutral_hue="gray",
        )
    ) as interface:
        gr.Markdown("# Marketing Content Analyzer")
        gr.Markdown(
            "Analyze your marketing content using Gemma Scope's neural features"
        )

        with gr.Row():
            with gr.Column(scale=1):
                input_text = gr.Textbox(
                    lines=5,
                    placeholder="Enter your marketing content here...",
                    label="Marketing Content",
                )
                analyze_btn = gr.Button("Analyze", variant="primary")
                gr.Examples(
                    examples=[
                        "WordLift is an AI-powered SEO tool",
                        "Our advanced machine learning algorithms optimize your content",
                        "Simple and effective website optimization",
                    ],
                    inputs=input_text,
                )

            with gr.Column(scale=2):
                output_text = gr.Markdown(label="Analysis Results")
                with gr.Group():
                    gr.Markdown("## Feature Dashboard")
                    feature_id_text = gr.Text(
                        label="Currently viewing feature", show_label=False
                    )
                    dashboard_frame = gr.HTML(
                        value="Analysis results will appear here",
                        label="Feature Dashboard",
                    )

        def update_dashboard(text):
            output, dashboard_url, feature_id = analyze(text)
            return (
                output,
                f"<iframe src='{dashboard_url}' width='100%' height='600px' frameborder='0' style='border: 1px solid #eee; border-radius: 8px;'></iframe>",
                f"Currently viewing Feature {feature_id} - Most active feature in your content",
            )

        analyze_btn.click(
            fn=update_dashboard,
            inputs=input_text,
            outputs=[output_text, dashboard_frame, feature_id_text],
        )

    return interface


if __name__ == "__main__":
    iface = create_gradio_interface()
    iface.launch()