File size: 1,018 Bytes
7803ddd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14eb6a5
 
 
7803ddd
 
e97b182
7803ddd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import gradio as gr
import json
from torchvision import transforms
import torch.nn.functional as F

TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript"
LABELS_PATH = "res/class_map_enrico.json"
IMG_SIZE = 128

model = torch.jit.load(TORCHSCRIPT_PATH)

with open(LABELS_PATH, "r") as f:
	label2Idx = json.load(f)["label2Idx"]
	
img_transforms = transforms.Compose([
	transforms.Resize(IMG_SIZE),
	transforms.ToTensor(),
	transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
	
def predict(img):
	img_input = img_transforms(img).unsqueeze(0)
	predictions = F.softmax(model(img_input), dim=-1)[0]
	confidences = {}
	for label in label2Idx:
		confidences[label] = float(predictions[int(label2Idx[label])])
		
	return confidences
	
example_imgs = [
    "res/example.jpg",
	"res/example_pair1.jpg",
	"res/example_pair2.jpg"
]

interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=example_imgs)

interface.launch()