|
import torch |
|
import gradio as gr |
|
import json |
|
from torchvision import transforms |
|
import torch.nn.functional as F |
|
|
|
TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript" |
|
LABELS_PATH = "res/class_map_enrico.json" |
|
IMG_SIZE = 128 |
|
|
|
model = torch.jit.load(TORCHSCRIPT_PATH) |
|
|
|
with open(LABELS_PATH, "r") as f: |
|
label2Idx = json.load(f)["label2Idx"] |
|
|
|
img_transforms = transforms.Compose([ |
|
transforms.Resize(IMG_SIZE), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
def predict(img): |
|
img_input = img_transforms(img).unsqueeze(0) |
|
predictions = F.softmax(model(img_input), dim=-1)[0] |
|
confidences = {} |
|
for label in label2Idx: |
|
confidences[label] = float(predictions[int(label2Idx[label])]) |
|
|
|
return confidences |
|
|
|
example_imgs = [ |
|
"res/example.jpg", |
|
"res/screenlane-snapchat-profile.jpg", |
|
"res/screenlane-snapchat-settings.jpg", |
|
"res/example_pair1.jpg", |
|
"res/example_pair2.jpg" |
|
] |
|
|
|
interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=example_imgs) |
|
|
|
interface.launch() |
|
|