danyalmalik's picture
added new model, moved net to separate file, updated app to use transformed data
c7d1130
raw
history blame
1.55 kB
import gradio as gr
import torch
from torchvision import transforms
import numpy as np
import os
import huggingface_hub
from net import Net
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
net = Net()
net.to(device)
HF_Token = os.environ['HF_Token']
model = huggingface_hub.cached_download(huggingface_hub.hf_hub_url(
'danyalmalik/sceneryclassifier', '1655684183.7481008_Acc0.87_modelweights.pth'), use_auth_token=HF_Token)
net.load_state_dict(torch.load(model, map_location=device))
mean = np.array([0.5, 0.5, 0.5])
std = np.array([0.25, 0.25, 0.25])
data_transforms = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
labels = ['Buildings', 'Forest', 'Glacier', 'Mountain', 'Sea', 'Street']
title = "Scenery Classifier"
def examples():
number = 8
for i in range(number):
imgs = os.listdir('examples')
egs = [os.path.join('examples/', eg) for eg in imgs]
return egs
def predict(img):
try:
img = data_transforms(img)
img = img.to(device)
with torch.no_grad():
output = net(img)
pred = [output[0][i].item() for i in range(len(labels))]
except:
pred = [0 for i in range(len(labels))]
weightage = {labels[i]: pred[i] for i in range(len(labels))}
return weightage
gr.Interface(fn=predict, inputs=gr.Image(shape=(150, 150), type='pil'),
outputs='label', examples=examples()).launch()