Spaces:
Runtime error
Runtime error
File size: 2,940 Bytes
284f370 09762c2 b1af9b5 09762c2 b1af9b5 09762c2 88b0bf6 09762c2 b1af9b5 09762c2 b1af9b5 09762c2 b1af9b5 09762c2 7b3a932 09762c2 b1af9b5 09762c2 b1af9b5 09762c2 b1af9b5 09762c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
from io import BytesIO
import matplotlib.pyplot as plt
import streamlit as st
import torch
import wandb
from PIL import Image
from torchvision.models.efficientnet import EfficientNet_B0_Weights
from torchvision.models.efficientnet import efficientnet_b0
MODELS_PATH = "./models/"
MODEL_FILE = "EFFB0.pb"
WANDB_PATH = "gmbreemer/food-project-captum/"
WANDB_RUN = "207pa091"
def pred_and_plot_image(
model: torch.nn.Module,
image_file: BytesIO,
class_names,
transform,
):
"""Run inference and show the results."""
img = Image.open(image_file)
img_transformed = transform(img).unsqueeze(dim=0)
model.eval()
with torch.inference_mode():
preds = model(img_transformed)
probs = torch.softmax(preds, dim=1)
class_label = torch.argmax(probs, dim=1)
plt.figure()
plt.imshow(img)
plt.title(
f"Prediction: {class_names[class_label]} | Probability: {probs.max():.3f}"
)
plt.axis(False)
@st.cache
def load_saved_model(model_local_path: str):
"""Load the downloaded model."""
saved_model = torch.load(model_local_path, map_location="cpu")
model = efficientnet_b0()
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(
in_features=1280, out_features=len(saved_model["class_names"]), bias=True
),
)
model_transforms = EfficientNet_B0_Weights.DEFAULT.transforms()
model.load_state_dict(saved_model["model_state_dict"])
return model, model_transforms, saved_model["class_names"]
@st.cache
def load_model_from_wandb(
wandb_path: str, wandb_run: str, models_path: str, file_name: str
) -> str:
"""Download a trained model created by a W&B run and store it locally"""
wandb.login(key=os.getenv("WANDB_API_KEY"))
api = wandb.Api()
file = api.run(wandb_path + wandb_run).file(name=file_name)
file.download(root=models_path, replace=True)
return models_path + file_name
"""
# Is it a pizza, steak or sushi?
### Let's find out:
"""
uploaded_file = st.file_uploader(
label="Select an image",
type=["jpg", "jpeg"],
help="Select an image to perform inference on",
accept_multiple_files=False,
)
if (
st.button("Tell me what it is", disabled=uploaded_file is None)
and uploaded_file is not None
):
model_local_path = load_model_from_wandb(
wandb_path=WANDB_PATH,
wandb_run=WANDB_RUN,
models_path=MODELS_PATH,
file_name=MODEL_FILE,
)
"Using the model trained by W&B run ", WANDB_RUN, "from: ", WANDB_PATH
model_inf, model_transforms, class_names = load_saved_model(
model_local_path=model_local_path
)
pred_and_plot_image(
model=model_inf,
image_file=uploaded_file,
class_names=class_names,
transform=model_transforms,
)
"""
### The image shows a:
"""
st.pyplot(plt.figure(1))
|