Spaces:
Build error
Build error
import os | |
import numpy as np | |
import time | |
import random | |
import torch | |
import torchvision.transforms as transforms | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from models import get_model | |
from dotmap import DotMap | |
from PIL import Image | |
#os.environ['TERM'] = 'linux' | |
#os.environ['TERMINFO'] = '/etc/terminfo' | |
# args | |
args = DotMap() | |
args.deploy = 'vanilla' | |
args.arch = 'dino_small_patch16' | |
args.resume = 'https://huggingface.co./hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth' | |
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY' | |
args.cx = '06d75168141bc47f1' | |
# model | |
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = get_model(args) | |
model.to(device) | |
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu') | |
model.load_state_dict(checkpoint['model'], strict=True) | |
# image transforms | |
def test_transform(): | |
def _convert_image_to_rgb(im): | |
return im.convert('RGB') | |
return transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
_convert_image_to_rgb, | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
preprocess = test_transform() | |
def denormalize(x, mean, std): | |
# 3, H, W | |
t = x.clone() | |
t.mul_(std).add_(mean) | |
return torch.clamp(t, 0, 1) | |
# Google image search | |
from google_images_search import GoogleImagesSearch | |
class MyGIS(GoogleImagesSearch): | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
return | |
# define search params | |
# option for commonly used search param are shown below for easy reference. | |
# For param marked with '##': | |
# - Multiselect is currently not feasible. Choose ONE option only | |
# - This param can also be omitted from _search_params if you do not wish to define any value | |
_search_params = { | |
'q': '...', | |
'num': 10, | |
'fileType': 'png', #'jpg|gif|png', | |
'rights': 'cc_publicdomain', #'cc_publicdomain|cc_attribute|cc_sharealike|cc_noncommercial|cc_nonderived', | |
#'safe': 'active|high|medium|off|safeUndefined', ## | |
'imgType': 'photo', #'clipart|face|lineart|stock|photo|animated|imgTypeUndefined', ## | |
#'imgSize': 'huge|icon|large|medium|small|xlarge|xxlarge|imgSizeUndefined', ## | |
#'imgDominantColor': 'black|blue|brown|gray|green|orange|pink|purple|red|teal|white|yellow|imgDominantColorUndefined', ## | |
'imgColorType': 'color', #'color|gray|mono|trans|imgColorTypeUndefined' ## | |
} | |
# Gradio UI | |
def inference(query, labels, n_supp=10): | |
''' | |
query: PIL image | |
labels: list of class names | |
''' | |
labels = labels.split(',') | |
n_supp = int(n_supp) | |
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4)) | |
with torch.no_grad(): | |
# query image | |
query = preprocess(query).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 3, H, W) | |
supp_x = [] | |
supp_y = [] | |
# search support images | |
for idx, y in enumerate(labels): | |
gis = GoogleImagesSearch(args.api_key, args.cx) | |
_search_params['q'] = y | |
_search_params['num'] = n_supp | |
gis.search(search_params=_search_params, custom_image_name='my_image') | |
gis._custom_image_name = 'my_image' | |
for j, x in enumerate(gis.results()): | |
x.download('./') | |
x_im = Image.open(x.path) | |
# vis | |
axs[idx, j].imshow(x_im) | |
axs[idx, j].set_title(f'{y}{j}:{x.url}') | |
axs[idx, j].axis('off') | |
x_im = preprocess(x_im) # (3, H, W) | |
supp_x.append(x_im) | |
supp_y.append(idx) | |
print('Searching for support images is done.') | |
supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W) | |
supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels) | |
with torch.cuda.amp.autocast(True): | |
output = model(supp_x, supp_y, query) # (1, 1, n_labels) | |
probs = output.softmax(dim=-1).detach().cpu().numpy() | |
return {k: float(v) for k, v in zip(labels, probs[0, 0])}, fig | |
# DEBUG | |
#query = Image.open('../labrador-puppy.jpg') | |
##labels = 'dog, cat' | |
#labels = 'girl, boy' | |
#output = inference(query, labels, n_supp=2) | |
#print(output) | |
gr.Interface(fn=inference, | |
inputs=[ | |
gr.inputs.Image(label="Image to classify", type="pil"), | |
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",), | |
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google") | |
], | |
theme="grass", | |
outputs=[ | |
gr.outputs.Label(label="Predicted class probabilities"), | |
gr.outputs.Image(type='plot', label="Support examples from Google image search"), | |
], | |
description="PMF few-shot learning with Google image search").launch(debug=True) | |