Spaces:
Runtime error
Runtime error
File size: 7,969 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import mmcv
import numpy as np
from mmengine.dist import master_only
from mmengine.visualization import Visualizer
from mmcls.registry import VISUALIZERS
from mmcls.structures import ClsDataSample
def _get_adaptive_scale(img_shape: Tuple[int, int],
min_scale: float = 0.3,
max_scale: float = 3.0) -> float:
"""Get adaptive scale according to image shape.
The target scale depends on the the short edge length of the image. If the
short edge length equals 224, the output is 1.0. And output linear scales
according the short edge length.
You can also specify the minimum scale and the maximum scale to limit the
linear scale.
Args:
img_shape (Tuple[int, int]): The shape of the canvas image.
min_size (int): The minimum scale. Defaults to 0.3.
max_size (int): The maximum scale. Defaults to 3.0.
Returns:
int: The adaptive scale.
"""
short_edge_length = min(img_shape)
scale = short_edge_length / 224.
return min(max(scale, min_scale), max_scale)
@VISUALIZERS.register_module()
class ClsVisualizer(Visualizer):
"""Universal Visualizer for classification task.
Args:
name (str): Name of the instance. Defaults to 'visualizer'.
image (np.ndarray, optional): the origin image to draw. The format
should be RGB. Defaults to None.
vis_backends (list, optional): Visual backend config list.
Defaults to None.
save_dir (str, optional): Save file dir for all storage backends.
If it is None, the backend storage will not save any data.
fig_save_cfg (dict): Keyword parameters of figure for saving.
Defaults to empty dict.
fig_show_cfg (dict): Keyword parameters of figure for showing.
Defaults to empty dict.
Examples:
>>> import torch
>>> import mmcv
>>> from pathlib import Path
>>> from mmcls.visualization import ClsVisualizer
>>> from mmcls.structures import ClsDataSample
>>> # Example image
>>> img = mmcv.imread("./demo/bird.JPEG", channel_order='rgb')
>>> # Example annotation
>>> data_sample = ClsDataSample().set_gt_label(1).set_pred_label(1).\
... set_pred_score(torch.tensor([0.1, 0.8, 0.1]))
>>> # Setup the visualizer
>>> vis = ClsVisualizer(
... save_dir="./outputs",
... vis_backends=[dict(type='LocalVisBackend')])
>>> # Set classes names
>>> vis.dataset_meta = {'classes': ['cat', 'bird', 'dog']}
>>> # Show the example image with annotation in a figure.
>>> # And it will ignore all preset storage backends.
>>> vis.add_datasample('res', img, data_sample, show=True)
>>> # Save the visualization result by the specified storage backends.
>>> vis.add_datasample('res', img, data_sample)
>>> assert Path('./outputs/vis_data/vis_image/res_0.png').exists()
>>> # Save another visualization result with the same name.
>>> vis.add_datasample('res', img, data_sample, step=1)
>>> assert Path('./outputs/vis_data/vis_image/res_1.png').exists()
"""
@master_only
def add_datasample(self,
name: str,
image: np.ndarray,
data_sample: Optional[ClsDataSample] = None,
draw_gt: bool = True,
draw_pred: bool = True,
draw_score: bool = True,
rescale_factor: Optional[float] = None,
show: bool = False,
text_cfg: dict = dict(),
wait_time: float = 0,
out_file: Optional[str] = None,
step: int = 0) -> None:
"""Draw datasample and save to all backends.
- If ``out_file`` is specified, all storage backends are ignored
and save the image to the ``out_file``.
- If ``show`` is True, plot the result image in a window, please
confirm you are able to access the graphical interface.
Args:
name (str): The image identifier.
image (np.ndarray): The image to draw.
data_sample (:obj:`ClsDataSample`, optional): The annotation of the
image. Defaults to None.
draw_gt (bool): Whether to draw ground truth labels.
Defaults to True.
draw_pred (bool): Whether to draw prediction labels.
Defaults to True.
draw_score (bool): Whether to draw the prediction scores
of prediction categories. Defaults to True.
rescale_factor (float, optional): Rescale the image by the rescale
factor before visualization. Defaults to None.
show (bool): Whether to display the drawn image. Defaults to False.
text_cfg (dict): Extra text setting, which accepts
arguments of :attr:`mmengine.Visualizer.draw_texts`.
Defaults to an empty dict.
wait_time (float): The interval of show (s). Defaults to 0, which
means "forever".
out_file (str, optional): Extra path to save the visualization
result. If specified, the visualizer will only save the result
image to the out_file and ignore its storage backends.
Defaults to None.
step (int): Global step value to record. Defaults to 0.
"""
classes = None
if self.dataset_meta is not None:
classes = self.dataset_meta.get('classes', None)
if rescale_factor is not None:
image = mmcv.imrescale(image, rescale_factor)
texts = []
self.set_image(image)
if draw_gt and 'gt_label' in data_sample:
gt_label = data_sample.gt_label
idx = gt_label.label.tolist()
class_labels = [''] * len(idx)
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [str(idx[i]) + class_labels[i] for i in range(len(idx))]
prefix = 'Ground truth: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
if draw_pred and 'pred_label' in data_sample:
pred_label = data_sample.pred_label
idx = pred_label.label.tolist()
score_labels = [''] * len(idx)
class_labels = [''] * len(idx)
if draw_score and 'score' in pred_label:
score_labels = [
f', {pred_label.score[i].item():.2f}' for i in idx
]
if classes is not None:
class_labels = [f' ({classes[i]})' for i in idx]
labels = [
str(idx[i]) + score_labels[i] + class_labels[i]
for i in range(len(idx))
]
prefix = 'Prediction: '
texts.append(prefix + ('\n' + ' ' * len(prefix)).join(labels))
img_scale = _get_adaptive_scale(image.shape[:2])
text_cfg = {
'positions': np.array([(img_scale * 5, ) * 2]).astype(np.int32),
'font_sizes': int(img_scale * 7),
'font_families': 'monospace',
'colors': 'white',
'bboxes': dict(facecolor='black', alpha=0.5, boxstyle='Round'),
**text_cfg
}
self.draw_texts('\n'.join(texts), **text_cfg)
drawn_img = self.get_image()
if show:
self.show(drawn_img, win_name=name, wait_time=wait_time)
if out_file is not None:
# save the image to the target file instead of vis_backends
mmcv.imwrite(drawn_img[..., ::-1], out_file)
else:
self.add_image(name, drawn_img, step=step)
|