KyanChen's picture
init
f549064
raw
history blame contribute delete
No virus
7.97 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import mmcv
import numpy as np
from mmengine.dist import master_only
from mmengine.visualization import Visualizer
from mmcls.registry import VISUALIZERS
from mmcls.structures import ClsDataSample
def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear scales
according the short edge length.
You can also specify the minimum scale and the maximum scale to limit the
linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_size (int): The minimum scale. Defaults to 0.3.
max_size (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
@VISUALIZERS.register_module()
class ClsVisualizer(Visualizer):
"""Universal Visualizer for classification task.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None.
vis_backends (list, optional): Visual backend config list.
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
Examples:
>>> import torch
>>> import mmcv
>>> from pathlib import Path
>>> from mmcls.visualization import ClsVisualizer
>>> from mmcls.structures import ClsDataSample
>>> # Example image
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
>>> # Example annotation
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
>>> # Setup the visualizer
>>> vis = ClsVisualizer(
... save_dir="./outputs",
... vis_backends=[dict(type='LocalVisBackend')])
>>> # Set classes names
>>> vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
>>> # Show the example image with annotation in a figure.
>>> # And it will ignore all preset storage backends.
>>> vis.add_datasample('res', img, data_sample, show=True)
>>> # Save the visualization result by the specified storage backends.
>>> vis.add_datasample('res', img, data_sample)
>>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists()
>>> # Save another visualization result with the same name.
>>> vis.add_datasample('res', img, data_sample, step=1)
>>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists()
"""
@master_only
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional[ClsDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
rescale_factor: Optional[float] = None,
show: bool = False,
text_cfg: dict = dict(),
wait_time: float = 0,
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If ``out_file`` is specified, all storage backends are ignored
and save the image to the ``out_file``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
image. Defaults to None.
draw_gt (bool): Whether to draw ground truth labels.
Defaults to True.
draw_pred (bool): Whether to draw prediction labels.
Defaults to True.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
rescale_factor (float, optional): Rescale the image by the rescale
factor before visualization. Defaults to None.
show (bool): Whether to display the drawn image. Defaults to False.
text_cfg (dict): Extra text setting, which accepts
arguments of :attr:`mmengine.Visualizer.draw_texts`.
Defaults to an empty dict.
wait_time (float): The interval of show (s). Defaults to 0, which
means "forever".
out_file (str, optional): Extra path to save the visualization
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Defaults to None.
step (int): Global step value to record. Defaults to 0.
"""
classes = None
if self.dataset_meta is not None:
classes = self.dataset_meta.get('classes', None)
if rescale_factor is not None:
image = mmcv.imrescale(image, rescale_factor)
texts = []
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
gt_label = data_sample.gt_label
idx = gt_label.label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample:
pred_label = data_sample.pred_label
idx = pred_label.label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'score' in pred_label:
score_labels = [
f', {pred_label.score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
img_scale = _get_adaptive_scale(image.shape[:2])
text_cfg = {
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
'font_sizes': int(img_scale * 7),
'font_families': 'monospace',
'colors': 'white',
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
**text_cfg
}
self.draw_texts('\n'.join(texts), **text_cfg)
drawn_img = self.get_image()
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)