Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import sys | |
import pickle | |
import argparse | |
import torch | |
from torch import nn | |
import torch.distributed as dist | |
import torch.backends.cudnn as cudnn | |
from torchvision import models as torchvision_models | |
from torchvision import transforms as pth_transforms | |
from PIL import Image, ImageFile | |
import numpy as np | |
import utils | |
import vision_transformer as vits | |
from eval_knn import extract_features | |
class OxfordParisDataset(torch.utils.data.Dataset): | |
def __init__(self, dir_main, dataset, split, transform=None, imsize=None): | |
if dataset not in ['roxford5k', 'rparis6k']: | |
raise ValueError('Unknown dataset: {}!'.format(dataset)) | |
# loading imlist, qimlist, and gnd, in cfg as a dict | |
gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) | |
with open(gnd_fname, 'rb') as f: | |
cfg = pickle.load(f) | |
cfg['gnd_fname'] = gnd_fname | |
cfg['ext'] = '.jpg' | |
cfg['qext'] = '.jpg' | |
cfg['dir_data'] = os.path.join(dir_main, dataset) | |
cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') | |
cfg['n'] = len(cfg['imlist']) | |
cfg['nq'] = len(cfg['qimlist']) | |
cfg['im_fname'] = config_imname | |
cfg['qim_fname'] = config_qimname | |
cfg['dataset'] = dataset | |
self.cfg = cfg | |
self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"] | |
self.transform = transform | |
self.imsize = imsize | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, index): | |
path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg") | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
with open(path, 'rb') as f: | |
img = Image.open(f) | |
img = img.convert('RGB') | |
if self.imsize is not None: | |
img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS) | |
if self.transform is not None: | |
img = self.transform(img) | |
return img, index | |
def config_imname(cfg, i): | |
return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) | |
def config_qimname(cfg, i): | |
return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser('Image Retrieval on revisited Paris and Oxford') | |
parser.add_argument('--data_path', default='/path/to/revisited_paris_oxford/', type=str) | |
parser.add_argument('--dataset', default='roxford5k', type=str, choices=['roxford5k', 'rparis6k']) | |
parser.add_argument('--multiscale', default=False, type=utils.bool_flag) | |
parser.add_argument('--imsize', default=224, type=int, help='Image size') | |
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") | |
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag) | |
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture') | |
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.') | |
parser.add_argument("--checkpoint_key", default="teacher", type=str, | |
help='Key to use in the checkpoint (example: "teacher")') | |
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') | |
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up | |
distributed training; see https://pytorch.org/docs/stable/distributed.html""") | |
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") | |
args = parser.parse_args() | |
utils.init_distributed_mode(args) | |
print("git:\n {}\n".format(utils.get_sha())) | |
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) | |
cudnn.benchmark = True | |
# ============ preparing data ... ============ | |
transform = pth_transforms.Compose([ | |
pth_transforms.ToTensor(), | |
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
dataset_train = OxfordParisDataset(args.data_path, args.dataset, split="train", transform=transform, imsize=args.imsize) | |
dataset_query = OxfordParisDataset(args.data_path, args.dataset, split="query", transform=transform, imsize=args.imsize) | |
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False) | |
data_loader_train = torch.utils.data.DataLoader( | |
dataset_train, | |
sampler=sampler, | |
batch_size=1, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
data_loader_query = torch.utils.data.DataLoader( | |
dataset_query, | |
batch_size=1, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs") | |
# ============ building network ... ============ | |
if "vit" in args.arch: | |
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) | |
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") | |
elif "xcit" in args.arch: | |
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) | |
elif args.arch in torchvision_models.__dict__.keys(): | |
model = torchvision_models.__dict__[args.arch](num_classes=0) | |
else: | |
print(f"Architecture {args.arch} non supported") | |
sys.exit(1) | |
if args.use_cuda: | |
model.cuda() | |
model.eval() | |
# load pretrained weights | |
if os.path.isfile(args.pretrained_weights): | |
state_dict = torch.load(args.pretrained_weights, map_location="cpu") | |
if args.checkpoint_key is not None and args.checkpoint_key in state_dict: | |
print(f"Take key {args.checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[args.checkpoint_key] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg)) | |
elif args.arch == "vit_small" and args.patch_size == 16: | |
print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.") | |
model.load_state_dict(torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth")) | |
else: | |
print("Warning: We use random weights.") | |
############################################################################ | |
# Step 1: extract features | |
train_features = extract_features(model, data_loader_train, args.use_cuda, multiscale=args.multiscale) | |
query_features = extract_features(model, data_loader_query, args.use_cuda, multiscale=args.multiscale) | |
if utils.get_rank() == 0: # only rank 0 will work from now on | |
# normalize features | |
train_features = nn.functional.normalize(train_features, dim=1, p=2) | |
query_features = nn.functional.normalize(query_features, dim=1, p=2) | |
############################################################################ | |
# Step 2: similarity | |
sim = torch.mm(train_features, query_features.T) | |
ranks = torch.argsort(-sim, dim=0).cpu().numpy() | |
############################################################################ | |
# Step 3: evaluate | |
gnd = dataset_train.cfg['gnd'] | |
# evaluate ranks | |
ks = [1, 5, 10] | |
# search for easy & hard | |
gnd_t = [] | |
for i in range(len(gnd)): | |
g = {} | |
g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) | |
g['junk'] = np.concatenate([gnd[i]['junk']]) | |
gnd_t.append(g) | |
mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks) | |
# search for hard | |
gnd_t = [] | |
for i in range(len(gnd)): | |
g = {} | |
g['ok'] = np.concatenate([gnd[i]['hard']]) | |
g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) | |
gnd_t.append(g) | |
mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks) | |
print('>> {}: mAP M: {}, H: {}'.format(args.dataset, np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2))) | |
print('>> {}: mP@k{} M: {}, H: {}'.format(args.dataset, np.array(ks), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2))) | |
dist.barrier() | |