adarsh commited on
Commit
18337e4
·
1 Parent(s): fa4665e

added model

Browse files
app.py CHANGED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # detect.py
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from torchvision.models import resnet50
5
+ from PIL import Image
6
+ import torch.nn as nn
7
+
8
+ # Define the class names - make sure these match your training classes
9
+ CLASS_NAMES = [
10
+ "Apple___Apple_scab",
11
+ "Apple___Black_rot",
12
+ # Add all your class names here...
13
+ ]
14
+
15
+ def load_model(model_path):
16
+ # Initialize the model architecture
17
+ model = resnet50(pretrained=False)
18
+ num_classes = len(CLASS_NAMES)
19
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
20
+
21
+ # Load the state dict
22
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
23
+ model.load_state_dict(state_dict)
24
+ model.eval()
25
+ return model
26
+
27
+ def predict_image(image_path, model):
28
+ """Predict the class of a given image"""
29
+ # Define the same transform as used during training
30
+ transform = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ ])
34
+
35
+ # Load and preprocess the image
36
+ image = Image.open(image_path).convert('RGB')
37
+ image_tensor = transform(image).unsqueeze(0)
38
+
39
+ # Make prediction
40
+ with torch.no_grad():
41
+ outputs = model(image_tensor)
42
+ _, predicted = torch.max(outputs, 1)
43
+
44
+ return CLASS_NAMES[predicted.item()]
45
+
46
+ # streamlit_app.py
47
+ import streamlit as st
48
+ import torch
49
+ import torchvision.transforms as transforms
50
+ from PIL import Image
51
+ import os
52
+ from detect import load_model, predict_image, CLASS_NAMES
53
+
54
+ # Set page config
55
+ st.set_page_config(page_title="Plant Disease Predictor", page_icon="🍃", layout="wide")
56
+
57
+ # Load the model
58
+ @st.cache_resource
59
+ def load_model_cached():
60
+ model_path = 'models/leaf_disease_res50_model_epoch_10.pth'
61
+ model = load_model(model_path)
62
+ return model
63
+
64
+ # Load model at startup
65
+ model = load_model_cached()
66
+
67
+ # Streamlit app
68
+ st.title("Plant Disease Predictor")
69
+ st.write("Upload an image of a plant leaf to predict if it has a disease.")
70
+
71
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
72
+
73
+ if uploaded_file is not None:
74
+ image = Image.open(uploaded_file).convert('RGB')
75
+ st.image(image, caption='Uploaded Image', use_column_width=True)
76
+
77
+ if st.button('Predict'):
78
+ # Show prediction in progress
79
+ with st.spinner('Analyzing image...'):
80
+ # Save the uploaded file temporarily
81
+ with open("temp_image.jpg", "wb") as f:
82
+ f.write(uploaded_file.getbuffer())
83
+
84
+ # Make prediction
85
+ prediction = predict_image("temp_image.jpg", model)
86
+
87
+ # Remove temporary file
88
+ os.remove("temp_image.jpg")
89
+
90
+ # Display result
91
+ st.success(f"Prediction: {prediction}")
92
+
93
+ # Display confidence scores
94
+ transform = transforms.Compose([
95
+ transforms.Resize((224, 224)), # Match the training size
96
+ transforms.ToTensor(),
97
+ ])
98
+
99
+ with torch.no_grad():
100
+ img_tensor = transform(image).unsqueeze(0)
101
+ outputs = model(img_tensor)
102
+ probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
103
+
104
+ # Display top 5 predictions
105
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
106
+ st.write("Top 5 Predictions:")
107
+ for i in range(top5_prob.size(0)):
108
+ st.write(f"{CLASS_NAMES[top5_catid[i]]}: {top5_prob[i].item()*100:.2f}%")
109
+
110
+ # Display list of detectable diseases
111
+ st.write("## List of Detectable Plant Diseases")
112
+ st.write("This model can detect the following plant diseases:")
113
+ for disease in CLASS_NAMES:
114
+ st.write(f"- {disease.replace('___', ' - ')}")
models/leaf_disease_res50_model_epoch_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6550ff3036e8f9503a549a70ef3b6790a283211fd2e87f97a306194e2f5d6eda
3
+ size 94663438