import gradio as gr from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image import torch # Load the pre-trained ViT model and processor processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') # Using the in21k pre-trained model model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k') # Inference function for predicting with ViT model def predict(image): try: # Ensure the image is in PIL format if isinstance(image, str): image = Image.open(image) # Preprocess the input image using the processor inputs = processor(images=image, return_tensors="pt") # Get the model's predictions with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get the predicted class index (class with the highest logit value) predicted_class_idx = logits.argmax(-1).item() # Get the human-readable label for the predicted class predicted_class_label = model.config.id2label[predicted_class_idx] return f"Predicted class: {predicted_class_label}" except Exception as e: return f"Error: {str(e)}" # Create Gradio Interface (Note the change here: `gr.Image` and `gr.Text`) interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil", label="Upload Image"), outputs=gr.Text(), title="ViT Image Classification (ImageNet)", description="Upload an image to classify it into one of the 1000 ImageNet classes.") # Launch the interface interface.launch()