# Copyright (c) OpenMMLab. All rights reserved. import math import os.path as osp from typing import Optional, Sequence from mmengine.fileio import join_path from mmengine.hooks import Hook from mmengine.runner import EpochBasedTrainLoop, Runner from mmengine.visualization import Visualizer from mmcls.registry import HOOKS from mmcls.structures import ClsDataSample @HOOKS.register_module() class VisualizationHook(Hook): """Classification Visualization Hook. Used to visualize validation and testing prediction results. - If ``out_dir`` is specified, all storage backends are ignored and save the image to the ``out_dir``. - If ``show`` is True, plot the result image in a window, please confirm you are able to access the graphical interface. Args: enable (bool): Whether to enable this hook. Defaults to False. interval (int): The interval of samples to visualize. Defaults to 5000. show (bool): Whether to display the drawn image. Defaults to False. out_dir (str, optional): directory where painted images will be saved in the testing process. If None, handle with the backends of the visualizer. Defaults to None. **kwargs: other keyword arguments of :meth:`mmcls.visualization.ClsVisualizer.add_datasample`. """ def __init__(self, enable=False, interval: int = 5000, show: bool = False, out_dir: Optional[str] = None, **kwargs): self._visualizer: Visualizer = Visualizer.get_current_instance() self.enable = enable self.interval = interval self.show = show self.out_dir = out_dir self.draw_args = {**kwargs, 'show': show} def _draw_samples(self, batch_idx: int, data_batch: dict, data_samples: Sequence[ClsDataSample], step: int = 0) -> None: """Visualize every ``self.interval`` samples from a data batch. Args: batch_idx (int): The index of the current batch in the val loop. data_batch (dict): Data from dataloader. outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model. step (int): Global step value to record. Defaults to 0. """ if self.enable is False: return batch_size = len(data_samples) images = data_batch['inputs'] start_idx = batch_size * batch_idx end_idx = start_idx + batch_size # The first index divisible by the interval, after the start index first_sample_id = math.ceil(start_idx / self.interval) * self.interval for sample_id in range(first_sample_id, end_idx, self.interval): image = images[sample_id - start_idx] image = image.permute(1, 2, 0).cpu().numpy().astype('uint8') data_sample = data_samples[sample_id - start_idx] if 'img_path' in data_sample: # osp.basename works on different platforms even file clients. sample_name = osp.basename(data_sample.get('img_path')) else: sample_name = str(sample_id) draw_args = self.draw_args if self.out_dir is not None: draw_args['out_file'] = join_path(self.out_dir, f'{sample_name}_{step}.png') self._visualizer.add_datasample( sample_name, image=image, data_sample=data_sample, step=step, **self.draw_args, ) def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[ClsDataSample]) -> None: """Visualize every ``self.interval`` samples during validation. Args: runner (:obj:`Runner`): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. data_batch (dict): Data from dataloader. outputs (Sequence[:obj:`ClsDataSample`]): Outputs from model. """ if isinstance(runner.train_loop, EpochBasedTrainLoop): step = runner.epoch else: step = runner.iter self._draw_samples(batch_idx, data_batch, outputs, step=step) def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[ClsDataSample]) -> None: """Visualize every ``self.interval`` samples during test. Args: runner (:obj:`Runner`): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. data_batch (dict): Data from dataloader. outputs (Sequence[:obj:`DetDataSample`]): Outputs from model. """ self._draw_samples(batch_idx, data_batch, outputs, step=0)