Spaces:
Sleeping
Sleeping
from models import EfficientNet | |
from utils import get_device | |
import torch | |
import json | |
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import json | |
import timm | |
from torch import nn | |
import torch.nn.functional as F | |
def load_efficientnet_model(model_path: str, device=get_device()): | |
""" | |
Load a PyTorch model checkpoint. | |
Args: | |
model_path: The path of the checkpoint file. | |
device: The device to load the model onto. | |
Returns: | |
The model loaded onto the specified device. | |
""" | |
# Initialize model | |
model = EfficientNet() | |
# Load model weights onto the specified device | |
model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict']) | |
# Set model to evaluation mode | |
model.eval() | |
return model | |
with open('idx_to_class.json', 'r') as f: | |
idx_to_class = json.load(f) | |
def predict_image(array): | |
""" | |
Predict the class of an image. | |
Args: | |
array: The image data as an array. | |
Returns: | |
The predicted class. | |
""" | |
# Convert the image to a PIL Image object | |
input_image = Image.fromarray(array) | |
# Load the model | |
model = load_efficientnet_model('efficientnet_epoch=18_loss=0.0020_val_f1score=0.8993.pth') | |
# Transform the image | |
transform = transforms.Compose([ | |
transforms.Resize(size=(150, 150)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
]) | |
image = transform(input_image).unsqueeze(0) | |
image.to(get_device()) | |
# Predict the class | |
with torch.no_grad(): | |
output = model(image) | |
# Apply softmax to the outputs to convert them into probabilities | |
probabilities = F.softmax(output, dim=1) | |
predicted = probabilities.argmax().item() | |
predicted_class = idx_to_class[str(predicted)] # Make sure your keys in json are string type | |
return predicted_class | |
# Create the image classifier | |
image_classifier = gr.Interface(fn=predict_image, inputs="image", outputs="text", allow_flagging='never') | |
# Launch the image classifier | |
image_classifier.launch(share=False) | |