import streamlit as st from torchvision import transforms import torch import PIL from streamlit_image_select import image_select #Transforming the Input Image img_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) #Loading the model model_pth = Path("model\MobileNet_v3.pt") model = torch.load(model_pth, map_location=torch.device('cpu')) model.eval() #Mushroom Classes classes = ['Boletus', 'Lactarius', 'Russula'] #Function that returns the type of Mushrooom def get_type(image): image = PIL.Image.open(image) image = img_transform(image) image = torch.unsqueeze(image, 0) out = model(image) ans = classes[out.argmax(-1)[0]] return "The Mushrooom Type is "+ans st.title("Mushroom Classifier") st.text("Upload your Image and find the category of mushroom it belongs to") #Sidebar to upload and select Images with st.sidebar: usr_img = st.file_uploader("Upload the Mushroom Image") c = st.container() img = image_select(label="Select From Examples", images=['samples\Russula.jpg','samples\Boletus.jpg','samples\Lactarius.jpg'], captions=["Russula", "Boletus", "Lactarius"],) #Displaying Image if usr_img: st.image(usr_img, width=500) else: st.image(img, width=500) #Button to Classify if st.button("Classify"): text = get_type(usr_img if usr_img else img) st.text(text)