Spaces:
Running
Running
import os | |
import random | |
import logging | |
import gradio as gr | |
from PIL import Image | |
from zipfile import ZipFile | |
from typing import Any, Dict,List | |
from transformers import pipeline | |
class Image_classification: | |
def __init__(self): | |
pass | |
def unzip_image_data(self) -> str: | |
""" | |
Unzips an image dataset into a specified directory. | |
Returns: | |
str: The path to the directory containing the extracted image files. | |
""" | |
try: | |
with ZipFile("image_dataset.zip","r") as extract: | |
directory_path=str("dataset") | |
os.mkdir(directory_path) | |
extract.extractall(f"{directory_path}") | |
return f"{directory_path}" | |
except Exception as e: | |
logging.error(f"An error occurred during extraction: {e}") | |
return "" | |
def example_images(self) -> List[str]: | |
""" | |
Unzips the image dataset and generates a list of paths to the individual image files and use image for showing example | |
Returns: | |
List[str]: A list of file paths to each image in the dataset. | |
""" | |
try: | |
image_dataset_folder = self.unzip_image_data() | |
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'] | |
image_count = len([name for name in os.listdir(image_dataset_folder) if os.path.isfile(os.path.join(image_dataset_folder, name)) and os.path.splitext(name)[1].lower() in image_extensions]) | |
example=[] | |
for i in range(image_count): | |
for name in os.listdir(image_dataset_folder): | |
path=(os.path.join(os.path.dirname(image_dataset_folder),os.path.join(image_dataset_folder,name))) | |
example.append(path) | |
return example | |
except Exception as e: | |
logging.error(f"An error occurred in example images: {e}") | |
return "" | |
def classify(self, image: Image.Image, model: Any) -> Dict[str, float]: | |
""" | |
Classifies an image using a specified model. | |
Args: | |
image (Image.Image): The image to classify. | |
model (Any): The model used for classification. | |
Returns: | |
Dict[str, float]: A dictionary of classification labels and their corresponding scores. | |
""" | |
try: | |
classifier = pipeline("image-classification", model=model) | |
result= classifier(image) | |
return result | |
except Exception as e: | |
logging.error(f"An error occurred during image classification: {e}") | |
raise | |
def format_the_result(self, image: Image.Image, model: Any) -> Dict[str, float]: | |
""" | |
Formats the classification result by retaining the highest score for each label. | |
Args: | |
image (Image.Image): The image to classify. | |
model (Any): The model used for classification. | |
Returns: | |
Dict[str, float]: A dictionary with unique labels and the highest score for each label. | |
""" | |
try: | |
data=self.classify(image,model) | |
new_dict = {} | |
for item in data: | |
label = item['label'] | |
score = item['score'] | |
if label in new_dict: | |
if new_dict[label] < score: | |
new_dict[label] = score | |
else: | |
new_dict[label] = score | |
return new_dict | |
except Exception as e: | |
logging.error(f"An error occurred while formatting the results: {e}") | |
raise | |
def interface(self): | |
with gr.Blocks(css=""" | |
.gradio-container {background: #314755; | |
background: -webkit-linear-gradient(to right, #26a0da, #314755); | |
background: linear-gradient(to right, #26a0da, #314755);} | |
.block svelte-90oupt padded{background:314755; | |
margin:0; | |
padding:0;}""") as demo: | |
gr.HTML(""" | |
<center><h1 style="color:#fff">Image Classification</h1></center>""") | |
exam_img=self.example_images() | |
with gr.Row(): | |
model = gr.Dropdown(["facebook/regnet-x-040","google/vit-large-patch16-384","microsoft/resnet-50",""],label="Choose a model") | |
with gr.Row(): | |
image = gr.Image(type="filepath",sources="upload") | |
with gr.Column(): | |
output=gr.Label() | |
with gr.Row(): | |
button=gr.Button() | |
button.click(self.format_the_result,[image,model],output) | |
gr.Examples( | |
examples=exam_img, | |
inputs=[image], | |
outputs=output, | |
fn=self.format_the_result, | |
cache_examples=False, | |
) | |
demo.launch(debug=True) | |
if __name__=="__main__": | |
image_classification=Image_classification() | |
result=image_classification.interface() |