Spaces:
Runtime error
Runtime error
import gradio as gr | |
from torchvision.models import resnet50, ResNet50_Weights | |
from torchvision import transforms | |
import torch.nn as nn | |
import torch | |
def create_model_from_checkpoint(): | |
# Loads a model from a checkpoint | |
model = resnet50() | |
model.fc = nn.Linear(model.fc.in_features, 3) | |
model.load_state_dict(torch.load("best_model")) | |
model.eval() | |
return model | |
def prep_image(img): | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor() | |
]) | |
transform_normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
transformed_img = transform(img) | |
input = transform_normalize(transformed_img) | |
input = input.unsqueeze(0) | |
return input | |
model = create_model_from_checkpoint() | |
labels = [ "benign", "malignant", "normal" ] | |
def predict(img): | |
input = prep_image(img) | |
with torch.no_grad(): | |
prediction = torch.nn.functional.softmax(model(input)[0], dim=0) | |
confidences = {labels[i]: float(prediction[i]) for i in range(3)} | |
return confidences | |
ui = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=3), | |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch() | |
ui.launch(share=True) |