|
|
|
import itertools |
|
import logging |
|
import numpy as np |
|
from collections import OrderedDict |
|
from collections.abc import Mapping |
|
from typing import Dict, List, Optional, Tuple, Union |
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
from torch import Tensor, nn |
|
|
|
from annotator.oneformer.detectron2.layers import ShapeSpec |
|
from annotator.oneformer.detectron2.structures import BitMasks, Boxes, ImageList, Instances |
|
from annotator.oneformer.detectron2.utils.events import get_event_storage |
|
|
|
from .backbone import Backbone |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _to_container(cfg): |
|
""" |
|
mmdet will assert the type of dict/list. |
|
So convert omegaconf objects to dict/list. |
|
""" |
|
if isinstance(cfg, DictConfig): |
|
cfg = OmegaConf.to_container(cfg, resolve=True) |
|
from mmcv.utils import ConfigDict |
|
|
|
return ConfigDict(cfg) |
|
|
|
|
|
class MMDetBackbone(Backbone): |
|
""" |
|
Wrapper of mmdetection backbones to use in detectron2. |
|
|
|
mmdet backbones produce list/tuple of tensors, while detectron2 backbones |
|
produce a dict of tensors. This class wraps the given backbone to produce |
|
output in detectron2's convention, so it can be used in place of detectron2 |
|
backbones. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
backbone: Union[nn.Module, Mapping], |
|
neck: Union[nn.Module, Mapping, None] = None, |
|
*, |
|
output_shapes: List[ShapeSpec], |
|
output_names: Optional[List[str]] = None, |
|
): |
|
""" |
|
Args: |
|
backbone: either a backbone module or a mmdet config dict that defines a |
|
backbone. The backbone takes a 4D image tensor and returns a |
|
sequence of tensors. |
|
neck: either a backbone module or a mmdet config dict that defines a |
|
neck. The neck takes outputs of backbone and returns a |
|
sequence of tensors. If None, no neck is used. |
|
output_shapes: shape for every output of the backbone (or neck, if given). |
|
stride and channels are often needed. |
|
output_names: names for every output of the backbone (or neck, if given). |
|
By default, will use "out0", "out1", ... |
|
""" |
|
super().__init__() |
|
if isinstance(backbone, Mapping): |
|
from mmdet.models import build_backbone |
|
|
|
backbone = build_backbone(_to_container(backbone)) |
|
self.backbone = backbone |
|
|
|
if isinstance(neck, Mapping): |
|
from mmdet.models import build_neck |
|
|
|
neck = build_neck(_to_container(neck)) |
|
self.neck = neck |
|
|
|
|
|
|
|
|
|
logger.info("Initializing mmdet backbone weights...") |
|
self.backbone.init_weights() |
|
|
|
|
|
|
|
self.backbone.train() |
|
if self.neck is not None: |
|
logger.info("Initializing mmdet neck weights ...") |
|
if isinstance(self.neck, nn.Sequential): |
|
for m in self.neck: |
|
m.init_weights() |
|
else: |
|
self.neck.init_weights() |
|
self.neck.train() |
|
|
|
self._output_shapes = output_shapes |
|
if not output_names: |
|
output_names = [f"out{i}" for i in range(len(output_shapes))] |
|
self._output_names = output_names |
|
|
|
def forward(self, x) -> Dict[str, Tensor]: |
|
outs = self.backbone(x) |
|
if self.neck is not None: |
|
outs = self.neck(outs) |
|
assert isinstance( |
|
outs, (list, tuple) |
|
), "mmdet backbone should return a list/tuple of tensors!" |
|
if len(outs) != len(self._output_shapes): |
|
raise ValueError( |
|
"Length of output_shapes does not match outputs from the mmdet backbone: " |
|
f"{len(outs)} != {len(self._output_shapes)}" |
|
) |
|
return {k: v for k, v in zip(self._output_names, outs)} |
|
|
|
def output_shape(self) -> Dict[str, ShapeSpec]: |
|
return {k: v for k, v in zip(self._output_names, self._output_shapes)} |
|
|
|
|
|
class MMDetDetector(nn.Module): |
|
""" |
|
Wrapper of a mmdetection detector model, for detection and instance segmentation. |
|
Input/output formats of this class follow detectron2's convention, so a |
|
mmdetection model can be trained and evaluated in detectron2. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
detector: Union[nn.Module, Mapping], |
|
*, |
|
|
|
|
|
size_divisibility=32, |
|
pixel_mean: Tuple[float], |
|
pixel_std: Tuple[float], |
|
): |
|
""" |
|
Args: |
|
detector: a mmdet detector, or a mmdet config dict that defines a detector. |
|
size_divisibility: pad input images to multiple of this number |
|
pixel_mean: per-channel mean to normalize input image |
|
pixel_std: per-channel stddev to normalize input image |
|
""" |
|
super().__init__() |
|
if isinstance(detector, Mapping): |
|
from mmdet.models import build_detector |
|
|
|
detector = build_detector(_to_container(detector)) |
|
self.detector = detector |
|
self.detector.init_weights() |
|
self.size_divisibility = size_divisibility |
|
|
|
self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) |
|
self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) |
|
assert ( |
|
self.pixel_mean.shape == self.pixel_std.shape |
|
), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" |
|
|
|
def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): |
|
images = [x["image"].to(self.device) for x in batched_inputs] |
|
images = [(x - self.pixel_mean) / self.pixel_std for x in images] |
|
images = ImageList.from_tensors(images, size_divisibility=self.size_divisibility).tensor |
|
metas = [] |
|
rescale = {"height" in x for x in batched_inputs} |
|
if len(rescale) != 1: |
|
raise ValueError("Some inputs have original height/width, but some don't!") |
|
rescale = list(rescale)[0] |
|
output_shapes = [] |
|
for input in batched_inputs: |
|
meta = {} |
|
c, h, w = input["image"].shape |
|
meta["img_shape"] = meta["ori_shape"] = (h, w, c) |
|
if rescale: |
|
scale_factor = np.array( |
|
[w / input["width"], h / input["height"]] * 2, dtype="float32" |
|
) |
|
ori_shape = (input["height"], input["width"]) |
|
output_shapes.append(ori_shape) |
|
meta["ori_shape"] = ori_shape + (c,) |
|
else: |
|
scale_factor = 1.0 |
|
output_shapes.append((h, w)) |
|
meta["scale_factor"] = scale_factor |
|
meta["flip"] = False |
|
padh, padw = images.shape[-2:] |
|
meta["pad_shape"] = (padh, padw, c) |
|
metas.append(meta) |
|
|
|
if self.training: |
|
gt_instances = [x["instances"].to(self.device) for x in batched_inputs] |
|
if gt_instances[0].has("gt_masks"): |
|
from mmdet.core import PolygonMasks as mm_PolygonMasks, BitmapMasks as mm_BitMasks |
|
|
|
def convert_mask(m, shape): |
|
|
|
if isinstance(m, BitMasks): |
|
return mm_BitMasks(m.tensor.cpu().numpy(), shape[0], shape[1]) |
|
else: |
|
return mm_PolygonMasks(m.polygons, shape[0], shape[1]) |
|
|
|
gt_masks = [convert_mask(x.gt_masks, x.image_size) for x in gt_instances] |
|
losses_and_metrics = self.detector.forward_train( |
|
images, |
|
metas, |
|
[x.gt_boxes.tensor for x in gt_instances], |
|
[x.gt_classes for x in gt_instances], |
|
gt_masks=gt_masks, |
|
) |
|
else: |
|
losses_and_metrics = self.detector.forward_train( |
|
images, |
|
metas, |
|
[x.gt_boxes.tensor for x in gt_instances], |
|
[x.gt_classes for x in gt_instances], |
|
) |
|
return _parse_losses(losses_and_metrics) |
|
else: |
|
results = self.detector.simple_test(images, metas, rescale=rescale) |
|
results = [ |
|
{"instances": _convert_mmdet_result(r, shape)} |
|
for r, shape in zip(results, output_shapes) |
|
] |
|
return results |
|
|
|
@property |
|
def device(self): |
|
return self.pixel_mean.device |
|
|
|
|
|
|
|
|
|
def _convert_mmdet_result(result, shape: Tuple[int, int]) -> Instances: |
|
if isinstance(result, tuple): |
|
bbox_result, segm_result = result |
|
if isinstance(segm_result, tuple): |
|
segm_result = segm_result[0] |
|
else: |
|
bbox_result, segm_result = result, None |
|
|
|
bboxes = torch.from_numpy(np.vstack(bbox_result)) |
|
bboxes, scores = bboxes[:, :4], bboxes[:, -1] |
|
labels = [ |
|
torch.full((bbox.shape[0],), i, dtype=torch.int32) for i, bbox in enumerate(bbox_result) |
|
] |
|
labels = torch.cat(labels) |
|
inst = Instances(shape) |
|
inst.pred_boxes = Boxes(bboxes) |
|
inst.scores = scores |
|
inst.pred_classes = labels |
|
|
|
if segm_result is not None and len(labels) > 0: |
|
segm_result = list(itertools.chain(*segm_result)) |
|
segm_result = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in segm_result] |
|
segm_result = torch.stack(segm_result, dim=0) |
|
inst.pred_masks = segm_result |
|
return inst |
|
|
|
|
|
|
|
def _parse_losses(losses: Dict[str, Tensor]) -> Dict[str, Tensor]: |
|
log_vars = OrderedDict() |
|
for loss_name, loss_value in losses.items(): |
|
if isinstance(loss_value, torch.Tensor): |
|
log_vars[loss_name] = loss_value.mean() |
|
elif isinstance(loss_value, list): |
|
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) |
|
else: |
|
raise TypeError(f"{loss_name} is not a tensor or list of tensors") |
|
|
|
if "loss" not in loss_name: |
|
|
|
storage = get_event_storage() |
|
value = log_vars.pop(loss_name).cpu().item() |
|
storage.put_scalar(loss_name, value) |
|
return log_vars |
|
|