Samuel Stevens commited on
Commit
d1c1a86
·
1 Parent(s): d86aa61

add app.py

Browse files
Files changed (2) hide show
  1. app.py +73 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision import transforms
5
+
6
+ from open_clip import create_model, get_tokenizer
7
+ from open_clip.training.imagenet_zeroshot_data import openai_imagenet_template
8
+
9
+ model_str = "ViT-B-16"
10
+ pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt"
11
+
12
+ preprocess_img = transforms.Compose(
13
+ [
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(
16
+ mean=(0.48145466, 0.4578275, 0.40821073),
17
+ std=(0.26862954, 0.26130258, 0.27577711),
18
+ ),
19
+ ]
20
+ )
21
+
22
+
23
+ @torch.no_grad()
24
+ def get_txt_features(classnames, templates):
25
+ all_features = []
26
+ for classname in classnames:
27
+ txts = [template(classname) for template in templates]
28
+ txts = tokenizer(txts)
29
+ txt_features = model.encode_text(txts)
30
+ txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
31
+ txt_features /= txt_features.norm()
32
+ all_features.append(txt_features)
33
+ all_features = torch.stack(all_features, dim=1)
34
+ return all_features
35
+
36
+
37
+ @torch.no_grad()
38
+ def predict(img, cls_str: str) -> dict[str, float]:
39
+ classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
40
+ txt_features = get_txt_features(classes, openai_imagenet_template)
41
+
42
+ img = preprocess_img(img)
43
+
44
+ img_features = model.encode_image(img.unsqueeze(0))
45
+ img_features = F.normalize(img_features, dim=-1)
46
+ logits = (img_features @ txt_features).squeeze()
47
+ probs = F.softmax(logits, dim=0).tolist()
48
+ return {cls: prob for cls, prob in zip(classes, probs)}
49
+
50
+
51
+ if __name__ == "__main__":
52
+ print("Starting.")
53
+ model = create_model(model_str, pretrained, output_dict=True)
54
+ print("Created model.")
55
+
56
+ model = torch.compile(model)
57
+ print("Compiled model.")
58
+
59
+ tokenizer = get_tokenizer(model_str)
60
+
61
+ demo = gr.Interface(
62
+ fn=predict,
63
+ inputs=[
64
+ gr.Image(shape=(224, 224)),
65
+ gr.Textbox(
66
+ placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True
67
+ ),
68
+ ],
69
+ outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
70
+ )
71
+
72
+ demo.launch()
73
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ open_clip_torch
2
+ torchvision
3
+ torch
4
+ gradio