from __future__ import absolute_import import sys import numpy as np import torch from torch import nn import os from collections import OrderedDict from torch.autograd import Variable import itertools from .base_model import BaseModel from scipy.ndimage import zoom import fractions import functools import skimage.transform from tqdm import tqdm from IPython import embed from . import networks_basic as networks import lpips as util class DistModel(BaseModel): def name(self): return self.model_name def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, use_gpu=True, printNet=False, spatial=False, is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): ''' INPUTS model - ['net-lin'] for linearly calibrated network ['net'] for off-the-shelf network ['L2'] for L2 distance in Lab colorspace ['SSIM'] for ssim in RGB colorspace net - ['squeeze','alex','vgg'] model_path - if None, will look in weights/[NET_NAME].pth colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM use_gpu - bool - whether or not to use a GPU printNet - bool - whether or not to print network architecture out spatial - bool - whether to output an array containing varying distances across spatial dimensions spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). is_train - bool - [True] for training mode lr - float - initial learning rate beta1 - float - initial momentum term for adam version - 0.1 for latest, 0.0 was original (with a bug) gpu_ids - int array - [0] by default, gpus to use ''' BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) self.model = model self.net = net self.is_train = is_train self.spatial = spatial self.gpu_ids = gpu_ids self.model_name = '%s [%s]'%(model,net) if(self.model == 'net-lin'): # pretrained net + linear layer self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, use_dropout=True, spatial=spatial, version=version, lpips=True) kw = {} if not use_gpu: kw['map_location'] = 'cpu' if(model_path is None): import inspect model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) if(not is_train): print('Loading model from: %s'%model_path) self.net.load_state_dict(torch.load(model_path, **kw), strict=False) elif(self.model=='net'): # pretrained network self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) elif(self.model in ['L2','l2']): self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing self.model_name = 'L2' elif(self.model in ['DSSIM','dssim','SSIM','ssim']): self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) self.model_name = 'SSIM' else: raise ValueError("Model [%s] not recognized." % self.model) self.parameters = list(self.net.parameters()) if self.is_train: # training mode # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) self.rankLoss = networks.BCERankingLoss() self.parameters += list(self.rankLoss.net.parameters()) self.lr = lr self.old_lr = lr self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) else: # test mode self.net.eval() if(use_gpu): self.net.to(gpu_ids[0]) self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) if(self.is_train): self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 if(printNet): print('---------- Networks initialized -------------') networks.print_network(self.net) print('-----------------------------------------------') def forward(self, in0, in1, retPerLayer=False): ''' Function computes the distance between image patches in0 and in1 INPUTS in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] OUTPUT computed distances between in0 and in1 ''' return self.net.forward(in0, in1, retPerLayer=retPerLayer) # ***** TRAINING FUNCTIONS ***** def optimize_parameters(self): self.forward_train() self.optimizer_net.zero_grad() self.backward_train() self.optimizer_net.step() self.clamp_weights() def clamp_weights(self): for module in self.net.modules(): if(hasattr(module, 'weight') and module.kernel_size==(1,1)): module.weight.data = torch.clamp(module.weight.data,min=0) def set_input(self, data): self.input_ref = data['ref'] self.input_p0 = data['p0'] self.input_p1 = data['p1'] self.input_judge = data['judge'] if(self.use_gpu): self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) self.var_ref = Variable(self.input_ref,requires_grad=True) self.var_p0 = Variable(self.input_p0,requires_grad=True) self.var_p1 = Variable(self.input_p1,requires_grad=True) def forward_train(self): # run forward pass # print(self.net.module.scaling_layer.shift) # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) self.d0 = self.forward(self.var_ref, self.var_p0) self.d1 = self.forward(self.var_ref, self.var_p1) self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) return self.loss_total def backward_train(self): torch.mean(self.loss_total).backward() def compute_accuracy(self,d0,d1,judge): ''' d0, d1 are Variables, judge is a Tensor ''' d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) self.old_lr = lr def score_2afc_dataset(data_loader, func, name=''): ''' Function computes Two Alternative Forced Choice (2AFC) score using distance function 'func' in dataset 'data_loader' INPUTS data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside func - callable distance function - calling d=func(in0,in1) should take 2 pytorch tensors with shape Nx3xXxY, and return numpy array of length N OUTPUTS [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators [1] - dictionary with following elements d0s,d1s - N arrays containing distances between reference patch to perturbed patches gts - N array in [0,1], preferred patch selected by human evaluators (closer to "0" for left patch p0, "1" for right patch p1, "0.6" means 60pct people preferred right patch, 40pct preferred left) scores - N array in [0,1], corresponding to what percentage function agreed with humans CONSTS N - number of test triplets in data_loader ''' d0s = [] d1s = [] gts = [] for data in tqdm(data_loader.load_data(), desc=name): d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() gts+=data['judge'].cpu().numpy().flatten().tolist() d0s = np.array(d0s) d1s = np.array(d1s) gts = np.array(gts) scores = (d0s