nexus-main / cataract.py
ariankhalfani's picture
Create cataract.py
ea7d15b verified
import os
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import sqlite3
from io import BytesIO
from scipy.stats import norm
# Load YOLO models
try:
yolo_model_cataract = YOLO('best-cataract-seg.pt')
yolo_model_object_detection = YOLO('best-cataract-od.pt')
print("YOLO models loaded successfully.")
except Exception as e:
print(f"Error loading YOLO models: {e}")
def calculate_ratios(red_values, green_values, blue_values, total_pixels):
if total_pixels == 0:
return 0, 0, 0
red_ratio = np.sum(red_values) / total_pixels
green_ratio = np.sum(green_values) / total_pixels
blue_ratio = np.sum(blue_values) / total_pixels
total_ratio = red_ratio + green_ratio + blue_ratio
if total_ratio > 0:
red_quantity = (red_ratio / total_ratio) * 255
green_quantity = (green_ratio / total_ratio) * 255
blue_quantity = (blue_ratio / total_ratio) * 255
else:
red_quantity, green_quantity, blue_quantity = 0, 0, 0
return red_quantity, green_quantity, blue_quantity
def cataract_staging(red_quantity, green_quantity, blue_quantity):
# Assuming you have already defined your mean and std for each class and each RGB channel
# Example mean and std based on earlier discussion
mean_mature_red = 73.37
std_mature_red = (90.12 - 41.49) / 4
mean_mature_green = 89.48
std_mature_green = (97.67 - 83.39) / 4
mean_mature_blue = 92.15
std_mature_blue = (117.82 - 75.37) / 4
mean_normal_red = 67.84
std_normal_red = (107.02 - 56.19) / 4
mean_normal_green = 84.85
std_normal_green = (89.89 - 80.74) / 4
mean_normal_blue = 102.31
std_normal_blue = (111.34 - 65.58) / 4
mean_immature_red = 68.83
std_immature_red = (85.95 - 41.49) / 4
mean_immature_green = 89.43
std_immature_green = (97.67 - 83.39) / 4
mean_immature_blue = 96.74
std_immature_blue = (117.82 - 78.41) / 4
# Calculate likelihoods for each class
likelihood_mature = (
norm.pdf(red_quantity, mean_mature_red, std_mature_red) *
norm.pdf(green_quantity, mean_mature_green, std_mature_green) *
norm.pdf(blue_quantity, mean_mature_blue, std_mature_blue)
)
likelihood_normal = (
norm.pdf(red_quantity, mean_normal_red, std_normal_red) *
norm.pdf(green_quantity, mean_normal_green, std_normal_green) *
norm.pdf(blue_quantity, mean_normal_blue, std_normal_blue)
)
likelihood_immature = (
norm.pdf(red_quantity, mean_immature_red, std_immature_red) *
norm.pdf(green_quantity, mean_immature_green, std_immature_green) *
norm.pdf(blue_quantity, mean_immature_blue, std_immature_blue)
)
# Define prior probabilities (assuming equal prior for simplicity)
prior_mature = 1/3
prior_normal = 1/3
prior_immature = 1/3
# Apply Bayes' theorem to compute posterior probabilities
posterior_mature = likelihood_mature * prior_mature
posterior_normal = likelihood_normal * prior_normal
posterior_immature = likelihood_immature * prior_immature
# Determine the stage based on maximum posterior probability
stages = {
posterior_mature: "Mature",
posterior_normal: "Normal",
posterior_immature: "Immature"
}
max_posterior = max(posterior_mature, posterior_normal, posterior_immature)
stage = stages[max_posterior]
return stage
def add_watermark(image):
try:
logo = Image.open('image-logo.png').convert("RGBA")
image = image.convert("RGBA")
# Resize logo
basewidth = 100
wpercent = (basewidth / float(logo.size[0]))
hsize = int((float(wpercent) * logo.size[1]))
logo = logo.resize((basewidth, hsize), Image.LANCZOS)
# Position logo
position = (image.width - logo.width - 10, image.height - logo.height - 10)
# Composite image
transparent = Image.new('RGBA', (image.width, image.height), (0, 0, 0, 0))
transparent.paste(image, (0, 0))
transparent.paste(logo, position, mask=logo)
return transparent.convert("RGB")
except Exception as e:
print(f"Error adding watermark: {e}")
return image
def predict_and_visualize(image):
try:
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
orig_size = pil_image.size
results = yolo_model_cataract(pil_image)
raw_response = str(results)
masked_image = np.array(pil_image)
mask_image = np.zeros_like(masked_image)
red_quantity, green_quantity, blue_quantity = 0, 0, 0
total_pixels = 0
if len(results) > 0:
result = results[0]
if hasattr(result, 'masks') and result.masks is not None and len(result.masks) > 0:
mask = np.array(result.masks.data.cpu().squeeze().numpy())
mask_resized = np.array(Image.fromarray(mask).resize(orig_size, Image.NEAREST))
red_mask = np.zeros_like(masked_image)
red_mask[mask_resized > 0.5] = [255, 0, 0]
alpha = 0.5
blended_image = cv2.addWeighted(masked_image, 1 - alpha, red_mask, alpha, 0)
pupil_pixels = np.array(pil_image)[mask_resized > 0.5]
total_pixels = pupil_pixels.shape[0]
red_values = pupil_pixels[:, 0]
green_values = pupil_pixels[:, 1]
blue_values = pupil_pixels[:, 2]
red_quantity, green_quantity, blue_quantity = calculate_ratios(red_values, green_values, blue_values, total_pixels)
stage = cataract_staging(red_quantity, green_quantity, blue_quantity)
# Add text to the blended image
combined_pil_image = Image.fromarray(blended_image)
draw = ImageDraw.Draw(combined_pil_image)
# Load a larger font (adjust the size as needed)
font_size = 48 # Example font size
try:
font = ImageFont.truetype("font.ttf", size=font_size)
except IOError:
font = ImageFont.load_default()
print("Error: cannot open resource, using default font.")
text = f"Red quantity: {red_quantity:.2f}\nGreen quantity: {green_quantity:.2f}\nBlue quantity: {blue_quantity:.2f}\nStage: {stage}"
# Calculate text bounding box
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width, text_height = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1]
text_x = 20
text_y = 40
padding = 10
# Draw a filled rectangle for the background
draw.rectangle(
[text_x - padding, text_y - padding, text_x + text_width + padding, text_y + text_height + padding],
fill="black"
)
# Draw text on top of the rectangle
draw.text((text_x, text_y), text, fill=(255, 255, 255, 255), font=font)
# Add watermark to the image
combined_pil_image_with_watermark = add_watermark(combined_pil_image)
return np.array(combined_pil_image_with_watermark), red_quantity, green_quantity, blue_quantity, raw_response, stage
return image, 0, 0, 0, "No mask detected.", "Unknown"
except Exception as e:
print("Error:", e)
return np.zeros_like(image), 0, 0, 0, str(e), "Error"
def check_duplicate_entry(conn, red_quantity, green_quantity, blue_quantity, stage):
cursor = conn.cursor()
query = '''SELECT COUNT(*) FROM cataract_results WHERE red_quantity=? AND green_quantity=? AND blue_quantity=? AND stage=?'''
cursor.execute(query, (red_quantity, green_quantity, blue_quantity, stage))
count = cursor.fetchone()[0]
return count > 0
def save_cataract_prediction_to_db(image, red_quantity, green_quantity, blue_quantity, stage):
database = "cataract_results.db"
conn = create_connection(database)
if conn:
create_cataract_table(conn)
# Check for duplicate entries
if check_duplicate_entry(conn, red_quantity, green_quantity, blue_quantity, stage):
conn.close()
return "Duplicate entry found, not saving.", "Duplicate entry detected."
sql = '''INSERT INTO cataract_results(image, red_quantity, green_quantity, blue_quantity, stage) VALUES(?,?,?,?,?)'''
cur = conn.cursor()
# Convert the image to bytes
buffered = BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
cur.execute(sql, (img_bytes, red_quantity, green_quantity, blue_quantity, stage))
conn.commit()
conn.close()
return "Data saved successfully", f"Red: {red_quantity}, Green: {green_quantity}, Blue: {blue_quantity}, Stage: {stage}"
return "Failed to save data", "No connection to the database."
def combined_prediction(image):
blended_image, red_quantity, green_quantity, blue_quantity, raw_response, stage = predict_and_visualize(image)
save_message, debug_info = save_cataract_prediction_to_db(Image.fromarray(blended_image), red_quantity, green_quantity, blue_quantity, stage)
return blended_image, red_quantity, green_quantity, blue_quantity, raw_response, stage, save_message, debug_info
def create_connection(db_file):
""" Create a database connection to the SQLite database """
conn = None
try:
conn = sqlite3.connect(db_file)
return conn
except sqlite3.Error as e:
print(e)
return conn
def create_cataract_table(conn):
""" Create the cataract results table if it does not exist """
create_table_sql = """ CREATE TABLE IF NOT EXISTS cataract_results (
id integer PRIMARY KEY,
image blob,
red_quantity real,
green_quantity real,
blue_quantity real,
stage text
); """
try:
cursor = conn.cursor()
cursor.execute(create_table_sql)
except sqlite3.Error as e:
print(e)
def predict_object_detection(image):
try:
image_np = np.array(image)
results = yolo_model_object_detection(image_np)
image_with_boxes = image_np.copy()
raw_predictions = []
for result in results[0].boxes:
label = "Normal" if result.cls.item() == 1 else "Cataract"
confidence = result.conf.item()
xmin, ymin, xmax, ymax = map(int, result.xyxy[0])
cv2.rectangle(image_with_boxes, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
font_scale = 1.0
thickness = 2
text = f'{label} {confidence:.2f}'
(text_width, text_height), baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
cv2.rectangle(image_with_boxes, (xmin, ymin - text_height - baseline), (xmin + text_width, ymin), (0, 0, 0), cv2.FILLED)
cv2.putText(image_with_boxes, text, (xmin, ymin - baseline), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
raw_predictions.append(f"Label: {label}, Confidence: {confidence:.2f}, Box: [{xmin}, {ymin}, {xmax}, {ymax}]")
raw_predictions_str = "\n".join(raw_predictions)
# Convert image_with_boxes to PIL image and add watermark
image_with_boxes_pil = Image.fromarray(image_with_boxes)
image_with_boxes_pil_with_watermark = add_watermark(image_with_boxes_pil)
return np.array(image_with_boxes_pil_with_watermark), raw_predictions_str
except Exception as e:
print("Error in object detection:", e)
return np.zeros_like(image), str(e)