|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import warnings |
|
from torchvision import models |
|
import torch.nn as nn |
|
|
|
|
|
class TransferNet(nn.Module): |
|
def __init__(self, num_classes=2): |
|
super(TransferNet, self).__init__() |
|
resnet = models.resnet18(pretrained=True) |
|
self.features = nn.Sequential(*list(resnet.children())[:-1]) |
|
self.fc = nn.Linear(resnet.fc.in_features, num_classes) |
|
|
|
def forward(self, x): |
|
x = self.features(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomVerticalFlip(p=0.5), |
|
transforms.RandomRotation(30), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
class_labels = {0: 'Diabetic Retinopathy', 1: 'No Diabetic Retinopathy'} |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
|
|
def load_model(model_path): |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = TransferNet() |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
return model, device |
|
|
|
model_path = "transfer_model.pth" |
|
model, device = load_model(model_path) |
|
|
|
def predict_image(image): |
|
img = Image.open(image).convert('RGB') |
|
img = transform(img).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(img) |
|
probabilities = torch.softmax(output, dim=1) |
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
|
class_name = class_labels[predicted_class] |
|
return class_name, probabilities[0].cpu().numpy() |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_image, |
|
inputs=gr.Image(type="filepath"), |
|
outputs=[ |
|
gr.Textbox(label="Prediction"), |
|
gr.Textbox(label="Probability of Prediction") |
|
], |
|
title="Diabetic Retinopathy Classification", |
|
description="Upload an image to classify it as Diabetic Retinopathy or No Diabetic Retinopathy.", |
|
) |
|
|
|
interface.launch() |