Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from open_clip import create_model, get_tokenizer | |
from templates import openai_imagenet_template | |
model_str = "ViT-B-16" | |
pretrained = "/fs/ess/PAS2136/foundation_model/model/10m/2023_09_22-21_14_04-model_ViT-B-16-lr_0.0001-b_4096-j_8-p_amp/checkpoints/epoch_99.pt" | |
preprocess_img = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711), | |
), | |
] | |
) | |
def get_txt_features(classnames, templates): | |
all_features = [] | |
for classname in classnames: | |
txts = [template(classname) for template in templates] | |
txts = tokenizer(txts) | |
txt_features = model.encode_text(txts) | |
txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) | |
txt_features /= txt_features.norm() | |
all_features.append(txt_features) | |
all_features = torch.stack(all_features, dim=1) | |
return all_features | |
def predict(img, cls_str: str) -> dict[str, float]: | |
classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] | |
txt_features = get_txt_features(classes, openai_imagenet_template) | |
img = preprocess_img(img) | |
img_features = model.encode_image(img.unsqueeze(0)) | |
img_features = F.normalize(img_features, dim=-1) | |
logits = (img_features @ txt_features).squeeze() | |
probs = F.softmax(logits, dim=0).tolist() | |
return {cls: prob for cls, prob in zip(classes, probs)} | |
if __name__ == "__main__": | |
print("Starting.") | |
model = create_model(model_str, pretrained, output_dict=True) | |
print("Created model.") | |
model = torch.compile(model) | |
print("Compiled model.") | |
tokenizer = get_tokenizer(model_str) | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Image(shape=(224, 224)), | |
gr.Textbox( | |
placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True | |
), | |
], | |
outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True), | |
) | |
demo.launch() | |