File size: 1,033 Bytes
7803ddd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import gradio as gr
import json
from torchvision import transforms
import torch.nn.functional as F
TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript"
LABELS_PATH = "res/class_map_enrico.json"
IMG_SIZE = 128
model = torch.jit.load(TORCHSCRIPT_PATH)
with open(LABELS_PATH, "r") as f:
label2Idx = json.load(f)["label2Idx"]
img_transforms = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def predict(img):
img_input = img_transforms(img).unsqueeze(0)
predictions = F.softmax(model(img_input), dim=-1)[0]
confidences = {}
for label in label2Idx:
confidences[label] = float(predictions[int(label2Idx[label])])
return confidences
example_imgs = [
"examples/example.jpg",
"examples/example_pair1.jpg",
"examples/example_pair2.jpg"
]
interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), examples=example_imgs)
interface.launch()
|