Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import ViTImageProcessor, ViTForImageClassification | |
from PIL import Image | |
import torch | |
# Load the pre-trained ViT model and processor | |
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') # Using the in21k pre-trained model | |
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k') | |
# Inference function for predicting with ViT model | |
def predict(image): | |
try: | |
# Ensure the image is in PIL format | |
if isinstance(image, str): | |
image = Image.open(image) | |
# Preprocess the input image using the processor | |
inputs = processor(images=image, return_tensors="pt") | |
# Get the model's predictions | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get the predicted class index (class with the highest logit value) | |
predicted_class_idx = logits.argmax(-1).item() | |
# Get the human-readable label for the predicted class | |
predicted_class_label = model.config.id2label[predicted_class_idx] | |
return f"Predicted class: {predicted_class_label}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio Interface (Note the change here: `gr.Image` and `gr.Text`) | |
interface = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil", label="Upload Image"), | |
outputs=gr.Text(), | |
title="ViT Image Classification (ImageNet)", | |
description="Upload an image to classify it into one of the 1000 ImageNet classes.") | |
# Launch the interface | |
interface.launch() | |