asmaa1's picture
Create app.py
82ba409 verified
raw
history blame
2.27 kB
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import warnings
from torchvision import models
import torch.nn as nn
# Define the TransferNet class
class TransferNet(nn.Module):
def __init__(self, num_classes=2):
super(TransferNet, self).__init__()
resnet = models.resnet18(pretrained=True)
self.features = nn.Sequential(*list(resnet.children())[:-1])
self.fc = nn.Linear(resnet.fc.in_features, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Define transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomRotation(30),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_labels = {0: 'Diabetic Retinopathy', 1: 'No Diabetic Retinopathy'}
# Suppress UserWarnings
warnings.filterwarnings("ignore", category=UserWarning)
def load_model(model_path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransferNet()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model, device
model_path = "transfer_model.pth" # Path to your model file
model, device = load_model(model_path)
def predict_image(image):
img = Image.open(image).convert('RGB')
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(img)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_name = class_labels[predicted_class]
return class_name, probabilities[0].cpu().numpy()
# Create Gradio interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="filepath"),
outputs=[
gr.Textbox(label="Prediction"),
gr.Textbox(label="Probability of Prediction")
],
title="Diabetic Retinopathy Classification",
description="Upload an image to classify it as Diabetic Retinopathy or No Diabetic Retinopathy.",
)
interface.launch()