KyanChen's picture
init
f549064
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Union
import numpy as np
import torch
from mmcls.registry import BATCH_AUGMENTS
class RandomBatchAugment:
"""Randomly choose one batch augmentation to apply.
Args:
augments (Callable | dict | list): configs of batch
augmentations.
probs (float | List[float] | None): The probabilities of each batch
augmentations. If None, choose evenly. Defaults to None.
Example:
>>> import torch
>>> import torch.nn.functional as F
>>> from mmcls.models import RandomBatchAugment
>>> augments_cfg = [
... dict(type='CutMix', alpha=1.),
... dict(type='Mixup', alpha=1.)
... ]
>>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3])
>>> imgs = torch.rand(16, 3, 32, 32)
>>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10)
>>> imgs, label = batch_augment(imgs, label)
.. note ::
To decide which batch augmentation will be used, it picks one of
``augments`` based on the probabilities. In the example above, the
probability to use CutMix is 0.5, to use Mixup is 0.3, and to do
nothing is 0.2.
"""
def __init__(self, augments: Union[Callable, dict, list], probs=None):
if not isinstance(augments, (tuple, list)):
augments = [augments]
self.augments = []
for aug in augments:
if isinstance(aug, dict):
self.augments.append(BATCH_AUGMENTS.build(aug))
else:
self.augments.append(aug)
if isinstance(probs, float):
probs = [probs]
if probs is not None:
assert len(augments) == len(probs), \
'``augments`` and ``probs`` must have same lengths. ' \
f'Got {len(augments)} vs {len(probs)}.'
assert sum(probs) <= 1, \
'The total probability of batch augments exceeds 1.'
self.augments.append(None)
probs.append(1 - sum(probs))
self.probs = probs
def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor):
"""Randomly apply batch augmentations to the batch inputs and batch
data samples."""
aug_index = np.random.choice(len(self.augments), p=self.probs)
aug = self.augments[aug_index]
if aug is not None:
return aug(batch_input, batch_score)
else:
return batch_input, batch_score.float()