File size: 2,717 Bytes
853a5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
import timm
from PIL import Image
import time
from tqdm import tqdm
import numpy as np
import requests


class ImageClassifier:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = timm.create_model("resnet50.a1_in1k", pretrained=True)
        self.model = self.model.to(self.device)
        self.model.eval()

        data_config = timm.data.resolve_model_data_config(self.model)
        self.transform = timm.data.create_transform(**data_config, is_training=False)
        url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
        self.labels = requests.get(url).text.strip().split("\n")

    @torch.no_grad()
    def predict_batch(self, image_list, progress=gr.Progress(track_tqdm=True)):
        if isinstance(image_list, tuple) and len(image_list) == 1:
            image_list = [image_list[0]]

        if not image_list or image_list[0] is None:
            return [[{"none": 1.0}]]

        progress(0.1, desc="Starting preprocessing...")
        tensors = []

        # Process each image in the batch
        for image in image_list:
            if image is None:
                continue
            # Convert numpy array to PIL Image
            img = Image.fromarray(image).convert("RGB")
            tensor = self.transform(img)
            tensors.append(tensor)

        if not tensors:
            return [[{"none": 1.0}]]

        progress(0.4, desc="Batching tensors...")
        batch = torch.stack(tensors).to(self.device)

        progress(0.6, desc="Running inference...")
        outputs = self.model(batch)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)

        progress(0.8, desc="Processing results...")
        batch_results = []
        for probs in probabilities:
            top5_prob, top5_catid = torch.topk(probs, 5)
            result = {
                self.labels[idx.item()]: float(prob)
                for prob, idx in zip(top5_prob, top5_catid)
            }
            batch_results.append(result)

        progress(1.0, desc="Done!")
        # Return results in the required format: list of list of dicts
        return [batch_results]


# Create classifier instance
classifier = ImageClassifier()

# Create Gradio interface
demo = gr.Interface(
    fn=classifier.predict_batch,
    inputs=gr.Image(),
    outputs=gr.Label(num_top_classes=5),
    title="Advanced Image Classification with Mamba",
    description="Upload images for batch classification with the resnet50.a1_in1k model",
    batch=True,
    max_batch_size=4,
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)