Spaces:
Runtime error
Runtime error
import os | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
import streamlit as st | |
import torch | |
import wandb | |
from PIL import Image | |
from torchvision.models.efficientnet import EfficientNet_B0_Weights | |
from torchvision.models.efficientnet import efficientnet_b0 | |
MODELS_PATH = "./models/" | |
MODEL_FILE = "EFFB0.pb" | |
WANDB_PATH = "gmbreemer/food-project-captum/" | |
WANDB_RUN = "207pa091" | |
def pred_and_plot_image( | |
model: torch.nn.Module, | |
image_file: BytesIO, | |
class_names, | |
transform, | |
): | |
"""Run inference and show the results.""" | |
img = Image.open(image_file) | |
img_transformed = transform(img).unsqueeze(dim=0) | |
model.eval() | |
with torch.inference_mode(): | |
preds = model(img_transformed) | |
probs = torch.softmax(preds, dim=1) | |
class_label = torch.argmax(probs, dim=1) | |
plt.figure() | |
plt.imshow(img) | |
plt.title( | |
f"Prediction: {class_names[class_label]} | Probability: {probs.max():.3f}" | |
) | |
plt.axis(False) | |
def load_saved_model(model_local_path: str): | |
"""Load the downloaded model.""" | |
saved_model = torch.load(model_local_path, map_location="cpu") | |
model = efficientnet_b0() | |
model.classifier = torch.nn.Sequential( | |
torch.nn.Dropout(p=0.2, inplace=True), | |
torch.nn.Linear( | |
in_features=1280, out_features=len(saved_model["class_names"]), bias=True | |
), | |
) | |
model_transforms = EfficientNet_B0_Weights.DEFAULT.transforms() | |
model.load_state_dict(saved_model["model_state_dict"]) | |
return model, model_transforms, saved_model["class_names"] | |
def load_model_from_wandb( | |
wandb_path: str, wandb_run: str, models_path: str, file_name: str | |
) -> str: | |
"""Download a trained model created by a W&B run and store it locally""" | |
wandb.login(key=os.getenv("WANDB_API_KEY")) | |
api = wandb.Api() | |
file = api.run(wandb_path + wandb_run).file(name=file_name) | |
file.download(root=models_path, replace=True) | |
return models_path + file_name | |
""" | |
# Is it a pizza, steak or sushi? | |
### Let's find out: | |
""" | |
uploaded_file = st.file_uploader( | |
label="Select an image", | |
type=["jpg", "jpeg"], | |
help="Select an image to perform inference on", | |
accept_multiple_files=False, | |
) | |
if ( | |
st.button("Tell me what it is", disabled=uploaded_file is None) | |
and uploaded_file is not None | |
): | |
model_local_path = load_model_from_wandb( | |
wandb_path=WANDB_PATH, | |
wandb_run=WANDB_RUN, | |
models_path=MODELS_PATH, | |
file_name=MODEL_FILE, | |
) | |
"Using the model trained by W&B run ", WANDB_RUN, "from: ", WANDB_PATH | |
model_inf, model_transforms, class_names = load_saved_model( | |
model_local_path=model_local_path | |
) | |
pred_and_plot_image( | |
model=model_inf, | |
image_file=uploaded_file, | |
class_names=class_names, | |
transform=model_transforms, | |
) | |
""" | |
### The image shows a: | |
""" | |
st.pyplot(plt.figure(1)) | |