Spaces:
Running
on
T4
Running
on
T4
""" Random Erasing (Cutout) | |
Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 | |
Copyright Zhun Zhong & Liang Zheng | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import random | |
import math | |
import torch | |
def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): | |
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_() | |
# paths, flip the order so normal is run on CPU if this becomes a problem | |
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 | |
if per_pixel: | |
return torch.empty(patch_size, dtype=dtype, device=device).normal_() | |
elif rand_color: | |
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() | |
else: | |
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) | |
class RandomErasing: | |
""" Randomly selects a rectangle region in an image and erases its pixels. | |
'Random Erasing Data Augmentation' by Zhong et al. | |
See https://arxiv.org/pdf/1708.04896.pdf | |
This variant of RandomErasing is intended to be applied to either a batch | |
or single image tensor after it has been normalized by dataset mean and std. | |
Args: | |
probability: Probability that the Random Erasing operation will be performed. | |
min_area: Minimum percentage of erased area wrt input image area. | |
max_area: Maximum percentage of erased area wrt input image area. | |
min_aspect: Minimum aspect ratio of erased area. | |
mode: pixel color mode, one of 'const', 'rand', or 'pixel' | |
'const' - erase block is constant color of 0 for all channels | |
'rand' - erase block is same per-channel random (normal) color | |
'pixel' - erase block is per-pixel random (normal) color | |
max_count: maximum number of erasing blocks per image, area per box is scaled by count. | |
per-image count is randomly chosen between 1 and this value. | |
""" | |
def __init__( | |
self, | |
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, | |
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): | |
self.probability = probability | |
self.min_area = min_area | |
self.max_area = max_area | |
max_aspect = max_aspect or 1 / min_aspect | |
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) | |
self.min_count = min_count | |
self.max_count = max_count or min_count | |
self.num_splits = num_splits | |
mode = mode.lower() | |
self.rand_color = False | |
self.per_pixel = False | |
if mode == 'rand': | |
self.rand_color = True # per block random normal | |
elif mode == 'pixel': | |
self.per_pixel = True # per pixel random normal | |
else: | |
assert not mode or mode == 'const' | |
self.device = device | |
def _erase(self, img, chan, img_h, img_w, dtype): | |
if random.random() > self.probability: | |
return | |
area = img_h * img_w | |
count = self.min_count if self.min_count == self.max_count else \ | |
random.randint(self.min_count, self.max_count) | |
for _ in range(count): | |
for attempt in range(10): | |
target_area = random.uniform(self.min_area, self.max_area) * area / count | |
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) | |
h = int(round(math.sqrt(target_area * aspect_ratio))) | |
w = int(round(math.sqrt(target_area / aspect_ratio))) | |
if w < img_w and h < img_h: | |
top = random.randint(0, img_h - h) | |
left = random.randint(0, img_w - w) | |
img[:, top:top + h, left:left + w] = _get_pixels( | |
self.per_pixel, self.rand_color, (chan, h, w), | |
dtype=dtype, device=self.device) | |
break | |
def __call__(self, input): | |
if len(input.size()) == 3: | |
self._erase(input, *input.size(), input.dtype) | |
else: | |
batch_size, chan, img_h, img_w = input.size() | |
# skip first slice of batch if num_splits is set (for clean portion of samples) | |
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 | |
for i in range(batch_start, batch_size): | |
self._erase(input[i], chan, img_h, img_w, input.dtype) | |
return input | |