Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import timm | |
from PIL import Image | |
import time | |
from tqdm import tqdm | |
import numpy as np | |
import requests | |
class ImageClassifier: | |
def __init__(self): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = timm.create_model("resnet50.a1_in1k", pretrained=True) | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
data_config = timm.data.resolve_model_data_config(self.model) | |
self.transform = timm.data.create_transform(**data_config, is_training=False) | |
url = "https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt" | |
self.labels = requests.get(url).text.strip().split("\n") | |
def predict_batch(self, image_list, progress=gr.Progress(track_tqdm=True)): | |
if isinstance(image_list, tuple) and len(image_list) == 1: | |
image_list = [image_list[0]] | |
if not image_list or image_list[0] is None: | |
return [[{"none": 1.0}]] | |
progress(0.1, desc="Starting preprocessing...") | |
tensors = [] | |
# Process each image in the batch | |
for image in image_list: | |
if image is None: | |
continue | |
# Convert numpy array to PIL Image | |
img = Image.fromarray(image).convert("RGB") | |
tensor = self.transform(img) | |
tensors.append(tensor) | |
if not tensors: | |
return [[{"none": 1.0}]] | |
progress(0.4, desc="Batching tensors...") | |
batch = torch.stack(tensors).to(self.device) | |
progress(0.6, desc="Running inference...") | |
outputs = self.model(batch) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
progress(0.8, desc="Processing results...") | |
batch_results = [] | |
for probs in probabilities: | |
top5_prob, top5_catid = torch.topk(probs, 5) | |
result = { | |
self.labels[idx.item()]: float(prob) | |
for prob, idx in zip(top5_prob, top5_catid) | |
} | |
batch_results.append(result) | |
progress(1.0, desc="Done!") | |
# Return results in the required format: list of list of dicts | |
return [batch_results] | |
# Create classifier instance | |
classifier = ImageClassifier() | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=classifier.predict_batch, | |
inputs=gr.Image(), | |
outputs=gr.Label(num_top_classes=5), | |
title="Advanced Image Classification with Mamba", | |
description="Upload images for batch classification with the resnet50.a1_in1k model", | |
batch=True, | |
max_batch_size=4, | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |