Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List, Optional, Tuple, Union | |
import cv2 | |
import mmcv | |
import numpy as np | |
try: | |
import seaborn as sns | |
except ImportError: | |
sns = None | |
import torch | |
from mmengine.dist import master_only | |
from mmengine.structures import InstanceData, PixelData | |
from mmengine.visualization import Visualizer | |
from ..evaluation import INSTANCE_OFFSET | |
from ..registry import VISUALIZERS | |
from ..structures import DetDataSample | |
from ..structures.mask import BitmapMasks, PolygonMasks, bitmap_to_polygon | |
from .palette import _get_adaptive_scales, get_palette, jitter_color | |
class DetLocalVisualizer(Visualizer): | |
"""MMDetection Local Visualizer. | |
Args: | |
name (str): Name of the instance. Defaults to 'visualizer'. | |
image (np.ndarray, optional): the origin image to draw. The format | |
should be RGB. Defaults to None. | |
vis_backends (list, optional): Visual backend config list. | |
Defaults to None. | |
save_dir (str, optional): Save file dir for all storage backends. | |
If it is None, the backend storage will not save any data. | |
bbox_color (str, tuple(int), optional): Color of bbox lines. | |
The tuple of color should be in BGR order. Defaults to None. | |
text_color (str, tuple(int), optional): Color of texts. | |
The tuple of color should be in BGR order. | |
Defaults to (200, 200, 200). | |
mask_color (str, tuple(int), optional): Color of masks. | |
The tuple of color should be in BGR order. | |
Defaults to None. | |
line_width (int, float): The linewidth of lines. | |
Defaults to 3. | |
alpha (int, float): The transparency of bboxes or mask. | |
Defaults to 0.8. | |
Examples: | |
>>> import numpy as np | |
>>> import torch | |
>>> from mmengine.structures import InstanceData | |
>>> from mmdet.structures import DetDataSample | |
>>> from mmdet.visualization import DetLocalVisualizer | |
>>> det_local_visualizer = DetLocalVisualizer() | |
>>> image = np.random.randint(0, 256, | |
... size=(10, 12, 3)).astype('uint8') | |
>>> gt_instances = InstanceData() | |
>>> gt_instances.bboxes = torch.Tensor([[1, 2, 2, 5]]) | |
>>> gt_instances.labels = torch.randint(0, 2, (1,)) | |
>>> gt_det_data_sample = DetDataSample() | |
>>> gt_det_data_sample.gt_instances = gt_instances | |
>>> det_local_visualizer.add_datasample('image', image, | |
... gt_det_data_sample) | |
>>> det_local_visualizer.add_datasample( | |
... 'image', image, gt_det_data_sample, | |
... out_file='out_file.jpg') | |
>>> det_local_visualizer.add_datasample( | |
... 'image', image, gt_det_data_sample, | |
... show=True) | |
>>> pred_instances = InstanceData() | |
>>> pred_instances.bboxes = torch.Tensor([[2, 4, 4, 8]]) | |
>>> pred_instances.labels = torch.randint(0, 2, (1,)) | |
>>> pred_det_data_sample = DetDataSample() | |
>>> pred_det_data_sample.pred_instances = pred_instances | |
>>> det_local_visualizer.add_datasample('image', image, | |
... gt_det_data_sample, | |
... pred_det_data_sample) | |
""" | |
def __init__(self, | |
name: str = 'visualizer', | |
image: Optional[np.ndarray] = None, | |
vis_backends: Optional[Dict] = None, | |
save_dir: Optional[str] = None, | |
bbox_color: Optional[Union[str, Tuple[int]]] = None, | |
text_color: Optional[Union[str, | |
Tuple[int]]] = (200, 200, 200), | |
mask_color: Optional[Union[str, Tuple[int]]] = None, | |
line_width: Union[int, float] = 3, | |
alpha: float = 0.8) -> None: | |
super().__init__( | |
name=name, | |
image=image, | |
vis_backends=vis_backends, | |
save_dir=save_dir) | |
self.bbox_color = bbox_color | |
self.text_color = text_color | |
self.mask_color = mask_color | |
self.line_width = line_width | |
self.alpha = alpha | |
# Set default value. When calling | |
# `DetLocalVisualizer().dataset_meta=xxx`, | |
# it will override the default value. | |
self.dataset_meta = {} | |
def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], | |
classes: Optional[List[str]], | |
palette: Optional[List[tuple]]) -> np.ndarray: | |
"""Draw instances of GT or prediction. | |
Args: | |
image (np.ndarray): The image to draw. | |
instances (:obj:`InstanceData`): Data structure for | |
instance-level annotations or predictions. | |
classes (List[str], optional): Category information. | |
palette (List[tuple], optional): Palette information | |
corresponding to the category. | |
Returns: | |
np.ndarray: the drawn image which channel is RGB. | |
""" | |
self.set_image(image) | |
if 'bboxes' in instances and instances.bboxes.sum() > 0: | |
bboxes = instances.bboxes | |
labels = instances.labels | |
max_label = int(max(labels) if len(labels) > 0 else 0) | |
text_palette = get_palette(self.text_color, max_label + 1) | |
text_colors = [text_palette[label] for label in labels] | |
bbox_color = palette if self.bbox_color is None \ | |
else self.bbox_color | |
bbox_palette = get_palette(bbox_color, max_label + 1) | |
colors = [bbox_palette[label] for label in labels] | |
self.draw_bboxes( | |
bboxes, | |
edge_colors=colors, | |
alpha=self.alpha, | |
line_widths=self.line_width) | |
positions = bboxes[:, :2] + self.line_width | |
areas = (bboxes[:, 3] - bboxes[:, 1]) * ( | |
bboxes[:, 2] - bboxes[:, 0]) | |
scales = _get_adaptive_scales(areas) | |
for i, (pos, label) in enumerate(zip(positions, labels)): | |
if 'label_names' in instances: | |
label_text = instances.label_names[i] | |
else: | |
label_text = classes[ | |
label] if classes is not None else f'class {label}' | |
if 'scores' in instances: | |
score = round(float(instances.scores[i]) * 100, 1) | |
label_text += f': {score}' | |
self.draw_texts( | |
label_text, | |
pos, | |
colors=text_colors[i], | |
font_sizes=int(13 * scales[i]), | |
bboxes=[{ | |
'facecolor': 'black', | |
'alpha': 0.8, | |
'pad': 0.7, | |
'edgecolor': 'none' | |
}]) | |
if 'masks' in instances: | |
labels = instances.labels | |
masks = instances.masks | |
if isinstance(masks, torch.Tensor): | |
masks = masks.numpy() | |
elif isinstance(masks, (PolygonMasks, BitmapMasks)): | |
masks = masks.to_ndarray() | |
masks = masks.astype(bool) | |
max_label = int(max(labels) if len(labels) > 0 else 0) | |
mask_color = palette if self.mask_color is None \ | |
else self.mask_color | |
mask_palette = get_palette(mask_color, max_label + 1) | |
colors = [jitter_color(mask_palette[label]) for label in labels] | |
text_palette = get_palette(self.text_color, max_label + 1) | |
text_colors = [text_palette[label] for label in labels] | |
polygons = [] | |
for i, mask in enumerate(masks): | |
contours, _ = bitmap_to_polygon(mask) | |
polygons.extend(contours) | |
self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha) | |
self.draw_binary_masks(masks, colors=colors, alphas=self.alpha) | |
if len(labels) > 0 and \ | |
('bboxes' not in instances or | |
instances.bboxes.sum() == 0): | |
# instances.bboxes.sum()==0 represent dummy bboxes. | |
# A typical example of SOLO does not exist bbox branch. | |
areas = [] | |
positions = [] | |
for mask in masks: | |
_, _, stats, centroids = cv2.connectedComponentsWithStats( | |
mask.astype(np.uint8), connectivity=8) | |
if stats.shape[0] > 1: | |
largest_id = np.argmax(stats[1:, -1]) + 1 | |
positions.append(centroids[largest_id]) | |
areas.append(stats[largest_id, -1]) | |
areas = np.stack(areas, axis=0) | |
scales = _get_adaptive_scales(areas) | |
for i, (pos, label) in enumerate(zip(positions, labels)): | |
if 'label_names' in instances: | |
label_text = instances.label_names[i] | |
else: | |
label_text = classes[ | |
label] if classes is not None else f'class {label}' | |
if 'scores' in instances: | |
score = round(float(instances.scores[i]) * 100, 1) | |
label_text += f': {score}' | |
self.draw_texts( | |
label_text, | |
pos, | |
colors=text_colors[i], | |
font_sizes=int(13 * scales[i]), | |
horizontal_alignments='center', | |
bboxes=[{ | |
'facecolor': 'black', | |
'alpha': 0.8, | |
'pad': 0.7, | |
'edgecolor': 'none' | |
}]) | |
return self.get_image() | |
def _draw_panoptic_seg(self, image: np.ndarray, | |
panoptic_seg: ['PixelData'], | |
classes: Optional[List[str]], | |
palette: Optional[List]) -> np.ndarray: | |
"""Draw panoptic seg of GT or prediction. | |
Args: | |
image (np.ndarray): The image to draw. | |
panoptic_seg (:obj:`PixelData`): Data structure for | |
pixel-level annotations or predictions. | |
classes (List[str], optional): Category information. | |
Returns: | |
np.ndarray: the drawn image which channel is RGB. | |
""" | |
# TODO: Is there a way to bypass? | |
num_classes = len(classes) | |
panoptic_seg_data = panoptic_seg.sem_seg[0] | |
ids = np.unique(panoptic_seg_data)[::-1] | |
if 'label_names' in panoptic_seg: | |
# open set panoptic segmentation | |
classes = panoptic_seg.metainfo['label_names'] | |
ignore_index = panoptic_seg.metainfo.get('ignore_index', | |
len(classes)) | |
ids = ids[ids != ignore_index] | |
else: | |
# for VOID label | |
ids = ids[ids != num_classes] | |
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) | |
segms = (panoptic_seg_data[None] == ids[:, None, None]) | |
max_label = int(max(labels) if len(labels) > 0 else 0) | |
mask_color = palette if self.mask_color is None \ | |
else self.mask_color | |
mask_palette = get_palette(mask_color, max_label + 1) | |
colors = [mask_palette[label] for label in labels] | |
self.set_image(image) | |
# draw segm | |
polygons = [] | |
for i, mask in enumerate(segms): | |
contours, _ = bitmap_to_polygon(mask) | |
polygons.extend(contours) | |
self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha) | |
self.draw_binary_masks(segms, colors=colors, alphas=self.alpha) | |
# draw label | |
areas = [] | |
positions = [] | |
for mask in segms: | |
_, _, stats, centroids = cv2.connectedComponentsWithStats( | |
mask.astype(np.uint8), connectivity=8) | |
max_id = np.argmax(stats[1:, -1]) + 1 | |
positions.append(centroids[max_id]) | |
areas.append(stats[max_id, -1]) | |
areas = np.stack(areas, axis=0) | |
scales = _get_adaptive_scales(areas) | |
text_palette = get_palette(self.text_color, max_label + 1) | |
text_colors = [text_palette[label] for label in labels] | |
for i, (pos, label) in enumerate(zip(positions, labels)): | |
label_text = classes[label] | |
self.draw_texts( | |
label_text, | |
pos, | |
colors=text_colors[i], | |
font_sizes=int(13 * scales[i]), | |
bboxes=[{ | |
'facecolor': 'black', | |
'alpha': 0.8, | |
'pad': 0.7, | |
'edgecolor': 'none' | |
}], | |
horizontal_alignments='center') | |
return self.get_image() | |
def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData, | |
classes: Optional[List], | |
palette: Optional[List]) -> np.ndarray: | |
"""Draw semantic seg of GT or prediction. | |
Args: | |
image (np.ndarray): The image to draw. | |
sem_seg (:obj:`PixelData`): Data structure for pixel-level | |
annotations or predictions. | |
classes (list, optional): Input classes for result rendering, as | |
the prediction of segmentation model is a segment map with | |
label indices, `classes` is a list which includes items | |
responding to the label indices. If classes is not defined, | |
visualizer will take `cityscapes` classes by default. | |
Defaults to None. | |
palette (list, optional): Input palette for result rendering, which | |
is a list of color palette responding to the classes. | |
Defaults to None. | |
Returns: | |
np.ndarray: the drawn image which channel is RGB. | |
""" | |
sem_seg_data = sem_seg.sem_seg | |
if isinstance(sem_seg_data, torch.Tensor): | |
sem_seg_data = sem_seg_data.numpy() | |
# 0 ~ num_class, the value 0 means background | |
ids = np.unique(sem_seg_data) | |
ignore_index = sem_seg.metainfo.get('ignore_index', 255) | |
ids = ids[ids != ignore_index] | |
if 'label_names' in sem_seg: | |
# open set semseg | |
label_names = sem_seg.metainfo['label_names'] | |
else: | |
label_names = classes | |
labels = np.array(ids, dtype=np.int64) | |
colors = [palette[label] for label in labels] | |
self.set_image(image) | |
# draw semantic masks | |
for i, (label, color) in enumerate(zip(labels, colors)): | |
masks = sem_seg_data == label | |
self.draw_binary_masks(masks, colors=[color], alphas=self.alpha) | |
label_text = label_names[label] | |
_, _, stats, centroids = cv2.connectedComponentsWithStats( | |
masks[0].astype(np.uint8), connectivity=8) | |
if stats.shape[0] > 1: | |
largest_id = np.argmax(stats[1:, -1]) + 1 | |
centroids = centroids[largest_id] | |
areas = stats[largest_id, -1] | |
scales = _get_adaptive_scales(areas) | |
self.draw_texts( | |
label_text, | |
centroids, | |
colors=(255, 255, 255), | |
font_sizes=int(13 * scales), | |
horizontal_alignments='center', | |
bboxes=[{ | |
'facecolor': 'black', | |
'alpha': 0.8, | |
'pad': 0.7, | |
'edgecolor': 'none' | |
}]) | |
return self.get_image() | |
def add_datasample( | |
self, | |
name: str, | |
image: np.ndarray, | |
data_sample: Optional['DetDataSample'] = None, | |
draw_gt: bool = True, | |
draw_pred: bool = True, | |
show: bool = False, | |
wait_time: float = 0, | |
# TODO: Supported in mmengine's Viusalizer. | |
out_file: Optional[str] = None, | |
pred_score_thr: float = 0.3, | |
step: int = 0) -> None: | |
"""Draw datasample and save to all backends. | |
- If GT and prediction are plotted at the same time, they are | |
displayed in a stitched image where the left image is the | |
ground truth and the right image is the prediction. | |
- If ``show`` is True, all storage backends are ignored, and | |
the images will be displayed in a local window. | |
- If ``out_file`` is specified, the drawn image will be | |
saved to ``out_file``. t is usually used when the display | |
is not available. | |
Args: | |
name (str): The image identifier. | |
image (np.ndarray): The image to draw. | |
data_sample (:obj:`DetDataSample`, optional): A data | |
sample that contain annotations and predictions. | |
Defaults to None. | |
draw_gt (bool): Whether to draw GT DetDataSample. Default to True. | |
draw_pred (bool): Whether to draw Prediction DetDataSample. | |
Defaults to True. | |
show (bool): Whether to display the drawn image. Default to False. | |
wait_time (float): The interval of show (s). Defaults to 0. | |
out_file (str): Path to output file. Defaults to None. | |
pred_score_thr (float): The threshold to visualize the bboxes | |
and masks. Defaults to 0.3. | |
step (int): Global step value to record. Defaults to 0. | |
""" | |
image = image.clip(0, 255).astype(np.uint8) | |
classes = self.dataset_meta.get('classes', None) | |
palette = self.dataset_meta.get('palette', None) | |
gt_img_data = None | |
pred_img_data = None | |
if data_sample is not None: | |
data_sample = data_sample.cpu() | |
if draw_gt and data_sample is not None: | |
gt_img_data = image | |
if 'gt_instances' in data_sample: | |
gt_img_data = self._draw_instances(image, | |
data_sample.gt_instances, | |
classes, palette) | |
if 'gt_sem_seg' in data_sample: | |
gt_img_data = self._draw_sem_seg(gt_img_data, | |
data_sample.gt_sem_seg, | |
classes, palette) | |
if 'gt_panoptic_seg' in data_sample: | |
assert classes is not None, 'class information is ' \ | |
'not provided when ' \ | |
'visualizing panoptic ' \ | |
'segmentation results.' | |
gt_img_data = self._draw_panoptic_seg( | |
gt_img_data, data_sample.gt_panoptic_seg, classes, palette) | |
if draw_pred and data_sample is not None: | |
pred_img_data = image | |
if 'pred_instances' in data_sample: | |
pred_instances = data_sample.pred_instances | |
pred_instances = pred_instances[ | |
pred_instances.scores > pred_score_thr] | |
pred_img_data = self._draw_instances(image, pred_instances, | |
classes, palette) | |
if 'pred_sem_seg' in data_sample: | |
pred_img_data = self._draw_sem_seg(pred_img_data, | |
data_sample.pred_sem_seg, | |
classes, palette) | |
if 'pred_panoptic_seg' in data_sample: | |
assert classes is not None, 'class information is ' \ | |
'not provided when ' \ | |
'visualizing panoptic ' \ | |
'segmentation results.' | |
pred_img_data = self._draw_panoptic_seg( | |
pred_img_data, data_sample.pred_panoptic_seg.numpy(), | |
classes, palette) | |
if gt_img_data is not None and pred_img_data is not None: | |
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) | |
elif gt_img_data is not None: | |
drawn_img = gt_img_data | |
elif pred_img_data is not None: | |
drawn_img = pred_img_data | |
else: | |
# Display the original image directly if nothing is drawn. | |
drawn_img = image | |
# It is convenient for users to obtain the drawn image. | |
# For example, the user wants to obtain the drawn image and | |
# save it as a video during video inference. | |
self.set_image(drawn_img) | |
if show: | |
self.show(drawn_img, win_name=name, wait_time=wait_time) | |
if out_file is not None: | |
mmcv.imwrite(drawn_img[..., ::-1], out_file) | |
else: | |
self.add_image(name, drawn_img, step) | |
def random_color(seed): | |
"""Random a color according to the input seed.""" | |
if sns is None: | |
raise RuntimeError('motmetrics is not installed,\ | |
please install it by: pip install seaborn') | |
np.random.seed(seed) | |
colors = sns.color_palette() | |
color = colors[np.random.choice(range(len(colors)))] | |
color = tuple([int(255 * c) for c in color]) | |
return color | |
class TrackLocalVisualizer(Visualizer): | |
"""Tracking Local Visualizer for the MOT, VIS tasks. | |
Args: | |
name (str): Name of the instance. Defaults to 'visualizer'. | |
image (np.ndarray, optional): the origin image to draw. The format | |
should be RGB. Defaults to None. | |
vis_backends (list, optional): Visual backend config list. | |
Defaults to None. | |
save_dir (str, optional): Save file dir for all storage backends. | |
If it is None, the backend storage will not save any data. | |
line_width (int, float): The linewidth of lines. | |
Defaults to 3. | |
alpha (int, float): The transparency of bboxes or mask. | |
Defaults to 0.8. | |
""" | |
def __init__(self, | |
name: str = 'visualizer', | |
image: Optional[np.ndarray] = None, | |
vis_backends: Optional[Dict] = None, | |
save_dir: Optional[str] = None, | |
line_width: Union[int, float] = 3, | |
alpha: float = 0.8) -> None: | |
super().__init__(name, image, vis_backends, save_dir) | |
self.line_width = line_width | |
self.alpha = alpha | |
# Set default value. When calling | |
# `TrackLocalVisualizer().dataset_meta=xxx`, | |
# it will override the default value. | |
self.dataset_meta = {} | |
def _draw_instances(self, image: np.ndarray, | |
instances: InstanceData) -> np.ndarray: | |
"""Draw instances of GT or prediction. | |
Args: | |
image (np.ndarray): The image to draw. | |
instances (:obj:`InstanceData`): Data structure for | |
instance-level annotations or predictions. | |
Returns: | |
np.ndarray: the drawn image which channel is RGB. | |
""" | |
self.set_image(image) | |
classes = self.dataset_meta.get('classes', None) | |
# get colors and texts | |
# for the MOT and VIS tasks | |
colors = [random_color(_id) for _id in instances.instances_id] | |
categories = [ | |
classes[label] if classes is not None else f'cls{label}' | |
for label in instances.labels | |
] | |
if 'scores' in instances: | |
texts = [ | |
f'{category_name}\n{instance_id} | {score:.2f}' | |
for category_name, instance_id, score in zip( | |
categories, instances.instances_id, instances.scores) | |
] | |
else: | |
texts = [ | |
f'{category_name}\n{instance_id}' for category_name, | |
instance_id in zip(categories, instances.instances_id) | |
] | |
# draw bboxes and texts | |
if 'bboxes' in instances: | |
# draw bboxes | |
bboxes = instances.bboxes.clone() | |
self.draw_bboxes( | |
bboxes, | |
edge_colors=colors, | |
alpha=self.alpha, | |
line_widths=self.line_width) | |
# draw texts | |
if texts is not None: | |
positions = bboxes[:, :2] + self.line_width | |
areas = (bboxes[:, 3] - bboxes[:, 1]) * ( | |
bboxes[:, 2] - bboxes[:, 0]) | |
scales = _get_adaptive_scales(areas.cpu().numpy()) | |
for i, pos in enumerate(positions): | |
self.draw_texts( | |
texts[i], | |
pos, | |
colors='black', | |
font_sizes=int(13 * scales[i]), | |
bboxes=[{ | |
'facecolor': [c / 255 for c in colors[i]], | |
'alpha': 0.8, | |
'pad': 0.7, | |
'edgecolor': 'none' | |
}]) | |
# draw masks | |
if 'masks' in instances: | |
masks = instances.masks | |
polygons = [] | |
for i, mask in enumerate(masks): | |
contours, _ = bitmap_to_polygon(mask) | |
polygons.extend(contours) | |
self.draw_polygons(polygons, edge_colors='w', alpha=self.alpha) | |
self.draw_binary_masks(masks, colors=colors, alphas=self.alpha) | |
return self.get_image() | |
def add_datasample( | |
self, | |
name: str, | |
image: np.ndarray, | |
data_sample: DetDataSample = None, | |
draw_gt: bool = True, | |
draw_pred: bool = True, | |
show: bool = False, | |
wait_time: int = 0, | |
# TODO: Supported in mmengine's Viusalizer. | |
out_file: Optional[str] = None, | |
pred_score_thr: float = 0.3, | |
step: int = 0) -> None: | |
"""Draw datasample and save to all backends. | |
- If GT and prediction are plotted at the same time, they are | |
displayed in a stitched image where the left image is the | |
ground truth and the right image is the prediction. | |
- If ``show`` is True, all storage backends are ignored, and | |
the images will be displayed in a local window. | |
- If ``out_file`` is specified, the drawn image will be | |
saved to ``out_file``. t is usually used when the display | |
is not available. | |
Args: | |
name (str): The image identifier. | |
image (np.ndarray): The image to draw. | |
data_sample (OptTrackSampleList): A data | |
sample that contain annotations and predictions. | |
Defaults to None. | |
draw_gt (bool): Whether to draw GT TrackDataSample. | |
Default to True. | |
draw_pred (bool): Whether to draw Prediction TrackDataSample. | |
Defaults to True. | |
show (bool): Whether to display the drawn image. Default to False. | |
wait_time (int): The interval of show (s). Defaults to 0. | |
out_file (str): Path to output file. Defaults to None. | |
pred_score_thr (float): The threshold to visualize the bboxes | |
and masks. Defaults to 0.3. | |
step (int): Global step value to record. Defaults to 0. | |
""" | |
gt_img_data = None | |
pred_img_data = None | |
if data_sample is not None: | |
data_sample = data_sample.cpu() | |
if draw_gt and data_sample is not None: | |
assert 'gt_instances' in data_sample | |
gt_img_data = self._draw_instances(image, data_sample.gt_instances) | |
if draw_pred and data_sample is not None: | |
assert 'pred_track_instances' in data_sample | |
pred_instances = data_sample.pred_track_instances | |
if 'scores' in pred_instances: | |
pred_instances = pred_instances[ | |
pred_instances.scores > pred_score_thr].cpu() | |
pred_img_data = self._draw_instances(image, pred_instances) | |
if gt_img_data is not None and pred_img_data is not None: | |
drawn_img = np.concatenate((gt_img_data, pred_img_data), axis=1) | |
elif gt_img_data is not None: | |
drawn_img = gt_img_data | |
else: | |
drawn_img = pred_img_data | |
if show: | |
self.show(drawn_img, win_name=name, wait_time=wait_time) | |
if out_file is not None: | |
mmcv.imwrite(drawn_img[..., ::-1], out_file) | |
else: | |
self.add_image(name, drawn_img, step) | |