Spaces:
Sleeping
Sleeping
import cv2 | |
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from torch import nn | |
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
from PIL import Image | |
import io | |
import zipfile | |
import os | |
# --- GlaucomaModel Class --- | |
class GlaucomaModel(object): | |
def __init__(self, | |
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", | |
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation', | |
device=torch.device('cpu')): | |
self.device = device | |
# Classification model for glaucoma | |
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path) | |
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval() | |
# Segmentation model for optic disc and cup | |
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path) | |
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval() | |
# Mapping for class labels | |
self.cls_id2label = self.cls_model.config.id2label | |
def glaucoma_pred(self, image): | |
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt") | |
with torch.no_grad(): | |
inputs.to(self.device) | |
outputs = self.cls_model(**inputs).logits | |
probs = F.softmax(outputs, dim=-1) | |
disease_idx = probs.cpu()[0, :].numpy().argmax() | |
confidence = probs.cpu()[0, disease_idx].item() * 100 | |
return disease_idx, confidence | |
def optic_disc_cup_pred(self, image): | |
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt") | |
with torch.no_grad(): | |
inputs.to(self.device) | |
outputs = self.seg_model(**inputs) | |
logits = outputs.logits.cpu() | |
upsampled_logits = nn.functional.interpolate( | |
logits, size=image.shape[:2], mode="bilinear", align_corners=False | |
) | |
seg_probs = F.softmax(upsampled_logits, dim=1) | |
pred_disc_cup = upsampled_logits.argmax(dim=1)[0] | |
cup_confidence = seg_probs[0, 2, :, :].mean().item() * 100 | |
disc_confidence = seg_probs[0, 1, :, :].mean().item() * 100 | |
return pred_disc_cup.numpy().astype(np.uint8), cup_confidence, disc_confidence | |
def process(self, image): | |
disease_idx, cls_confidence = self.glaucoma_pred(image) | |
disc_cup, cup_confidence, disc_confidence = self.optic_disc_cup_pred(image) | |
try: | |
vcdr = simple_vcdr(disc_cup) | |
except: | |
vcdr = np.nan | |
mask = (disc_cup > 0).astype(np.uint8) | |
x, y, w, h = cv2.boundingRect(mask) | |
padding = max(50, int(0.2 * max(w, h))) | |
x = max(x - padding, 0) | |
y = max(y - padding, 0) | |
w = min(w + 2 * padding, image.shape[1] - x) | |
h = min(h + 2 * padding, image.shape[0] - y) | |
cropped_image = image[y:y+h, x:x+w] if w >= 50 and h >= 50 else image.copy() | |
_, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2) | |
return disease_idx, disc_cup_image, vcdr, cls_confidence, cup_confidence, disc_confidence, cropped_image | |
# --- Utility Functions --- | |
def simple_vcdr(mask): | |
disc_area = np.sum(mask == 1) | |
cup_area = np.sum(mask == 2) | |
if disc_area == 0: | |
return np.nan | |
vcdr = cup_area / disc_area | |
return vcdr | |
def add_mask(image, mask, classes, colors, alpha=0.5): | |
overlay = image.copy() | |
for class_id, color in zip(classes, colors): | |
overlay[mask == class_id] = color | |
output = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0) | |
return output, overlay | |
# --- Streamlit Interface --- | |
def main(): | |
st.set_page_config(layout="wide") | |
st.title("Batch Glaucoma Screening from Retinal Fundus Images") | |
# Explanation for the confidence threshold | |
st.sidebar.write("**Confidence Threshold** (optional): Set a threshold to filter images based on the model's confidence in glaucoma classification.") | |
confidence_threshold = st.sidebar.slider("Confidence Threshold (%)", 0, 100, 70) | |
uploaded_files = st.sidebar.file_uploader("Upload Images", type=['png', 'jpeg', 'jpg'], accept_multiple_files=True) | |
confident_images = [] | |
download_confident_images = [] | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
image = Image.open(uploaded_file).convert('RGB') | |
image_np = np.array(image).astype(np.uint8) | |
with st.spinner(f'Processing {uploaded_file.name}...'): | |
model = GlaucomaModel(device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) | |
disease_idx, disc_cup_image, vcdr, cls_conf, cup_conf, disc_conf, cropped_image = model.process(image_np) | |
# Confidence-based grouping | |
is_confident = cls_conf >= confidence_threshold | |
if is_confident: | |
confident_images.append(uploaded_file.name) | |
download_confident_images.append((cropped_image, uploaded_file.name)) | |
# Display Results | |
with st.container(): | |
st.subheader(f"Results for {uploaded_file.name}") | |
cols = st.columns(4) | |
cols[0].image(image_np, caption="Input Image", use_column_width=True) | |
cols[1].image(disc_cup_image, caption="Disc/Cup Segmentation", use_column_width=True) | |
cols[2].image(image_np, caption="Class Activation Map", use_column_width=True) | |
cols[3].image(cropped_image, caption="Cropped Image", use_column_width=True) | |
st.write(f"**Vertical cup-to-disc ratio:** {vcdr:.04f}") | |
st.write(f"**Category:** {model.cls_id2label[disease_idx]} ({cls_conf:.02f}% confidence)") | |
st.write(f"**Optic Cup Segmentation Confidence:** {cup_conf:.02f}%") | |
st.write(f"**Optic Disc Segmentation Confidence:** {disc_conf:.02f}%") | |
st.write(f"**Confidence Group:** {'Confident' if is_confident else 'Not Confident'}") | |
# Download Button for Confident Images | |
if download_confident_images: | |
with zipfile.ZipFile("confident_cropped_images.zip", "w") as zf: | |
for cropped_image, name in download_confident_images: | |
img_buffer = io.BytesIO() | |
Image.fromarray(cropped_image).save(img_buffer, format="PNG") | |
zf.writestr(f"{name}_cropped.png", img_buffer.getvalue()) | |
# Provide a markdown link to the ZIP file | |
st.sidebar.markdown( | |
f"[Download Confident Cropped Images](./confident_cropped_images.zip)", | |
unsafe_allow_html=True | |
) | |
else: | |
st.sidebar.info("Upload images to begin analysis.") | |
if __name__ == '__main__': | |
main() |