File size: 2,940 Bytes
284f370
09762c2
 
 
 
 
b1af9b5
09762c2
 
 
b1af9b5
09762c2
 
 
 
 
 
 
 
 
88b0bf6
09762c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1af9b5
 
 
09762c2
b1af9b5
09762c2
b1af9b5
 
09762c2
7b3a932
09762c2
 
 
 
b1af9b5
 
09762c2
 
 
 
 
 
 
 
 
 
 
b1af9b5
09762c2
 
 
b1af9b5
09762c2
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from io import BytesIO

import matplotlib.pyplot as plt
import streamlit as st
import torch
import wandb
from PIL import Image
from torchvision.models.efficientnet import EfficientNet_B0_Weights
from torchvision.models.efficientnet import efficientnet_b0

MODELS_PATH = "./models/"
MODEL_FILE = "EFFB0.pb"
WANDB_PATH = "gmbreemer/food-project-captum/"
WANDB_RUN = "207pa091"


def pred_and_plot_image(
    model: torch.nn.Module,
    image_file: BytesIO,
    class_names,
    transform,
):
    """Run inference and show the results."""
    img = Image.open(image_file)
    img_transformed = transform(img).unsqueeze(dim=0)

    model.eval()
    with torch.inference_mode():
        preds = model(img_transformed)

    probs = torch.softmax(preds, dim=1)
    class_label = torch.argmax(probs, dim=1)

    plt.figure()
    plt.imshow(img)
    plt.title(
        f"Prediction: {class_names[class_label]} | Probability: {probs.max():.3f}"
    )
    plt.axis(False)


@st.cache
def load_saved_model(model_local_path: str):
    """Load the downloaded model."""
    saved_model = torch.load(model_local_path, map_location="cpu")

    model = efficientnet_b0()
    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=0.2, inplace=True),
        torch.nn.Linear(
            in_features=1280, out_features=len(saved_model["class_names"]), bias=True
        ),
    )

    model_transforms = EfficientNet_B0_Weights.DEFAULT.transforms()

    model.load_state_dict(saved_model["model_state_dict"])

    return model, model_transforms, saved_model["class_names"]


@st.cache
def load_model_from_wandb(
    wandb_path: str, wandb_run: str, models_path: str, file_name: str
) -> str:
    """Download a trained model created by a W&B run and store it locally"""
    wandb.login(key=os.getenv("WANDB_API_KEY"))
    api = wandb.Api()
    file = api.run(wandb_path + wandb_run).file(name=file_name)
    file.download(root=models_path, replace=True)

    return models_path + file_name


"""
# Is it a pizza, steak or sushi?

### Let's find out:
"""

uploaded_file = st.file_uploader(
    label="Select an image",
    type=["jpg", "jpeg"],
    help="Select an image to perform inference on",
    accept_multiple_files=False,
)


if (
    st.button("Tell me what it is", disabled=uploaded_file is None)
    and uploaded_file is not None
):
    model_local_path = load_model_from_wandb(
        wandb_path=WANDB_PATH,
        wandb_run=WANDB_RUN,
        models_path=MODELS_PATH,
        file_name=MODEL_FILE,
    )
    "Using the model trained by W&B run ", WANDB_RUN, "from: ", WANDB_PATH

    model_inf, model_transforms, class_names = load_saved_model(
        model_local_path=model_local_path
    )

    pred_and_plot_image(
        model=model_inf,
        image_file=uploaded_file,
        class_names=class_names,
        transform=model_transforms,
    )
    """
    ### The image shows a:
    """
    st.pyplot(plt.figure(1))