Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as transforms | |
import gradio as gr | |
from PIL import Image | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def get_model_name(name, batch_size, learning_rate, epoch): | |
""" Generate a name for the model consisting of all the hyperparameter values | |
Args: | |
config: Configuration object containing the hyperparameters | |
Returns: | |
path: A string with the hyperparameter name and value concatenated | |
""" | |
path = "model_{0}_bs{1}_lr{2}_epoch{3}".format(name, | |
batch_size, | |
learning_rate, | |
epoch) | |
return path | |
class LargeNet(nn.Module): | |
def __init__(self): | |
super(LargeNet, self).__init__() | |
self.name = "large" | |
self.conv1 = nn.Conv2d(3, 5, 5) | |
self.pool = nn.MaxPool2d(2, 2) | |
self.conv2 = nn.Conv2d(5, 10, 5) | |
self.fc1 = nn.Linear(10 * 29 * 29, 32) | |
self.fc2 = nn.Linear(32, 8) | |
def forward(self, x): | |
x = self.pool(F.relu(self.conv1(x))) | |
x = self.pool(F.relu(self.conv2(x))) | |
x = x.view(-1, 10 * 29 * 29) | |
x = F.relu(self.fc1(x)) | |
x = self.fc2(x) | |
x = x.squeeze(1) # Flatten to [batch_size] | |
return x | |
transform = transforms.Compose([ | |
transforms.Resize((128, 128)), # Resize to 128x128 | |
transforms.ToTensor(), # Convert to Tensor | |
transforms.Normalize((0.5,), (0.5,)) # Normalize to [-1, 1] | |
]) | |
def load_model(): | |
net = LargeNet() #small or large network | |
model_path = get_model_name(net.name, batch_size=128, learning_rate=0.001, epoch=29) | |
state = torch.load(model_path) | |
net.load_state_dict(state) | |
net.eval() | |
return net | |
class_names = ["Gasoline_Can", "Pebbels", "pliers", "Screw_Driver", "Toolbox", "Wrench", "other"] | |
def predict(image): | |
model = load_model() | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
_, pred = torch.max(output, 1) | |
return class_names[pred.item()] | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="label", | |
title="Mechanical Tools Classifier", | |
description="Upload an image to classify it as one of the mechanical tools." | |
) | |
if __name__ == "__main__": | |
interface.launch() | |