Spaces:
Build error
Build error
import os | |
import numpy as np | |
import time | |
import random | |
import torch | |
import torchvision.transforms as transforms | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from models import get_model | |
from dotmap import DotMap | |
from PIL import Image | |
#os.environ['TERM'] = 'linux' | |
#os.environ['TERMINFO'] = '/etc/terminfo' | |
# args | |
args = DotMap() | |
args.deploy = 'vanilla' | |
args.arch = 'dino_small_patch16' | |
args.no_pretrain = True | |
args.resume = 'https://huggingface.co./hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth' | |
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY' | |
args.cx = '06d75168141bc47f1' | |
# model | |
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = get_model(args) | |
model.to(device) | |
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu') | |
model.load_state_dict(checkpoint['model'], strict=True) | |
# image transforms | |
def test_transform(): | |
def _convert_image_to_rgb(im): | |
return im.convert('RGB') | |
return transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
_convert_image_to_rgb, | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
preprocess = test_transform() | |
def denormalize(x, mean, std): | |
# 3, H, W | |
t = x.clone() | |
t.mul_(std).add_(mean) | |
return torch.clamp(t, 0, 1) | |
# Google image search | |
from google_images_search import GoogleImagesSearch | |
class MyGIS(GoogleImagesSearch): | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
return | |
# define search params | |
# option for commonly used search param are shown below for easy reference. | |
# For param marked with '##': | |
# - Multiselect is currently not feasible. Choose ONE option only | |
# - This param can also be omitted from _search_params if you do not wish to define any value | |
_search_params = { | |
'q': '...', | |
'num': 10, | |
'fileType': 'png', #'jpg|gif|png', | |
'rights': 'cc_publicdomain', #'cc_publicdomain|cc_attribute|cc_sharealike|cc_noncommercial|cc_nonderived', | |
#'safe': 'active|high|medium|off|safeUndefined', ## | |
'imgType': 'photo', #'clipart|face|lineart|stock|photo|animated|imgTypeUndefined', ## | |
#'imgSize': 'huge|icon|large|medium|small|xlarge|xxlarge|imgSizeUndefined', ## | |
#'imgDominantColor': 'black|blue|brown|gray|green|orange|pink|purple|red|teal|white|yellow|imgDominantColorUndefined', ## | |
'imgColorType': 'color', #'color|gray|mono|trans|imgColorTypeUndefined' ## | |
} | |
# Gradio UI | |
def inference(query, labels, n_supp=10, | |
file_type='png', rights='cc_publicdomain', | |
image_type='photo', color_type='color'): | |
''' | |
query: PIL image | |
labels: list of class names | |
''' | |
labels = labels.split(',') | |
n_supp = int(n_supp) | |
_search_params['num'] = n_supp | |
_search_params['fileType'] = file_type | |
_search_params['rights'] = rights | |
_search_params['imgType'] = image_type | |
_search_params['imgColorType'] = color_type | |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4)) | |
with torch.no_grad(): | |
# query image | |
query = preprocess(query).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 3, H, W) | |
supp_x = [] | |
supp_y = [] | |
# search support images | |
for idx, y in enumerate(labels): | |
gis = GoogleImagesSearch(args.api_key, args.cx) | |
_search_params['q'] = y | |
gis.search(search_params=_search_params, custom_image_name='my_image') | |
gis._custom_image_name = 'my_image' # fix: image name sometimes too long | |
for j, x in enumerate(gis.results()): | |
x.download('./') | |
x_im = Image.open(x.path) | |
# vis | |
axs[idx, j].imshow(x_im) | |
axs[idx, j].set_title(f'{y}{j}:{x.url}') | |
axs[idx, j].axis('off') | |
x_im = preprocess(x_im) # (3, H, W) | |
supp_x.append(x_im) | |
supp_y.append(idx) | |
print('Searching for support images is done.') | |
supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W) | |
supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels) | |
with torch.cuda.amp.autocast(True): | |
output = model(supp_x, supp_y, query) # (1, 1, n_labels) | |
probs = output.softmax(dim=-1).detach().cpu().numpy() | |
return {k: float(v) for k, v in zip(labels, probs[0, 0])}, fig | |
# DEBUG | |
##query = Image.open('../labrador-puppy.jpg') | |
#query = Image.open('/Users/hushell/Documents/Dan_tr.png') | |
##labels = 'dog, cat' | |
#labels = 'girl, sussie' | |
#output = inference(query, labels, n_supp=2) | |
#print(output) | |
title = "P>M>F few-shot learning pipeline with Google Image Search (GIS)" | |
description = "Short description: We take a ViT-small backbone, which is pre-trained with DINO, and meta-trained on Meta-Dataset; for few-shot classification, we use a ProtoNet classifier. The demo can be viewed as zero-shot since the support set is built by searching images from Google. Note that you may need to play with GIS parameters to get good support examples. Besides, GIS is not very stable as search requests may fail for many reasons (e.g., number of requests reaches the limit of the day)." | |
article = "<p style='text-align: center'><a href='http://arxiv.org/abs/2204.07305' target='_blank'>Arxiv</a></p>" | |
gr.Interface(fn=inference, | |
inputs=[ | |
gr.inputs.Image(label="Image to classify", type="pil"), | |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",), | |
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="GIS: Number of support examples per class"), | |
gr.inputs.Dropdown(['png', 'jpg'], default='png', label='GIS: Image file type'), | |
gr.inputs.Dropdown(['cc_publicdomain', 'cc_attribute', 'cc_sharealike', 'cc_noncommercial', 'cc_nonderived'], default='cc_publicdomain', label='GIS: Copy rights'), | |
gr.inputs.Dropdown(['clipart', 'face', 'lineart', 'stock', 'photo', 'animated', 'imgTypeUndefined'], default='photo', label='GIS: Image type'), | |
gr.inputs.Dropdown(['color', 'gray', 'mono', 'trans', 'imgColorTypeUndefined'], default='color', label='GIS: Image color type'), | |
], | |
theme="grass", | |
outputs=[ | |
gr.outputs.Label(label="Predicted class probabilities"), | |
gr.outputs.Image(type='plot', label="Support examples from Google image search"), | |
], | |
title=title, | |
description=description, | |
article=article, | |
).launch(debug=True) | |