import torch import torchvision.transforms as transforms from PIL import Image import gradio as gr import warnings from torchvision import models import torch.nn as nn # Define the TransferNet class class TransferNet(nn.Module): def __init__(self, num_classes=2): super(TransferNet, self).__init__() resnet = models.resnet18(pretrained=True) self.features = nn.Sequential(*list(resnet.children())[:-1]) self.fc = nn.Linear(resnet.fc.in_features, num_classes) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # Define transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomRotation(30), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class_labels = {0: 'Diabetic Retinopathy', 1: 'No Diabetic Retinopathy'} # Suppress UserWarnings warnings.filterwarnings("ignore", category=UserWarning) def load_model(model_path): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = TransferNet() model.load_state_dict(torch.load(model_path, map_location=device,weights_only=True)) model.to(device) model.eval() return model, device model_path = "transfer_model.pth" # Path to your model file model, device = load_model(model_path) def predict_image(image): img = Image.open(image).convert('RGB') img = transform(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(img) probabilities = torch.softmax(output, dim=1) predicted_class = torch.argmax(probabilities, dim=1).item() class_name = class_labels[predicted_class] return class_name, probabilities[0].cpu().numpy() # Example images and their descriptions examples = [ ["dr.jpg"], ["No_DR.png"] ] # Create Gradio interface interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="filepath"), outputs=[ gr.Textbox(label="Prediction"), gr.Textbox(label="Probability of Prediction") ], title="Diabetic Retinopathy Classification", description="Upload an image to classify it as Diabetic Retinopathy or No Diabetic Retinopathy.", theme="gstaff/xkcd", examples=examples , ) interface.launch()