Jsonwu's picture
Update app.py
99e7319
raw
history blame
1.11 kB
import torch
import gradio as gr
import json
from torchvision import transforms
import torch.nn.functional as F
TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript"
LABELS_PATH = "res/class_map_enrico.json"
IMG_SIZE = 128
model = torch.jit.load(TORCHSCRIPT_PATH)
with open(LABELS_PATH, "r") as f:
label2Idx = json.load(f)["label2Idx"]
img_transforms = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def predict(img):
img_input = img_transforms(img).unsqueeze(0)
predictions = F.softmax(model(img_input), dim=-1)[0]
confidences = {}
for label in label2Idx:
confidences[label] = float(predictions[int(label2Idx[label])])
return confidences
example_imgs = [
"res/example.jpg",
"res/screenlane-snapchat-profile.jpg",
"res/screenlane-snapchat-settings.jpg",
"res/example_pair1.jpg",
"res/example_pair2.jpg"
]
interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=example_imgs)
interface.launch()