|
|
|
import streamlit as st |
|
from torchvision import transforms |
|
import torch |
|
import PIL |
|
from streamlit_image_select import image_select |
|
import pathlib |
|
import platform |
|
plt = platform.system() |
|
if plt == 'Linux': |
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
|
img_transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
model_pth = Path("model\MobileNet_v3.pt") |
|
model = torch.load(model_pth, map_location=torch.device('cpu')) |
|
model.eval() |
|
|
|
|
|
classes = ['Boletus', 'Lactarius', 'Russula'] |
|
|
|
|
|
def get_type(image): |
|
image = PIL.Image.open(image) |
|
image = img_transform(image) |
|
image = torch.unsqueeze(image, 0) |
|
out = model(image) |
|
ans = classes[out.argmax(-1)[0]] |
|
return "The Mushrooom Type is "+ans |
|
|
|
st.title("Mushroom Classifier") |
|
st.text("Upload your Image and find the category of mushroom it belongs to") |
|
|
|
|
|
with st.sidebar: |
|
usr_img = st.file_uploader("Upload the Mushroom Image") |
|
c = st.container() |
|
img = image_select(label="Select From Examples", |
|
images=['samples\Russula.jpg','samples\Boletus.jpg','samples\Lactarius.jpg'], |
|
captions=["Russula", "Boletus", "Lactarius"],) |
|
|
|
if usr_img: |
|
st.image(usr_img, width=500) |
|
else: |
|
st.image(img, width=500) |
|
|
|
|
|
if st.button("Classify"): |
|
text = get_type(usr_img if usr_img else img) |
|
st.text(text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|