|
import torch |
|
import sys |
|
import subprocess |
|
|
|
def check_gpu_status(): |
|
print("Python version:", sys.version) |
|
print("PyTorch version:", torch.__version__) |
|
print("CUDA available:", torch.cuda.is_available()) |
|
print("CUDA version:", torch.version.cuda if torch.cuda.is_available() else "Not available") |
|
|
|
if torch.cuda.is_available(): |
|
print("GPU Device:", torch.cuda.get_device_name(0)) |
|
print("GPU Memory:", torch.cuda.get_device_properties(0).total_memory / 1024**3, "GB") |
|
|
|
try: |
|
nvidia_smi = subprocess.check_output(["nvidia-smi"]) |
|
print("nvidia-smi output:") |
|
print(nvidia_smi.decode()) |
|
except: |
|
print("nvidia-smi not available") |
|
|
|
|
|
print("=== GPU Status Check ===") |
|
check_gpu_status() |
|
print("======================") |
|
|
|
|
|
import gradio as gr |
|
import easyocr |
|
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification |
|
import numpy as np |
|
from PIL import Image |
|
import json |
|
from compliance_rules import ComplianceRules |
|
|
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
if torch.cuda.is_available(): |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
else: |
|
print("Running on CPU") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
try: |
|
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) |
|
print("EasyOCR initialized successfully") |
|
except Exception as e: |
|
print(f"Error initializing EasyOCR: {str(e)}") |
|
reader = easyocr.Reader(['en'], gpu=False) |
|
print("Falling back to CPU for EasyOCR") |
|
|
|
|
|
reader = easyocr.Reader(['en']) |
|
|
|
|
|
compliance_rules = ComplianceRules() |
|
|
|
def extract_text_from_image(image): |
|
"""Extract text from image using EasyOCR""" |
|
try: |
|
result = reader.readtext(np.array(image)) |
|
return " ".join([text[1] for text in result]) |
|
except Exception as e: |
|
print(f"Error in text extraction: {str(e)}") |
|
return "Error extracting text from image" |
|
|
|
def check_compliance(text): |
|
"""Check text for compliance across all regions""" |
|
rules = compliance_rules.get_all_rules() |
|
report = { |
|
"compliant": True, |
|
"violations": [], |
|
"warnings": [], |
|
"channel_risks": { |
|
"email": {"score": 0, "details": []}, |
|
"social": {"score": 0, "details": []}, |
|
"print": {"score": 0, "details": []} |
|
} |
|
} |
|
|
|
for region, region_rules in rules.items(): |
|
|
|
for term_info in region_rules["prohibited_terms"]: |
|
term = term_info["term"].lower() |
|
if term in text.lower() or any(var.lower() in text.lower() for var in term_info["variations"]): |
|
report["compliant"] = False |
|
violation = f"{region}: Prohibited term '{term}' found" |
|
report["violations"].append({ |
|
"region": region, |
|
"type": "prohibited_term", |
|
"term": term, |
|
"severity": term_info["severity"] |
|
}) |
|
|
|
|
|
for channel in report["channel_risks"]: |
|
risk_score = compliance_rules.calculate_risk_score([violation], [], region) |
|
report["channel_risks"][channel]["score"] += risk_score |
|
report["channel_risks"][channel]["details"].append( |
|
f"Prohibited term '{term}' increases {channel} risk" |
|
) |
|
|
|
|
|
for disclaimer in region_rules["required_disclaimers"]: |
|
disclaimer_found = any( |
|
disc_text.lower() in text.lower() |
|
for disc_text in disclaimer["text"] |
|
) |
|
if not disclaimer_found: |
|
warning = f"{region}: Missing {disclaimer['type']} disclaimer" |
|
report["warnings"].append({ |
|
"region": region, |
|
"type": "missing_disclaimer", |
|
"disclaimer_type": disclaimer["type"], |
|
"severity": disclaimer["severity"] |
|
}) |
|
|
|
|
|
for channel in report["channel_risks"]: |
|
risk_score = compliance_rules.calculate_risk_score([], [warning], region) |
|
report["channel_risks"][channel]["score"] += risk_score |
|
report["channel_risks"][channel]["details"].append( |
|
f"Missing {disclaimer['type']} disclaimer affects {channel} risk" |
|
) |
|
|
|
return report |
|
|
|
def analyze_ad_copy(image): |
|
"""Main function to analyze ad copy""" |
|
|
|
text = extract_text_from_image(image) |
|
|
|
|
|
compliance_report = check_compliance(text) |
|
|
|
|
|
report_text = "Compliance Analysis Report\n\n" |
|
report_text += f"Overall Status: {'✅ Compliant' if compliance_report['compliant'] else '❌ Non-Compliant'}\n\n" |
|
|
|
if compliance_report["violations"]: |
|
report_text += "Violations Found:\n" |
|
for violation in compliance_report["violations"]: |
|
report_text += f"• {violation['region']}: {violation['type']} - '{violation['term']}' (Severity: {violation['severity']})\n" |
|
report_text += "\n" |
|
|
|
if compliance_report["warnings"]: |
|
report_text += "Warnings:\n" |
|
for warning in compliance_report["warnings"]: |
|
report_text += f"• {warning['region']}: {warning['disclaimer_type']} (Severity: {warning['severity']})\n" |
|
report_text += "\n" |
|
|
|
report_text += "Channel Risk Assessment:\n" |
|
for channel, risk_info in compliance_report["channel_risks"].items(): |
|
score = risk_info["score"] |
|
risk_level = "Low" if score < 3 else "Medium" if score < 6 else "High" |
|
report_text += f"• {channel.capitalize()}: {risk_level} Risk (Score: {score})\n" |
|
if risk_info["details"]: |
|
for detail in risk_info["details"]: |
|
report_text += f" - {detail}\n" |
|
|
|
return report_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=analyze_ad_copy, |
|
inputs=[ |
|
gr.Image( |
|
type="pil", |
|
label="Upload Marketing Material", |
|
height=300, |
|
width=400, |
|
image_mode="RGB", |
|
scale=1, |
|
source="upload", |
|
tool="select" |
|
) |
|
], |
|
outputs=gr.Textbox(label="Compliance Report", lines=10), |
|
title="Marketing Campaign Compliance Checker", |
|
description="Upload marketing material to check compliance with US (SEC), UK (FCA), and EU financial regulations.", |
|
examples=[], |
|
theme=gr.themes.Base(), |
|
allow_flagging="never" |
|
) |
|
|
|
|
|
try: |
|
iface.launch(debug=True) |
|
except Exception as e: |
|
print(f"Error launching interface: {str(e)}") |