File size: 2,579 Bytes
e02f821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01ab6e1
e02f821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from transformers import CLIPModel, CLIPProcessor
import time
import gradio as gr


def get_zero_shot_classification_tab():
    openai_model_name = "openai/clip-vit-large-patch14"
    openai_model = CLIPModel.from_pretrained(openai_model_name)
    openai_processor = CLIPProcessor.from_pretrained(openai_model_name)

    patrickjohncyh_model_name = "patrickjohncyh/fashion-clip"
    patrickjohncyh_model = CLIPModel.from_pretrained(patrickjohncyh_model_name)
    patrickjohncyh_processor = CLIPProcessor.from_pretrained(patrickjohncyh_model_name)

    model_map = {
        openai_model_name: (openai_model, openai_processor),
        patrickjohncyh_model_name: (patrickjohncyh_model, patrickjohncyh_processor)
    }

    def gradio_process(model_name, image, text):
        (model, processor) = model_map[model_name]
        labels = text.split(", ")
        print (labels)
        start = time.time()
        inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
        outputs = model(**inputs)
        probs = outputs.logits_per_image.softmax(dim=1)[0]
        end = time.time()
        time_spent = end - start
        probs = list(probs)
        results = []
        for i in range(len(labels)):
            results.append(f"{labels[i]} - {probs[i].item():.4f}")
        result = "\n".join(results)

        return [result, time_spent]
    
    with gr.TabItem("Zero-Shot Classification") as zero_shot_image_classification_tab:
        gr.Markdown("# Zero-Shot Image Classification")

        with gr.Row():
            with gr.Column():
                # Input components
                input_image = gr.Image(label="Upload Image", type="pil")
                input_text = gr.Textbox(label="Labels (comma separated)")
                model_selector = gr.Dropdown([openai_model_name, patrickjohncyh_model_name],
                                                label = "Select Model")

                # Process button
                process_btn = gr.Button("Classificate")

            with gr.Column():
                # Output components
                elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
                output_text = gr.Textbox(label="Classification")

        # Connect the input components to the processing function
        process_btn.click(
            fn=gradio_process,
            inputs=[
                model_selector,
                input_image,
                input_text
            ],
            outputs=[output_text, elapsed_result]
        )
    
    return zero_shot_image_classification_tab