|
import matplotlib.pyplot as plt |
|
import annotator.uniformer.mmcv as mmcv |
|
import torch |
|
from annotator.uniformer.mmcv.parallel import collate, scatter |
|
from annotator.uniformer.mmcv.runner import load_checkpoint |
|
|
|
from annotator.uniformer.mmseg.datasets.pipelines import Compose |
|
from annotator.uniformer.mmseg.models import build_segmentor |
|
|
|
|
|
def init_segmentor(config, checkpoint=None, device='cuda:0'): |
|
"""Initialize a segmentor from config file. |
|
|
|
Args: |
|
config (str or :obj:`mmcv.Config`): Config file path or the config |
|
object. |
|
checkpoint (str, optional): Checkpoint path. If left as None, the model |
|
will not load any weights. |
|
device (str, optional) CPU/CUDA device option. Default 'cuda:0'. |
|
Use 'cpu' for loading model on CPU. |
|
Returns: |
|
nn.Module: The constructed segmentor. |
|
""" |
|
if isinstance(config, str): |
|
config = mmcv.Config.fromfile(config) |
|
elif not isinstance(config, mmcv.Config): |
|
raise TypeError('config must be a filename or Config object, ' |
|
'but got {}'.format(type(config))) |
|
config.model.pretrained = None |
|
config.model.train_cfg = None |
|
model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) |
|
if checkpoint is not None: |
|
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') |
|
model.CLASSES = checkpoint['meta']['CLASSES'] |
|
model.PALETTE = checkpoint['meta']['PALETTE'] |
|
model.cfg = config |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
|
|
class LoadImage: |
|
"""A simple pipeline to load image.""" |
|
|
|
def __call__(self, results): |
|
"""Call function to load images into results. |
|
|
|
Args: |
|
results (dict): A result dict contains the file name |
|
of the image to be read. |
|
|
|
Returns: |
|
dict: ``results`` will be returned containing loaded image. |
|
""" |
|
|
|
if isinstance(results['img'], str): |
|
results['filename'] = results['img'] |
|
results['ori_filename'] = results['img'] |
|
else: |
|
results['filename'] = None |
|
results['ori_filename'] = None |
|
img = mmcv.imread(results['img']) |
|
results['img'] = img |
|
results['img_shape'] = img.shape |
|
results['ori_shape'] = img.shape |
|
return results |
|
|
|
|
|
def inference_segmentor(model, img): |
|
"""Inference image(s) with the segmentor. |
|
|
|
Args: |
|
model (nn.Module): The loaded segmentor. |
|
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded |
|
images. |
|
|
|
Returns: |
|
(list[Tensor]): The segmentation result. |
|
""" |
|
cfg = model.cfg |
|
device = next(model.parameters()).device |
|
|
|
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] |
|
test_pipeline = Compose(test_pipeline) |
|
|
|
data = dict(img=img) |
|
data = test_pipeline(data) |
|
data = collate([data], samples_per_gpu=1) |
|
if next(model.parameters()).is_cuda: |
|
|
|
data = scatter(data, [device])[0] |
|
else: |
|
data['img_metas'] = [i.data[0] for i in data['img_metas']] |
|
|
|
|
|
with torch.no_grad(): |
|
result = model(return_loss=False, rescale=True, **data) |
|
return result |
|
|
|
|
|
def show_result_pyplot(model, |
|
img, |
|
result, |
|
palette=None, |
|
fig_size=(15, 10), |
|
opacity=0.5, |
|
title='', |
|
block=True): |
|
"""Visualize the segmentation results on the image. |
|
|
|
Args: |
|
model (nn.Module): The loaded segmentor. |
|
img (str or np.ndarray): Image filename or loaded image. |
|
result (list): The segmentation result. |
|
palette (list[list[int]]] | None): The palette of segmentation |
|
map. If None is given, random palette will be generated. |
|
Default: None |
|
fig_size (tuple): Figure size of the pyplot figure. |
|
opacity(float): Opacity of painted segmentation map. |
|
Default 0.5. |
|
Must be in (0, 1] range. |
|
title (str): The title of pyplot figure. |
|
Default is ''. |
|
block (bool): Whether to block the pyplot figure. |
|
Default is True. |
|
""" |
|
if hasattr(model, 'module'): |
|
model = model.module |
|
img = model.show_result( |
|
img, result, palette=palette, show=False, opacity=opacity) |
|
|
|
|
|
|
|
|
|
|
|
return mmcv.bgr2rgb(img) |
|
|