Spaces:
Runtime error
Runtime error
File size: 8,209 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmcls.registry import MODELS
from mmcls.structures import (ClsDataSample, MultiTaskDataSample,
batch_label_to_onehot, cat_batch_labels,
stack_batch_scores, tensor_split)
from .batch_augments import RandomBatchAugment
@MODELS.register_module()
class ClsDataPreprocessor(BaseDataPreprocessor):
"""Image pre-processor for classification tasks.
Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
to_onehot (bool): Whether to generate one-hot format gt-labels and set
to data samples. Defaults to False.
num_classes (int, optional): The number of classes. Defaults to None.
batch_augments (dict, optional): The batch augmentations settings,
including "augments" and "probs". For more details, see
:class:`mmcls.models.RandomBatchAugment`.
"""
def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Number = 0,
to_rgb: bool = False,
to_onehot: bool = False,
num_classes: Optional[int] = None,
batch_augments: Optional[dict] = None):
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.to_rgb = to_rgb
self.to_onehot = to_onehot
self.num_classes = num_classes
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both `mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
if batch_augments is not None:
self.batch_augments = RandomBatchAugment(**batch_augments)
if not self.to_onehot:
from mmengine.logging import MMLogger
MMLogger.get_current_instance().info(
'Because batch augmentations are enabled, the data '
'preprocessor automatically enables the `to_onehot` '
'option to generate one-hot format labels.')
self.to_onehot = True
else:
self.batch_augments = None
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.
Args:
data (dict): data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
"""
inputs = self.cast_data(data['inputs'])
if isinstance(inputs, torch.Tensor):
# The branch if use `default_collate` as the collate_fn in the
# dataloader.
# ------ To RGB ------
if self.to_rgb and inputs.size(1) == 3:
inputs = inputs.flip(1)
# -- Normalization ---
inputs = inputs.float()
if self._enable_normalize:
inputs = (inputs - self.mean) / self.std
# ------ Padding -----
if self.pad_size_divisor > 1:
h, w = inputs.shape[-2:]
target_h = math.ceil(
h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(
w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
self.pad_value)
else:
# The branch if use `pseudo_collate` as the collate_fn in the
# dataloader.
processed_inputs = []
for input_ in inputs:
# ------ To RGB ------
if self.to_rgb and input_.size(0) == 3:
input_ = input_.flip(0)
# -- Normalization ---
input_ = input_.float()
if self._enable_normalize:
input_ = (input_ - self.mean) / self.std
processed_inputs.append(input_)
# Combine padding and stack
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
self.pad_value)
data_samples = data.get('data_samples', None)
sample_item = data_samples[0] if data_samples is not None else None
if isinstance(sample_item,
ClsDataSample) and 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(
gt_labels, device=self.device)
batch_score = stack_batch_scores(gt_labels, device=self.device)
if batch_score is None and self.to_onehot:
assert batch_label is not None, \
'Cannot generate onehot format labels because no labels.'
num_classes = self.num_classes or data_samples[0].get(
'num_classes')
assert num_classes is not None, \
'Cannot generate one-hot format labels because not set ' \
'`num_classes` in `data_preprocessor`.'
batch_score = batch_label_to_onehot(batch_label, label_indices,
num_classes)
# ----- Batch Augmentations ----
if training and self.batch_augments is not None:
inputs, batch_score = self.batch_augments(inputs, batch_score)
# ----- scatter labels and scores to data samples ---
if batch_label is not None:
for sample, label in zip(
data_samples, tensor_split(batch_label,
label_indices)):
sample.set_gt_label(label)
if batch_score is not None:
for sample, score in zip(data_samples, batch_score):
sample.set_gt_score(score)
elif isinstance(sample_item, MultiTaskDataSample):
data_samples = self.cast_data(data_samples)
return {'inputs': inputs, 'data_samples': data_samples}
|