import gradio as gr import open_clip import torch import requests import numpy as np from PIL import Image model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP') def predict(inp): image = preprocess_val(inp).unsqueeze(0) # catgs = [ # "Shirts", # "SetShirtsPants", # "SetJacketsPants", # "Pants", # "Jeans", # "JacketsCoats", # "Shoes", # "Underpants", # "Socks", # "Hats", # "Wallets", # "Bags", # "Scarfs", # "Parasols&Umbrellas", # "Necklaces", # "Towels&Robes", # "WallObjects", # "Rugs", # "Glassware", # "Mugs&Cups", # "OralCare" # ] # text = tokenizer(catgs) # with torch.no_grad(), torch.cuda.amp.autocast(): # image_features = model.encode_image(image) # image_features /= image_features.norm(dim=-1, keepdim=True) # text_features = model.encode_text(text) # text_features /= text_features.norm(dim=-1, keepdim=True) # text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) # max_prob_idx = np.argmax(text_probs) # pred_lbl = catgs[max_prob_idx] # pred_lbl_prob = text_probs[0, max_prob_idx].item() pred_lbl = "clothing" mw = ["men", "women", "boy", "girl"] catgs = [ mw[0] + "s " + pred_lbl, mw[1] + "s " + pred_lbl, mw[2] + "s " + pred_lbl, mw[3] + "s " + pred_lbl ] text = tokenizer(catgs) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = model.encode_image(image) text_features = model.encode_text(text) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) max_prob_idx = np.argmax(text_probs) pred_lbl_f = mw[max_prob_idx] pred_lbl_prob_f = text_probs[0, max_prob_idx].item() # tlt = f"{pred_lbl} <{100.0 * pred_lbl_prob:.1f}%> , {pred_lbl_f} <{100.0 * pred_lbl_prob_f:.1f}%>" tlt = f"{pred_lbl_f} <{100.0 * pred_lbl_prob_f:.1f}%>" return(tlt) gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(), examples=["imgs/cargo.jpg", "imgs/palazzo.jpg", "imgs/leggings.jpg", "imgs/dresspants.jpg"]).launch(share=True)