John6666 commited on
Commit
e77e97e
·
verified ·
1 Parent(s): 3bb0ac8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -58
app.py CHANGED
@@ -1,59 +1,58 @@
1
- from transformers import pipeline
2
- import gradio as gr
3
-
4
-
5
- clip_models = [
6
- "zer0int/CLIP-GmP-ViT-L-14",
7
- "John6666/zer0int_CLIP-GmP-ViT-L-14",
8
- "openai/clip-vit-large-patch14",
9
- "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
10
- ]
11
-
12
- clip_checkpoint = clip_models[0]
13
- clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
14
-
15
- def postprocess(output):
16
- return {out["label"]: float(out["score"]) for out in output}
17
-
18
-
19
- def infer(image, candidate_labels):
20
- candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
21
- clip_out = clip_detector(image, candidate_labels=candidate_labels)
22
- return postprocess(clip_out)
23
-
24
-
25
- def load_clip_model(modelname):
26
- global clip_detector
27
- try:
28
- clip_detector = pipeline(model=modelname, task="zero-shot-image-classification")
29
- except Exception as e:
30
- raise gr.Error(f"Model load error: {modelname} {e}")
31
- return modelname
32
-
33
-
34
- with gr.Blocks() as demo:
35
- gr.Markdown("# Test CLIP")
36
- with gr.Row():
37
- with gr.Column():
38
- image_input = gr.Image(type="pil")
39
- text_input = gr.Textbox(label="Input a list of labels")
40
- model_input = gr.Dropdown(label="CLIP model", choices=clip_models, value=clip_models[0], allow_custom_value=True, interactive=True)
41
- run_button = gr.Button("Run", visible=True)
42
-
43
- with gr.Column():
44
- clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
45
-
46
- examples = [["./baklava.jpg", "baklava, souffle, tiramisu"], ["./cheetah.jpg", "cat, dog"], ["./cat.png", "cat, dog"]]
47
- gr.Examples(
48
- examples = examples,
49
- inputs=[image_input, text_input],
50
- outputs=[clip_output],
51
- fn=infer,
52
- cache_examples=True
53
- )
54
- run_button.click(fn=infer,
55
- inputs=[image_input, text_input],
56
- outputs=[clip_output])
57
- model_input.change(load_clip_model, [model_input], [model_input])
58
-
59
  demo.launch()
 
1
+ from transformers import pipeline
2
+ import gradio as gr
3
+
4
+
5
+ clip_models = [
6
+ "zer0int/CLIP-GmP-ViT-L-14",
7
+ "openai/clip-vit-large-patch14",
8
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
9
+ ]
10
+
11
+ clip_checkpoint = clip_models[0]
12
+ clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
13
+
14
+ def postprocess(output):
15
+ return {out["label"]: float(out["score"]) for out in output}
16
+
17
+
18
+ def infer(image, candidate_labels):
19
+ candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
20
+ clip_out = clip_detector(image, candidate_labels=candidate_labels)
21
+ return postprocess(clip_out)
22
+
23
+
24
+ def load_clip_model(modelname):
25
+ global clip_detector
26
+ try:
27
+ clip_detector = pipeline(model=modelname, task="zero-shot-image-classification")
28
+ except Exception as e:
29
+ raise gr.Error(f"Model load error: {modelname} {e}")
30
+ return modelname
31
+
32
+
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("# Test CLIP")
35
+ with gr.Row():
36
+ with gr.Column():
37
+ image_input = gr.Image(type="pil")
38
+ text_input = gr.Textbox(label="Input a list of labels")
39
+ model_input = gr.Dropdown(label="CLIP model", choices=clip_models, value=clip_models[0], allow_custom_value=True, interactive=True)
40
+ run_button = gr.Button("Run", visible=True)
41
+
42
+ with gr.Column():
43
+ clip_output = gr.Label(label = "CLIP Output", num_top_classes=3)
44
+
45
+ examples = [["./baklava.jpg", "baklava, souffle, tiramisu"], ["./cheetah.jpg", "cat, dog"], ["./cat.png", "cat, dog"]]
46
+ gr.Examples(
47
+ examples = examples,
48
+ inputs=[image_input, text_input],
49
+ outputs=[clip_output],
50
+ fn=infer,
51
+ cache_examples=True
52
+ )
53
+ run_button.click(fn=infer,
54
+ inputs=[image_input, text_input],
55
+ outputs=[clip_output])
56
+ model_input.change(load_clip_model, [model_input], [model_input])
57
+
 
58
  demo.launch()