|
import logging |
|
import math |
|
from typing import Dict |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import tqdm |
|
from torch.utils.data import DataLoader |
|
|
|
from saicinpainting.evaluation.utils import move_to_device |
|
|
|
LOGGER = logging.getLogger(__name__) |
|
|
|
|
|
class InpaintingEvaluator(): |
|
def __init__(self, dataset, scores, area_grouping=True, bins=10, batch_size=32, device='cuda', |
|
integral_func=None, integral_title=None, clamp_image_range=None): |
|
""" |
|
:param dataset: torch.utils.data.Dataset which contains images and masks |
|
:param scores: dict {score_name: EvaluatorScore object} |
|
:param area_grouping: in addition to the overall scores, allows to compute score for the groups of samples |
|
which are defined by share of area occluded by mask |
|
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) |
|
:param batch_size: batch_size for the dataloader |
|
:param device: device to use |
|
""" |
|
self.scores = scores |
|
self.dataset = dataset |
|
|
|
self.area_grouping = area_grouping |
|
self.bins = bins |
|
|
|
self.device = torch.device(device) |
|
|
|
self.dataloader = DataLoader(self.dataset, shuffle=False, batch_size=batch_size) |
|
|
|
self.integral_func = integral_func |
|
self.integral_title = integral_title |
|
self.clamp_image_range = clamp_image_range |
|
|
|
def _get_bin_edges(self): |
|
bin_edges = np.linspace(0, 1, self.bins + 1) |
|
|
|
num_digits = max(0, math.ceil(math.log10(self.bins)) - 1) |
|
interval_names = [] |
|
for idx_bin in range(self.bins): |
|
start_percent, end_percent = round(100 * bin_edges[idx_bin], num_digits), \ |
|
round(100 * bin_edges[idx_bin + 1], num_digits) |
|
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) |
|
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) |
|
interval_names.append("{0}-{1}%".format(start_percent, end_percent)) |
|
|
|
groups = [] |
|
for batch in self.dataloader: |
|
mask = batch['mask'] |
|
batch_size = mask.shape[0] |
|
area = mask.to(self.device).reshape(batch_size, -1).mean(dim=-1) |
|
bin_indices = np.searchsorted(bin_edges, area.detach().cpu().numpy(), side='right') - 1 |
|
|
|
bin_indices[bin_indices == self.bins] = self.bins - 1 |
|
groups.append(bin_indices) |
|
groups = np.hstack(groups) |
|
|
|
return groups, interval_names |
|
|
|
def evaluate(self, model=None): |
|
""" |
|
:param model: callable with signature (image_batch, mask_batch); should return inpainted_batch |
|
:return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or |
|
name of the particular group arranged by area of mask (e.g. '10-20%') |
|
and score statistics for the group as values. |
|
""" |
|
results = dict() |
|
if self.area_grouping: |
|
groups, interval_names = self._get_bin_edges() |
|
else: |
|
groups = None |
|
|
|
for score_name, score in tqdm.auto.tqdm(self.scores.items(), desc='scores'): |
|
score.to(self.device) |
|
with torch.no_grad(): |
|
score.reset() |
|
for batch in tqdm.auto.tqdm(self.dataloader, desc=score_name, leave=False): |
|
batch = move_to_device(batch, self.device) |
|
image_batch, mask_batch = batch['image'], batch['mask'] |
|
if self.clamp_image_range is not None: |
|
image_batch = torch.clamp(image_batch, |
|
min=self.clamp_image_range[0], |
|
max=self.clamp_image_range[1]) |
|
if model is None: |
|
assert 'inpainted' in batch, \ |
|
'Model is None, so we expected precomputed inpainting results at key "inpainted"' |
|
inpainted_batch = batch['inpainted'] |
|
else: |
|
inpainted_batch = model(image_batch, mask_batch) |
|
score(inpainted_batch, image_batch, mask_batch) |
|
total_results, group_results = score.get_value(groups=groups) |
|
|
|
results[(score_name, 'total')] = total_results |
|
if groups is not None: |
|
for group_index, group_values in group_results.items(): |
|
group_name = interval_names[group_index] |
|
results[(score_name, group_name)] = group_values |
|
|
|
if self.integral_func is not None: |
|
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) |
|
|
|
return results |
|
|
|
|
|
def ssim_fid100_f1(metrics, fid_scale=100): |
|
ssim = metrics[('ssim', 'total')]['mean'] |
|
fid = metrics[('fid', 'total')]['mean'] |
|
fid_rel = max(0, fid_scale - fid) / fid_scale |
|
f1 = 2 * ssim * fid_rel / (ssim + fid_rel + 1e-3) |
|
return f1 |
|
|
|
|
|
def lpips_fid100_f1(metrics, fid_scale=100): |
|
neg_lpips = 1 - metrics[('lpips', 'total')]['mean'] |
|
fid = metrics[('fid', 'total')]['mean'] |
|
fid_rel = max(0, fid_scale - fid) / fid_scale |
|
f1 = 2 * neg_lpips * fid_rel / (neg_lpips + fid_rel + 1e-3) |
|
return f1 |
|
|
|
|
|
|
|
class InpaintingEvaluatorOnline(nn.Module): |
|
def __init__(self, scores, bins=10, image_key='image', inpainted_key='inpainted', |
|
integral_func=None, integral_title=None, clamp_image_range=None): |
|
""" |
|
:param scores: dict {score_name: EvaluatorScore object} |
|
:param bins: number of groups, partition is generated by np.linspace(0., 1., bins + 1) |
|
:param device: device to use |
|
""" |
|
super().__init__() |
|
LOGGER.info(f'{type(self)} init called') |
|
self.scores = nn.ModuleDict(scores) |
|
self.image_key = image_key |
|
self.inpainted_key = inpainted_key |
|
self.bins_num = bins |
|
self.bin_edges = np.linspace(0, 1, self.bins_num + 1) |
|
|
|
num_digits = max(0, math.ceil(math.log10(self.bins_num)) - 1) |
|
self.interval_names = [] |
|
for idx_bin in range(self.bins_num): |
|
start_percent, end_percent = round(100 * self.bin_edges[idx_bin], num_digits), \ |
|
round(100 * self.bin_edges[idx_bin + 1], num_digits) |
|
start_percent = '{:.{n}f}'.format(start_percent, n=num_digits) |
|
end_percent = '{:.{n}f}'.format(end_percent, n=num_digits) |
|
self.interval_names.append("{0}-{1}%".format(start_percent, end_percent)) |
|
|
|
self.groups = [] |
|
|
|
self.integral_func = integral_func |
|
self.integral_title = integral_title |
|
self.clamp_image_range = clamp_image_range |
|
|
|
LOGGER.info(f'{type(self)} init done') |
|
|
|
def _get_bins(self, mask_batch): |
|
batch_size = mask_batch.shape[0] |
|
area = mask_batch.view(batch_size, -1).mean(dim=-1).detach().cpu().numpy() |
|
bin_indices = np.clip(np.searchsorted(self.bin_edges, area) - 1, 0, self.bins_num - 1) |
|
return bin_indices |
|
|
|
def forward(self, batch: Dict[str, torch.Tensor]): |
|
""" |
|
Calculate and accumulate metrics for batch. To finalize evaluation and obtain final metrics, call evaluation_end |
|
:param batch: batch dict with mandatory fields mask, image, inpainted (can be overriden by self.inpainted_key) |
|
""" |
|
result = {} |
|
with torch.no_grad(): |
|
image_batch, mask_batch, inpainted_batch = batch[self.image_key], batch['mask'], batch[self.inpainted_key] |
|
if self.clamp_image_range is not None: |
|
image_batch = torch.clamp(image_batch, |
|
min=self.clamp_image_range[0], |
|
max=self.clamp_image_range[1]) |
|
self.groups.extend(self._get_bins(mask_batch)) |
|
|
|
for score_name, score in self.scores.items(): |
|
result[score_name] = score(inpainted_batch, image_batch, mask_batch) |
|
return result |
|
|
|
def process_batch(self, batch: Dict[str, torch.Tensor]): |
|
return self(batch) |
|
|
|
def evaluation_end(self, states=None): |
|
""":return: dict with (score_name, group_type) as keys, where group_type can be either 'overall' or |
|
name of the particular group arranged by area of mask (e.g. '10-20%') |
|
and score statistics for the group as values. |
|
""" |
|
LOGGER.info(f'{type(self)}: evaluation_end called') |
|
|
|
self.groups = np.array(self.groups) |
|
|
|
results = {} |
|
for score_name, score in self.scores.items(): |
|
LOGGER.info(f'Getting value of {score_name}') |
|
cur_states = [s[score_name] for s in states] if states is not None else None |
|
total_results, group_results = score.get_value(groups=self.groups, states=cur_states) |
|
LOGGER.info(f'Getting value of {score_name} done') |
|
results[(score_name, 'total')] = total_results |
|
|
|
for group_index, group_values in group_results.items(): |
|
group_name = self.interval_names[group_index] |
|
results[(score_name, group_name)] = group_values |
|
|
|
if self.integral_func is not None: |
|
results[(self.integral_title, 'total')] = dict(mean=self.integral_func(results)) |
|
|
|
LOGGER.info(f'{type(self)}: reset scores') |
|
self.groups = [] |
|
for sc in self.scores.values(): |
|
sc.reset() |
|
LOGGER.info(f'{type(self)}: reset scores done') |
|
|
|
LOGGER.info(f'{type(self)}: evaluation_end done') |
|
return results |
|
|