File size: 2,363 Bytes
cf64e05
 
 
 
0de6eb2
 
 
 
 
 
cf64e05
6d4f793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf64e05
d2bdb20
 
 
6d4f793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de6eb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
from transformers import pipeline

def classify(image, model_name):
    try:
        pipe = pipeline("image-classification", model=model_name)
        results = pipe(image)
        return {result["label"]: round(result["score"], 2) for result in results}
    except Exception as e:
        return {"Error": str(e)}

# Gradio Blocks Interface
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Custom timm Model Image Classifier 🚀
        
        Explore the power of [timm](https://github.com/rwightman/pytorch-image-models) models for image classification using 
        the Hugging Face [Transformers pipeline](https://huggingface.co./docs/transformers/main_classes/pipelines).
        
        With just a few lines of code, you can load any timm model hosted on the Hugging Face Hub and classify images effortlessly. 
        This application demonstrates how you can use the pipeline API to create a powerful yet minimalistic image classification tool.
        
        ## How to Use
        
        1. Upload an image or use one of the provided examples.
        2. Enter a valid timm model name from the Hugging Face Hub (e.g., `timm/resnet50.a1_in1k`).
        3. View the top predictions and confidence scores!
        """
    )
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload an Image")
            model_name_input = gr.Textbox(
                label="Enter timm Model Name",
                placeholder="e.g., timm/mobilenetv3_large_100.ra_in1k"
            )
        with gr.Column():
            output_label = gr.Label(num_top_classes=3, label="Top Predictions")

    submit_button = gr.Button("Classify")
    submit_button.click(fn=classify, inputs=[image_input, model_name_input], outputs=output_label)
    
    gr.Examples(
        examples=[
            ["cat.jpg", "timm/mobilenetv3_small_100.lamb_in1k"],
            ["cat.jpg", "timm/resnet50.a1_in1k"],
        ],
        inputs=[image_input, model_name_input]
    )

    gr.Markdown(
        """
        ## Learn More
        - Check out the implementation in the `app.py` file to see how easy it is to integrate timm models.
        - Dive into the [official blog post on timm integration](https://huggingface.co./blog/timm-transformers) for more insights.
        """
    )
demo.launch()