|
import abc |
|
from typing import Dict, List |
|
|
|
import numpy as np |
|
import torch |
|
from skimage import color |
|
from skimage.segmentation import mark_boundaries |
|
|
|
from . import colors |
|
|
|
COLORS, _ = colors.generate_colors(151) |
|
|
|
|
|
class BaseVisualizer: |
|
@abc.abstractmethod |
|
def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): |
|
""" |
|
Take a batch, make an image from it and visualize |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], |
|
last_without_mask=True, rescale_keys=None, mask_only_first=None, |
|
black_mask=False) -> np.ndarray: |
|
mask = images_dict['mask'] > 0.5 |
|
result = [] |
|
for i, k in enumerate(keys): |
|
img = images_dict[k] |
|
img = np.transpose(img, (1, 2, 0)) |
|
|
|
if rescale_keys is not None and k in rescale_keys: |
|
img = img - img.min() |
|
img /= img.max() + 1e-5 |
|
if len(img.shape) == 2: |
|
img = np.expand_dims(img, 2) |
|
|
|
if img.shape[2] == 1: |
|
img = np.repeat(img, 3, axis=2) |
|
elif (img.shape[2] > 3): |
|
img_classes = img.argmax(2) |
|
img = color.label2rgb(img_classes, colors=COLORS) |
|
|
|
if mask_only_first: |
|
need_mark_boundaries = i == 0 |
|
else: |
|
need_mark_boundaries = i < len(keys) - 1 or not last_without_mask |
|
|
|
if need_mark_boundaries: |
|
if black_mask: |
|
img = img * (1 - mask[0][..., None]) |
|
img = mark_boundaries(img, |
|
mask[0], |
|
color=(1., 0., 0.), |
|
outline_color=(1., 1., 1.), |
|
mode='thick') |
|
result.append(img) |
|
return np.concatenate(result, axis=1) |
|
|
|
|
|
def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, |
|
last_without_mask=True, rescale_keys=None) -> np.ndarray: |
|
batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() |
|
if k in keys or k == 'mask'} |
|
|
|
batch_size = next(iter(batch.values())).shape[0] |
|
items_to_vis = min(batch_size, max_items) |
|
result = [] |
|
for i in range(items_to_vis): |
|
cur_dct = {k: tens[i] for k, tens in batch.items()} |
|
result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, |
|
rescale_keys=rescale_keys)) |
|
return np.concatenate(result, axis=0) |
|
|