import gradio as gr from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import torch # Load the feature extractor and model from Hugging Face feature_extractor = AutoImageProcessor.from_pretrained( "microsoft/swinv2-base-patch4-window16-256" ) model = AutoModelForImageClassification.from_pretrained( "amaye15/SwinV2-Base-Image-Orientation-Fixer" ) def predict_image(image): # Convert the image to the required format and preprocess it inputs = feature_extractor(images=image, return_tensors="pt") # Perform the prediction outputs = model(**inputs) # Get the predicted class probabilities logits = outputs.logits # Calculate softmax to get probabilities probabilities = torch.softmax(logits, dim=-1).squeeze() # Create a dictionary of all class labels and their probabilities result = { model.config.id2label[idx]: prob.item() for idx, prob in enumerate(probabilities) } # Sort the results by probability in descending order sorted_result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) return sorted_result # Enhanced description with a detailed overview of the app description = """ ### Overview This application is a web-based interface built using Gradio that allows users to upload images and receive class predictions with probabilities. It utilizes a pre-trained SwinV2 model from Hugging Face. ### How It Works 1. **Image Upload**: Users upload an image which is then processed and classified by the model. 2. **Feature Extraction**: The image is preprocessed using a feature extractor that converts it into a format suitable for the model. 3. **Prediction**: The model predicts the class probabilities using a softmax function on the output logits. 4. **Results**: The results are displayed as a sorted list of classes with their corresponding probabilities, showing the most likely class first. Enjoy exploring the capabilities of this advanced image classification model! """ # Create the Gradio interface using the updated components and enhanced description iface = gr.Interface( fn=predict_image, # The prediction function inputs=gr.Image(type="pil"), # Accepts images in PIL format outputs=gr.Label(num_top_classes=None), # Outputs all predicted classes title="Image Orientation", # Optional title description=description, # Enhanced description with detailed app overview ) # Launch the Gradio app iface.launch()