gmbreemer's picture
Revert bump streamlit
7b3a932
import os
from io import BytesIO
import matplotlib.pyplot as plt
import streamlit as st
import torch
import wandb
from PIL import Image
from torchvision.models.efficientnet import EfficientNet_B0_Weights
from torchvision.models.efficientnet import efficientnet_b0
MODELS_PATH = "./models/"
MODEL_FILE = "EFFB0.pb"
WANDB_PATH = "gmbreemer/food-project-captum/"
WANDB_RUN = "207pa091"
def pred_and_plot_image(
model: torch.nn.Module,
image_file: BytesIO,
class_names,
transform,
):
"""Run inference and show the results."""
img = Image.open(image_file)
img_transformed = transform(img).unsqueeze(dim=0)
model.eval()
with torch.inference_mode():
preds = model(img_transformed)
probs = torch.softmax(preds, dim=1)
class_label = torch.argmax(probs, dim=1)
plt.figure()
plt.imshow(img)
plt.title(
f"Prediction: {class_names[class_label]} | Probability: {probs.max():.3f}"
)
plt.axis(False)
@st.cache
def load_saved_model(model_local_path: str):
"""Load the downloaded model."""
saved_model = torch.load(model_local_path, map_location="cpu")
model = efficientnet_b0()
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(
in_features=1280, out_features=len(saved_model["class_names"]), bias=True
),
)
model_transforms = EfficientNet_B0_Weights.DEFAULT.transforms()
model.load_state_dict(saved_model["model_state_dict"])
return model, model_transforms, saved_model["class_names"]
@st.cache
def load_model_from_wandb(
wandb_path: str, wandb_run: str, models_path: str, file_name: str
) -> str:
"""Download a trained model created by a W&B run and store it locally"""
wandb.login(key=os.getenv("WANDB_API_KEY"))
api = wandb.Api()
file = api.run(wandb_path + wandb_run).file(name=file_name)
file.download(root=models_path, replace=True)
return models_path + file_name
"""
# Is it a pizza, steak or sushi?
### Let's find out:
"""
uploaded_file = st.file_uploader(
label="Select an image",
type=["jpg", "jpeg"],
help="Select an image to perform inference on",
accept_multiple_files=False,
)
if (
st.button("Tell me what it is", disabled=uploaded_file is None)
and uploaded_file is not None
):
model_local_path = load_model_from_wandb(
wandb_path=WANDB_PATH,
wandb_run=WANDB_RUN,
models_path=MODELS_PATH,
file_name=MODEL_FILE,
)
"Using the model trained by W&B run ", WANDB_RUN, "from: ", WANDB_PATH
model_inf, model_transforms, class_names = load_saved_model(
model_local_path=model_local_path
)
pred_and_plot_image(
model=model_inf,
image_file=uploaded_file,
class_names=class_names,
transform=model_transforms,
)
"""
### The image shows a:
"""
st.pyplot(plt.figure(1))