File size: 1,859 Bytes
853a5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
import timm
from PIL import Image
import requests


class ImageClassifier:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Create model and move to appropriate device
        self.model = timm.create_model("resnet50.a1_in1k", pretrained=True)
        self.model = self.model.to(self.device)
        self.model.eval()

        # Get model specific transforms
        data_config = timm.data.resolve_model_data_config(self.model)
        self.transform = timm.data.create_transform(**data_config, is_training=False)

        # Load ImageNet labels
        url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
        self.labels = requests.get(url).text.strip().split("\n")

    @torch.no_grad()
    def predict(self, image):
        if image is None:
            return None

        # Preprocess image
        img = Image.fromarray(image).convert("RGB")
        img_tensor = self.transform(img).unsqueeze(0).to(self.device)

        # Get prediction
        output = self.model(img_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # Get top 5 predictions
        top5_prob, top5_catid = torch.topk(probabilities, 5)

        return {
            self.labels[idx.item()]: float(prob)
            for prob, idx in zip(top5_prob, top5_catid)
        }


# Create classifier instance
classifier = ImageClassifier()

# Create Gradio interface
demo = gr.Interface(
    fn=classifier.predict,
    inputs=gr.Image(type="numpy", label="Input Image"),
    outputs=gr.Label(num_top_classes=5, label="Top 5 Predictions"),
    title="Basic Image Classification with Mamba",
    description="Upload an image to classify it using the resnet50.a1_in1k model",
)

if __name__ == "__main__":
    demo.launch()