File size: 3,791 Bytes
c5bd7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from torchvision import transforms
from PIL import Image
import model_builder
import matplotlib.pyplot as plt
import argparse

def predict_image(image_path: str, 
                  model_path: str,
                  class_names: list):
    """
    Predict class label of an image using a pre-trained model and plot the image with predicted class and probability as title.

    Args:
        image_path (str): Path to the input image.
        model_path (str): Path to the saved PyTorch model.
        class_names (list): List of class names.

    Returns:
        predicted_class (str): Predicted class label.
        probability_percentage (float): Probability percentage of the predicted class.
    """
    # Load the image
    image = Image.open(image_path).convert('RGB')

    # Define transformations for the image
    transform = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor()
    ])

    # Preprocess the image
    image_tensor = transform(image).unsqueeze(0)

    # Load the model
    model = model_builder.TrashClassificationCNNModel(input_shape=3,
                                                      hidden_units=15,
                                                      output_shape=len(class_names)
                                                      )
    
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Make predictions
    with torch.inference_mode():
        output = model(image_tensor)
        probabilities = torch.softmax(output, dim=1).squeeze().tolist()
        predicted_index = torch.argmax(output, 1).item()
        probability_percentage = probabilities[predicted_index] * 100
        predicted_class = class_names[predicted_index]

    # Plot the image with predicted class and probability as title
    plt.imshow(image_tensor.squeeze().permute(1, 2, 0))
    plt.title(f'Predicted Class: {predicted_class}  |  Probability: {probability_percentage:.2f}%', fontdict={'family': 'serif',
                                                                                                              'color':  'black',
                                                                                                              'weight': 'normal',
                                                                                                              'size': 16,})
    plt.axis(False)
    plt.show()

    return predicted_class, probability_percentage

# Arguments to be passed while running
parser = argparse.ArgumentParser(description="For Prediction of an Image using a Loaded Model")
parser.add_argument("--image", type=str, default=None, help="Path to the image to make predictions on.")
parser.add_argument("--model_path", type=str, default=None, help="Path to the Trained Models State Dict.")
args = parser.parse_args()

IMAGE_PATH = args.image
MODEL_PATH = args.model_path

# Conditionals to check if Image path and Model path are entered
if not IMAGE_PATH:
    print("Please Enter Image Path using --image")
    raise SystemExit

if not MODEL_PATH:
    print("Please Enter Model PAth using --model_path")
    raise SystemExit

# Handle any errors related to wrong paths
try:
    # Run the function
    predict_image(image_path=IMAGE_PATH,
                model_path=MODEL_PATH,
                class_names=['Cardboard',
                            'Food Organics',
                            'Glass',
                            'Metal',
                            'Miscellaneous Trash',
                            'Paper',
                            'Plastic',
                            'Textile Trash',
                            'Vegetation'])
except Exception as exception:
    print("INVALID MODEL_PATH OR INVALID IMAGE PATH")
    print(f"\n{exception}")