ViT_Team-A / app.py
iamomtiwari's picture
Update app.py
bb9954b verified
raw
history blame
1.73 kB
import gradio as gr
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
# Load the pre-trained ViT model and processor
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') # Using the in21k pre-trained model
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k')
# Inference function for predicting with ViT model
def predict(image):
try:
# Ensure the image is in PIL format
if isinstance(image, str):
image = Image.open(image)
# Preprocess the input image using the processor
inputs = processor(images=image, return_tensors="pt")
# Get the model's predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get the predicted class index (class with the highest logit value)
predicted_class_idx = logits.argmax(-1).item()
# Get the human-readable label for the predicted class
predicted_class_label = model.config.id2label[predicted_class_idx]
return f"Predicted class: {predicted_class_label}"
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio Interface (Note the change here: `gr.Image` and `gr.Text`)
interface = gr.Interface(fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Text(),
title="ViT Image Classification (ImageNet)",
description="Upload an image to classify it into one of the 1000 ImageNet classes.")
# Launch the interface
interface.launch()