Ozgur Unlu
more tests to see the probşem with enabling hardware
4740b6e
raw
history blame
7.17 kB
import torch
import sys
import subprocess
def check_gpu_status():
print("Python version:", sys.version)
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda if torch.cuda.is_available() else "Not available")
if torch.cuda.is_available():
print("GPU Device:", torch.cuda.get_device_name(0))
print("GPU Memory:", torch.cuda.get_device_properties(0).total_memory / 1024**3, "GB")
try:
nvidia_smi = subprocess.check_output(["nvidia-smi"])
print("nvidia-smi output:")
print(nvidia_smi.decode())
except:
print("nvidia-smi not available")
# Run GPU check at startup
print("=== GPU Status Check ===")
check_gpu_status()
print("======================")
# Rest of your imports
import gradio as gr
import easyocr
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification
import numpy as np
from PIL import Image
import json
from compliance_rules import ComplianceRules
# Print GPU information for debugging
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
print("Running on CPU")
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize OCR reader with device specification
try:
reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
print("EasyOCR initialized successfully")
except Exception as e:
print(f"Error initializing EasyOCR: {str(e)}")
reader = easyocr.Reader(['en'], gpu=False)
print("Falling back to CPU for EasyOCR")
# Initialize OCR reader
reader = easyocr.Reader(['en'])
# Initialize compliance rules
compliance_rules = ComplianceRules()
def extract_text_from_image(image):
"""Extract text from image using EasyOCR"""
try:
result = reader.readtext(np.array(image))
return " ".join([text[1] for text in result])
except Exception as e:
print(f"Error in text extraction: {str(e)}")
return "Error extracting text from image"
def check_compliance(text):
"""Check text for compliance across all regions"""
rules = compliance_rules.get_all_rules()
report = {
"compliant": True,
"violations": [],
"warnings": [],
"channel_risks": {
"email": {"score": 0, "details": []},
"social": {"score": 0, "details": []},
"print": {"score": 0, "details": []}
}
}
for region, region_rules in rules.items():
# Check prohibited terms
for term_info in region_rules["prohibited_terms"]:
term = term_info["term"].lower()
if term in text.lower() or any(var.lower() in text.lower() for var in term_info["variations"]):
report["compliant"] = False
violation = f"{region}: Prohibited term '{term}' found"
report["violations"].append({
"region": region,
"type": "prohibited_term",
"term": term,
"severity": term_info["severity"]
})
# Update channel risks
for channel in report["channel_risks"]:
risk_score = compliance_rules.calculate_risk_score([violation], [], region)
report["channel_risks"][channel]["score"] += risk_score
report["channel_risks"][channel]["details"].append(
f"Prohibited term '{term}' increases {channel} risk"
)
# Check required disclaimers
for disclaimer in region_rules["required_disclaimers"]:
disclaimer_found = any(
disc_text.lower() in text.lower()
for disc_text in disclaimer["text"]
)
if not disclaimer_found:
warning = f"{region}: Missing {disclaimer['type']} disclaimer"
report["warnings"].append({
"region": region,
"type": "missing_disclaimer",
"disclaimer_type": disclaimer["type"],
"severity": disclaimer["severity"]
})
# Update channel risks
for channel in report["channel_risks"]:
risk_score = compliance_rules.calculate_risk_score([], [warning], region)
report["channel_risks"][channel]["score"] += risk_score
report["channel_risks"][channel]["details"].append(
f"Missing {disclaimer['type']} disclaimer affects {channel} risk"
)
return report
def analyze_ad_copy(image):
"""Main function to analyze ad copy"""
# Extract text from image
text = extract_text_from_image(image)
# Check compliance
compliance_report = check_compliance(text)
# Generate readable report
report_text = "Compliance Analysis Report\n\n"
report_text += f"Overall Status: {'✅ Compliant' if compliance_report['compliant'] else '❌ Non-Compliant'}\n\n"
if compliance_report["violations"]:
report_text += "Violations Found:\n"
for violation in compliance_report["violations"]:
report_text += f"• {violation['region']}: {violation['type']} - '{violation['term']}' (Severity: {violation['severity']})\n"
report_text += "\n"
if compliance_report["warnings"]:
report_text += "Warnings:\n"
for warning in compliance_report["warnings"]:
report_text += f"• {warning['region']}: {warning['disclaimer_type']} (Severity: {warning['severity']})\n"
report_text += "\n"
report_text += "Channel Risk Assessment:\n"
for channel, risk_info in compliance_report["channel_risks"].items():
score = risk_info["score"]
risk_level = "Low" if score < 3 else "Medium" if score < 6 else "High"
report_text += f"• {channel.capitalize()}: {risk_level} Risk (Score: {score})\n"
if risk_info["details"]:
for detail in risk_info["details"]:
report_text += f" - {detail}\n"
return report_text
# Create Gradio interface
iface = gr.Interface(
fn=analyze_ad_copy,
inputs=[
gr.Image(
type="pil",
label="Upload Marketing Material",
height=300,
width=400,
image_mode="RGB",
scale=1,
source="upload",
tool="select"
)
],
outputs=gr.Textbox(label="Compliance Report", lines=10),
title="Marketing Campaign Compliance Checker",
description="Upload marketing material to check compliance with US (SEC), UK (FCA), and EU financial regulations.",
examples=[],
theme=gr.themes.Base(),
allow_flagging="never"
)
# Launch the app with additional error handling
try:
iface.launch(debug=True)
except Exception as e:
print(f"Error launching interface: {str(e)}")