# Copyright (c) OpenMMLab. All rights reserved. from typing import Callable, Union import numpy as np import torch from mmcls.registry import BATCH_AUGMENTS class RandomBatchAugment: """Randomly choose one batch augmentation to apply. Args: augments (Callable | dict | list): configs of batch augmentations. probs (float | List[float] | None): The probabilities of each batch augmentations. If None, choose evenly. Defaults to None. Example: >>> import torch >>> import torch.nn.functional as F >>> from mmcls.models import RandomBatchAugment >>> augments_cfg = [ ... dict(type='CutMix', alpha=1.), ... dict(type='Mixup', alpha=1.) ... ] >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) >>> imgs = torch.rand(16, 3, 32, 32) >>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10) >>> imgs, label = batch_augment(imgs, label) .. note :: To decide which batch augmentation will be used, it picks one of ``augments`` based on the probabilities. In the example above, the probability to use CutMix is 0.5, to use Mixup is 0.3, and to do nothing is 0.2. """ def __init__(self, augments: Union[Callable, dict, list], probs=None): if not isinstance(augments, (tuple, list)): augments = [augments] self.augments = [] for aug in augments: if isinstance(aug, dict): self.augments.append(BATCH_AUGMENTS.build(aug)) else: self.augments.append(aug) if isinstance(probs, float): probs = [probs] if probs is not None: assert len(augments) == len(probs), \ '``augments`` and ``probs`` must have same lengths. ' \ f'Got {len(augments)} vs {len(probs)}.' assert sum(probs) <= 1, \ 'The total probability of batch augments exceeds 1.' self.augments.append(None) probs.append(1 - sum(probs)) self.probs = probs def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor): """Randomly apply batch augmentations to the batch inputs and batch data samples.""" aug_index = np.random.choice(len(self.augments), p=self.probs) aug = self.augments[aug_index] if aug is not None: return aug(batch_input, batch_score) else: return batch_input, batch_score.float()