Spaces:
Runtime error
Runtime error
import copy | |
import numpy as np | |
from typing import List | |
import torch | |
from fvcore.transforms import NoOpTransform | |
from torch import nn | |
from detectron2.config import configurable | |
from detectron2.data.transforms import ( | |
RandomFlip, | |
ResizeShortestEdge, | |
ResizeTransform, | |
apply_augmentations, | |
) | |
__all__ = ["DatasetMapperTTA"] | |
class DatasetMapperTTA: | |
""" | |
Implement test-time augmentation for detection data. | |
It is a callable which takes a dataset dict from a detection dataset, | |
and returns a list of dataset dicts where the images | |
are augmented from the input image by the transformations defined in the config. | |
This is used for test-time augmentation. | |
""" | |
def __init__(self, min_sizes: List[int], max_size: int, flip: bool): | |
""" | |
Args: | |
min_sizes: list of short-edge size to resize the image to | |
max_size: maximum height or width of resized images | |
flip: whether to apply flipping augmentation | |
""" | |
self.min_sizes = min_sizes | |
self.max_size = max_size | |
self.flip = flip | |
def from_config(cls, cfg): | |
return { | |
"min_sizes": cfg.TEST.AUG.MIN_SIZES, | |
"max_size": cfg.TEST.AUG.MAX_SIZE, | |
"flip": cfg.TEST.AUG.FLIP, | |
} | |
def __call__(self, dataset_dict): | |
""" | |
Args: | |
dict: a dict in standard model input format. See tutorials for details. | |
Returns: | |
list[dict]: | |
a list of dicts, which contain augmented version of the input image. | |
The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``. | |
Each dict has field "transforms" which is a TransformList, | |
containing the transforms that are used to generate this image. | |
""" | |
numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy() | |
shape = numpy_image.shape | |
orig_shape = (dataset_dict["height"], dataset_dict["width"]) | |
if shape[:2] != orig_shape: | |
# It transforms the "original" image in the dataset to the input image | |
pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1]) | |
else: | |
pre_tfm = NoOpTransform() | |
# Create all combinations of augmentations to use | |
aug_candidates = [] # each element is a list[Augmentation] | |
for min_size in self.min_sizes: | |
resize = ResizeShortestEdge(min_size, self.max_size) | |
aug_candidates.append([resize]) # resize only | |
if self.flip: | |
flip = RandomFlip(prob=1.0) | |
aug_candidates.append([resize, flip]) # resize + flip | |
# Apply all the augmentations | |
ret = [] | |
for aug in aug_candidates: | |
new_image, tfms = apply_augmentations(aug, np.copy(numpy_image)) | |
torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1))) | |
dic = copy.deepcopy(dataset_dict) | |
dic["transforms"] = pre_tfm + tfms | |
dic["image"] = torch_image | |
ret.append(dic) | |
return ret |