Spaces:
Running
Running
update
Browse files
app.py
CHANGED
@@ -52,7 +52,7 @@ MARKETING_FEATURES = [
|
|
52 |
|
53 |
class MarketingAnalyzer:
|
54 |
"""Main class for analyzing marketing content using Gemma Scope"""
|
55 |
-
|
56 |
def __init__(self):
|
57 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
# Store model size as instance variable
|
@@ -64,17 +64,17 @@ class MarketingAnalyzer:
|
|
64 |
"""Initialize Gemma model and tokenizer"""
|
65 |
try:
|
66 |
model_name = f"google/gemma-{self.model_size}"
|
67 |
-
|
68 |
# Initialize model and tokenizer with token from environment
|
69 |
self.model = AutoModelForCausalLM.from_pretrained(
|
70 |
model_name,
|
71 |
device_map='auto'
|
72 |
)
|
73 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
-
|
75 |
self.model.eval()
|
76 |
logger.info(f"Initialized model: {model_name}")
|
77 |
-
|
78 |
except Exception as e:
|
79 |
logger.error(f"Error initializing model: {str(e)}")
|
80 |
raise
|
@@ -107,25 +107,25 @@ class MarketingAnalyzer:
|
|
107 |
'categories': {},
|
108 |
'recommendations': []
|
109 |
}
|
110 |
-
|
111 |
try:
|
112 |
# Get model activations
|
113 |
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
114 |
with torch.no_grad():
|
115 |
outputs = self.model(**inputs, output_hidden_states=True)
|
116 |
-
|
117 |
# Analyze each feature
|
118 |
for feature_id, sae_data in self.saes.items():
|
119 |
feature = sae_data['feature']
|
120 |
layer_output = outputs.hidden_states[feature.layer]
|
121 |
-
|
122 |
# Apply SAE
|
123 |
activations = self._apply_sae(
|
124 |
layer_output,
|
125 |
sae_data['params'],
|
126 |
feature.threshold
|
127 |
)
|
128 |
-
|
129 |
# Skip BOS token and handle empty activations
|
130 |
activations = activations[:, 1:] # Skip BOS token
|
131 |
if activations.numel() > 0:
|
@@ -134,7 +134,7 @@ class MarketingAnalyzer:
|
|
134 |
else:
|
135 |
mean_activation = 0.0
|
136 |
max_activation = 0.0
|
137 |
-
|
138 |
# Record results
|
139 |
feature_result = {
|
140 |
'name': feature.name,
|
@@ -146,21 +146,21 @@ class MarketingAnalyzer:
|
|
146 |
feature
|
147 |
)
|
148 |
}
|
149 |
-
|
150 |
results['features'][feature_id] = feature_result
|
151 |
-
|
152 |
# Aggregate by category
|
153 |
if feature.category not in results['categories']:
|
154 |
results['categories'][feature.category] = []
|
155 |
results['categories'][feature.category].append(feature_result)
|
156 |
-
|
157 |
# Generate recommendations
|
158 |
results['recommendations'] = self._generate_recommendations(results)
|
159 |
-
|
160 |
except Exception as e:
|
161 |
logger.error(f"Error analyzing content: {str(e)}")
|
162 |
raise
|
163 |
-
|
164 |
return results
|
165 |
|
166 |
def _apply_sae(
|
@@ -191,18 +191,18 @@ class MarketingAnalyzer:
|
|
191 |
def _generate_recommendations(self, results: Dict) -> List[str]:
|
192 |
"""Generate content recommendations based on analysis"""
|
193 |
recommendations = []
|
194 |
-
|
195 |
try:
|
196 |
# Get technical features
|
197 |
tech_features = [
|
198 |
f for f in results['features'].values()
|
199 |
if f['category'] == 'technical'
|
200 |
]
|
201 |
-
|
202 |
# Calculate average technical score if we have features
|
203 |
if tech_features:
|
204 |
tech_score = np.mean([f['activation_score'] for f in tech_features])
|
205 |
-
|
206 |
if tech_score > 0.8:
|
207 |
recommendations.append(
|
208 |
"Consider simplifying technical language for broader audience"
|
@@ -213,7 +213,7 @@ class MarketingAnalyzer:
|
|
213 |
)
|
214 |
except Exception as e:
|
215 |
logger.error(f"Error generating recommendations: {str(e)}")
|
216 |
-
|
217 |
return recommendations
|
218 |
|
219 |
def create_gradio_interface():
|
@@ -229,42 +229,42 @@ def create_gradio_interface():
|
|
229 |
title="Marketing Content Analyzer (Error)",
|
230 |
description="Failed to initialize. Please check if HF_TOKEN is properly set."
|
231 |
)
|
232 |
-
|
233 |
def analyze(text):
|
234 |
results = analyzer.analyze_content(text)
|
235 |
-
|
236 |
# Format results for display
|
237 |
output = "Content Analysis Results\n\n"
|
238 |
-
|
239 |
# Overall category scores
|
240 |
output += "Category Scores:\n"
|
241 |
for category, features in results['categories'].items():
|
242 |
if features: # Check if we have features for this category
|
243 |
avg_score = np.mean([f['activation_score'] for f in features])
|
244 |
output += f"{category.title()}: {avg_score:.2f}\n"
|
245 |
-
|
246 |
# Feature details
|
247 |
output += "\nFeature Details:\n"
|
248 |
for feature_id, feature in results['features'].items():
|
249 |
output += f"\n{feature['name']}:\n"
|
250 |
output += f"Score: {feature['activation_score']:.2f}\n"
|
251 |
output += f"Interpretation: {feature['interpretation']}\n"
|
252 |
-
|
253 |
# Recommendations
|
254 |
if results['recommendations']:
|
255 |
output += "\nRecommendations:\n"
|
256 |
for rec in results['recommendations']:
|
257 |
output += f"- {rec}\n"
|
258 |
-
|
259 |
return output
|
260 |
-
|
261 |
# Create interface with custom theming
|
262 |
custom_theme = gr.themes.Soft(
|
263 |
primary_hue="indigo",
|
264 |
secondary_hue="blue",
|
265 |
neutral_hue="gray"
|
266 |
)
|
267 |
-
|
268 |
interface = gr.Interface(
|
269 |
fn=analyze,
|
270 |
inputs=gr.Textbox(
|
@@ -283,7 +283,7 @@ def create_gradio_interface():
|
|
283 |
theme=custom_theme
|
284 |
)
|
285 |
)
|
286 |
-
|
287 |
return interface
|
288 |
|
289 |
if __name__ == "__main__":
|
|
|
52 |
|
53 |
class MarketingAnalyzer:
|
54 |
"""Main class for analyzing marketing content using Gemma Scope"""
|
55 |
+
|
56 |
def __init__(self):
|
57 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
# Store model size as instance variable
|
|
|
64 |
"""Initialize Gemma model and tokenizer"""
|
65 |
try:
|
66 |
model_name = f"google/gemma-{self.model_size}"
|
67 |
+
|
68 |
# Initialize model and tokenizer with token from environment
|
69 |
self.model = AutoModelForCausalLM.from_pretrained(
|
70 |
model_name,
|
71 |
device_map='auto'
|
72 |
)
|
73 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
74 |
+
|
75 |
self.model.eval()
|
76 |
logger.info(f"Initialized model: {model_name}")
|
77 |
+
|
78 |
except Exception as e:
|
79 |
logger.error(f"Error initializing model: {str(e)}")
|
80 |
raise
|
|
|
107 |
'categories': {},
|
108 |
'recommendations': []
|
109 |
}
|
110 |
+
|
111 |
try:
|
112 |
# Get model activations
|
113 |
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
114 |
with torch.no_grad():
|
115 |
outputs = self.model(**inputs, output_hidden_states=True)
|
116 |
+
|
117 |
# Analyze each feature
|
118 |
for feature_id, sae_data in self.saes.items():
|
119 |
feature = sae_data['feature']
|
120 |
layer_output = outputs.hidden_states[feature.layer]
|
121 |
+
|
122 |
# Apply SAE
|
123 |
activations = self._apply_sae(
|
124 |
layer_output,
|
125 |
sae_data['params'],
|
126 |
feature.threshold
|
127 |
)
|
128 |
+
|
129 |
# Skip BOS token and handle empty activations
|
130 |
activations = activations[:, 1:] # Skip BOS token
|
131 |
if activations.numel() > 0:
|
|
|
134 |
else:
|
135 |
mean_activation = 0.0
|
136 |
max_activation = 0.0
|
137 |
+
|
138 |
# Record results
|
139 |
feature_result = {
|
140 |
'name': feature.name,
|
|
|
146 |
feature
|
147 |
)
|
148 |
}
|
149 |
+
|
150 |
results['features'][feature_id] = feature_result
|
151 |
+
|
152 |
# Aggregate by category
|
153 |
if feature.category not in results['categories']:
|
154 |
results['categories'][feature.category] = []
|
155 |
results['categories'][feature.category].append(feature_result)
|
156 |
+
|
157 |
# Generate recommendations
|
158 |
results['recommendations'] = self._generate_recommendations(results)
|
159 |
+
|
160 |
except Exception as e:
|
161 |
logger.error(f"Error analyzing content: {str(e)}")
|
162 |
raise
|
163 |
+
|
164 |
return results
|
165 |
|
166 |
def _apply_sae(
|
|
|
191 |
def _generate_recommendations(self, results: Dict) -> List[str]:
|
192 |
"""Generate content recommendations based on analysis"""
|
193 |
recommendations = []
|
194 |
+
|
195 |
try:
|
196 |
# Get technical features
|
197 |
tech_features = [
|
198 |
f for f in results['features'].values()
|
199 |
if f['category'] == 'technical'
|
200 |
]
|
201 |
+
|
202 |
# Calculate average technical score if we have features
|
203 |
if tech_features:
|
204 |
tech_score = np.mean([f['activation_score'] for f in tech_features])
|
205 |
+
|
206 |
if tech_score > 0.8:
|
207 |
recommendations.append(
|
208 |
"Consider simplifying technical language for broader audience"
|
|
|
213 |
)
|
214 |
except Exception as e:
|
215 |
logger.error(f"Error generating recommendations: {str(e)}")
|
216 |
+
|
217 |
return recommendations
|
218 |
|
219 |
def create_gradio_interface():
|
|
|
229 |
title="Marketing Content Analyzer (Error)",
|
230 |
description="Failed to initialize. Please check if HF_TOKEN is properly set."
|
231 |
)
|
232 |
+
|
233 |
def analyze(text):
|
234 |
results = analyzer.analyze_content(text)
|
235 |
+
|
236 |
# Format results for display
|
237 |
output = "Content Analysis Results\n\n"
|
238 |
+
|
239 |
# Overall category scores
|
240 |
output += "Category Scores:\n"
|
241 |
for category, features in results['categories'].items():
|
242 |
if features: # Check if we have features for this category
|
243 |
avg_score = np.mean([f['activation_score'] for f in features])
|
244 |
output += f"{category.title()}: {avg_score:.2f}\n"
|
245 |
+
|
246 |
# Feature details
|
247 |
output += "\nFeature Details:\n"
|
248 |
for feature_id, feature in results['features'].items():
|
249 |
output += f"\n{feature['name']}:\n"
|
250 |
output += f"Score: {feature['activation_score']:.2f}\n"
|
251 |
output += f"Interpretation: {feature['interpretation']}\n"
|
252 |
+
|
253 |
# Recommendations
|
254 |
if results['recommendations']:
|
255 |
output += "\nRecommendations:\n"
|
256 |
for rec in results['recommendations']:
|
257 |
output += f"- {rec}\n"
|
258 |
+
|
259 |
return output
|
260 |
+
|
261 |
# Create interface with custom theming
|
262 |
custom_theme = gr.themes.Soft(
|
263 |
primary_hue="indigo",
|
264 |
secondary_hue="blue",
|
265 |
neutral_hue="gray"
|
266 |
)
|
267 |
+
|
268 |
interface = gr.Interface(
|
269 |
fn=analyze,
|
270 |
inputs=gr.Textbox(
|
|
|
283 |
theme=custom_theme
|
284 |
)
|
285 |
)
|
286 |
+
|
287 |
return interface
|
288 |
|
289 |
if __name__ == "__main__":
|