gradio_demo / image_classifier /batch_classifier.py
soutrik
added: testing app and also the workflow file
853a5c8
raw
history blame
2.72 kB
import gradio as gr
import torch
import timm
from PIL import Image
import time
from tqdm import tqdm
import numpy as np
import requests
class ImageClassifier:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = timm.create_model("resnet50.a1_in1k", pretrained=True)
self.model = self.model.to(self.device)
self.model.eval()
data_config = timm.data.resolve_model_data_config(self.model)
self.transform = timm.data.create_transform(**data_config, is_training=False)
url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt"
self.labels = requests.get(url).text.strip().split("\n")
@torch.no_grad()
def predict_batch(self, image_list, progress=gr.Progress(track_tqdm=True)):
if isinstance(image_list, tuple) and len(image_list) == 1:
image_list = [image_list[0]]
if not image_list or image_list[0] is None:
return [[{"none": 1.0}]]
progress(0.1, desc="Starting preprocessing...")
tensors = []
# Process each image in the batch
for image in image_list:
if image is None:
continue
# Convert numpy array to PIL Image
img = Image.fromarray(image).convert("RGB")
tensor = self.transform(img)
tensors.append(tensor)
if not tensors:
return [[{"none": 1.0}]]
progress(0.4, desc="Batching tensors...")
batch = torch.stack(tensors).to(self.device)
progress(0.6, desc="Running inference...")
outputs = self.model(batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
progress(0.8, desc="Processing results...")
batch_results = []
for probs in probabilities:
top5_prob, top5_catid = torch.topk(probs, 5)
result = {
self.labels[idx.item()]: float(prob)
for prob, idx in zip(top5_prob, top5_catid)
}
batch_results.append(result)
progress(1.0, desc="Done!")
# Return results in the required format: list of list of dicts
return [batch_results]
# Create classifier instance
classifier = ImageClassifier()
# Create Gradio interface
demo = gr.Interface(
fn=classifier.predict_batch,
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=5),
title="Advanced Image Classification with Mamba",
description="Upload images for batch classification with the resnet50.a1_in1k model",
batch=True,
max_batch_size=4,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)