Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import math | |
import numpy as np | |
import random | |
import torch | |
import torchvision.transforms.functional as F | |
from PIL import Image | |
from torchvision import transforms | |
from .rand_augment import rand_augment_transform | |
from .random_erasing import RandomErasing | |
import numbers | |
import PIL | |
import torchvision | |
import vbench.third_party.umt.functional as FF | |
_pil_interpolation_to_str = { | |
Image.NEAREST: "PIL.Image.NEAREST", | |
Image.BILINEAR: "PIL.Image.BILINEAR", | |
Image.BICUBIC: "PIL.Image.BICUBIC", | |
Image.LANCZOS: "PIL.Image.LANCZOS", | |
Image.HAMMING: "PIL.Image.HAMMING", | |
Image.BOX: "PIL.Image.BOX", | |
} | |
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) | |
def _pil_interp(method): | |
if method == "bicubic": | |
return Image.BICUBIC | |
elif method == "lanczos": | |
return Image.LANCZOS | |
elif method == "hamming": | |
return Image.HAMMING | |
else: | |
return Image.BILINEAR | |
def random_short_side_scale_jitter( | |
images, min_size, max_size, boxes=None, inverse_uniform_sampling=False | |
): | |
""" | |
Perform a spatial short scale jittering on the given images and | |
corresponding boxes. | |
Args: | |
images (tensor): images to perform scale jitter. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
min_size (int): the minimal size to scale the frames. | |
max_size (int): the maximal size to scale the frames. | |
boxes (ndarray): optional. Corresponding boxes to images. | |
Dimension is `num boxes` x 4. | |
inverse_uniform_sampling (bool): if True, sample uniformly in | |
[1 / max_scale, 1 / min_scale] and take a reciprocal to get the | |
scale. If False, take a uniform sample from [min_scale, max_scale]. | |
Returns: | |
(tensor): the scaled images with dimension of | |
`num frames` x `channel` x `new height` x `new width`. | |
(ndarray or None): the scaled boxes with dimension of | |
`num boxes` x 4. | |
""" | |
if inverse_uniform_sampling: | |
size = int( | |
round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) | |
) | |
else: | |
size = int(round(np.random.uniform(min_size, max_size))) | |
height = images.shape[2] | |
width = images.shape[3] | |
if (width <= height and width == size) or ( | |
height <= width and height == size | |
): | |
return images, boxes | |
new_width = size | |
new_height = size | |
if width < height: | |
new_height = int(math.floor((float(height) / width) * size)) | |
if boxes is not None: | |
boxes = boxes * float(new_height) / height | |
else: | |
new_width = int(math.floor((float(width) / height) * size)) | |
if boxes is not None: | |
boxes = boxes * float(new_width) / width | |
return ( | |
torch.nn.functional.interpolate( | |
images, | |
size=(new_height, new_width), | |
mode="bilinear", | |
align_corners=False, | |
), | |
boxes, | |
) | |
def crop_boxes(boxes, x_offset, y_offset): | |
""" | |
Peform crop on the bounding boxes given the offsets. | |
Args: | |
boxes (ndarray or None): bounding boxes to peform crop. The dimension | |
is `num boxes` x 4. | |
x_offset (int): cropping offset in the x axis. | |
y_offset (int): cropping offset in the y axis. | |
Returns: | |
cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
cropped_boxes = boxes.copy() | |
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset | |
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset | |
return cropped_boxes | |
def random_crop(images, size, boxes=None): | |
""" | |
Perform random spatial crop on the given images and corresponding boxes. | |
Args: | |
images (tensor): images to perform random crop. The dimension is | |
`num frames` x `channel` x `height` x `width`. | |
size (int): the size of height and width to crop on the image. | |
boxes (ndarray or None): optional. Corresponding boxes to images. | |
Dimension is `num boxes` x 4. | |
Returns: | |
cropped (tensor): cropped images with dimension of | |
`num frames` x `channel` x `size` x `size`. | |
cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
if images.shape[2] == size and images.shape[3] == size: | |
return images | |
height = images.shape[2] | |
width = images.shape[3] | |
y_offset = 0 | |
if height > size: | |
y_offset = int(np.random.randint(0, height - size)) | |
x_offset = 0 | |
if width > size: | |
x_offset = int(np.random.randint(0, width - size)) | |
cropped = images[ | |
:, :, y_offset : y_offset + size, x_offset : x_offset + size | |
] | |
cropped_boxes = ( | |
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None | |
) | |
return cropped, cropped_boxes | |
def horizontal_flip(prob, images, boxes=None): | |
""" | |
Perform horizontal flip on the given images and corresponding boxes. | |
Args: | |
prob (float): probility to flip the images. | |
images (tensor): images to perform horizontal flip, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
boxes (ndarray or None): optional. Corresponding boxes to images. | |
Dimension is `num boxes` x 4. | |
Returns: | |
images (tensor): images with dimension of | |
`num frames` x `channel` x `height` x `width`. | |
flipped_boxes (ndarray or None): the flipped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
if boxes is None: | |
flipped_boxes = None | |
else: | |
flipped_boxes = boxes.copy() | |
if np.random.uniform() < prob: | |
images = images.flip((-1)) | |
if len(images.shape) == 3: | |
width = images.shape[2] | |
elif len(images.shape) == 4: | |
width = images.shape[3] | |
else: | |
raise NotImplementedError("Dimension does not supported") | |
if boxes is not None: | |
flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 | |
return images, flipped_boxes | |
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): | |
""" | |
Perform uniform spatial sampling on the images and corresponding boxes. | |
Args: | |
images (tensor): images to perform uniform crop. The dimension is | |
`num frames` x `channel` x `height` x `width`. | |
size (int): size of height and weight to crop the images. | |
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width | |
is larger than height. Or 0, 1, or 2 for top, center, and bottom | |
crop if height is larger than width. | |
boxes (ndarray or None): optional. Corresponding boxes to images. | |
Dimension is `num boxes` x 4. | |
scale_size (int): optinal. If not None, resize the images to scale_size before | |
performing any crop. | |
Returns: | |
cropped (tensor): images with dimension of | |
`num frames` x `channel` x `size` x `size`. | |
cropped_boxes (ndarray or None): the cropped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
assert spatial_idx in [0, 1, 2] | |
ndim = len(images.shape) | |
if ndim == 3: | |
images = images.unsqueeze(0) | |
height = images.shape[2] | |
width = images.shape[3] | |
if scale_size is not None: | |
if width <= height: | |
width, height = scale_size, int(height / width * scale_size) | |
else: | |
width, height = int(width / height * scale_size), scale_size | |
images = torch.nn.functional.interpolate( | |
images, | |
size=(height, width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
y_offset = int(math.ceil((height - size) / 2)) | |
x_offset = int(math.ceil((width - size) / 2)) | |
if height > width: | |
if spatial_idx == 0: | |
y_offset = 0 | |
elif spatial_idx == 2: | |
y_offset = height - size | |
else: | |
if spatial_idx == 0: | |
x_offset = 0 | |
elif spatial_idx == 2: | |
x_offset = width - size | |
cropped = images[ | |
:, :, y_offset : y_offset + size, x_offset : x_offset + size | |
] | |
cropped_boxes = ( | |
crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None | |
) | |
if ndim == 3: | |
cropped = cropped.squeeze(0) | |
return cropped, cropped_boxes | |
def clip_boxes_to_image(boxes, height, width): | |
""" | |
Clip an array of boxes to an image with the given height and width. | |
Args: | |
boxes (ndarray): bounding boxes to perform clipping. | |
Dimension is `num boxes` x 4. | |
height (int): given image height. | |
width (int): given image width. | |
Returns: | |
clipped_boxes (ndarray): the clipped boxes with dimension of | |
`num boxes` x 4. | |
""" | |
clipped_boxes = boxes.copy() | |
clipped_boxes[:, [0, 2]] = np.minimum( | |
width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) | |
) | |
clipped_boxes[:, [1, 3]] = np.minimum( | |
height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) | |
) | |
return clipped_boxes | |
def blend(images1, images2, alpha): | |
""" | |
Blend two images with a given weight alpha. | |
Args: | |
images1 (tensor): the first images to be blended, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
images2 (tensor): the second images to be blended, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
alpha (float): the blending weight. | |
Returns: | |
(tensor): blended images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
return images1 * alpha + images2 * (1 - alpha) | |
def grayscale(images): | |
""" | |
Get the grayscale for the input images. The channels of images should be | |
in order BGR. | |
Args: | |
images (tensor): the input images for getting grayscale. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
Returns: | |
img_gray (tensor): blended images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
# R -> 0.299, G -> 0.587, B -> 0.114. | |
img_gray = torch.tensor(images) | |
gray_channel = ( | |
0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] | |
) | |
img_gray[:, 0] = gray_channel | |
img_gray[:, 1] = gray_channel | |
img_gray[:, 2] = gray_channel | |
return img_gray | |
def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): | |
""" | |
Perfrom a color jittering on the input images. The channels of images | |
should be in order BGR. | |
Args: | |
images (tensor): images to perform color jitter. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
img_brightness (float): jitter ratio for brightness. | |
img_contrast (float): jitter ratio for contrast. | |
img_saturation (float): jitter ratio for saturation. | |
Returns: | |
images (tensor): the jittered images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
jitter = [] | |
if img_brightness != 0: | |
jitter.append("brightness") | |
if img_contrast != 0: | |
jitter.append("contrast") | |
if img_saturation != 0: | |
jitter.append("saturation") | |
if len(jitter) > 0: | |
order = np.random.permutation(np.arange(len(jitter))) | |
for idx in range(0, len(jitter)): | |
if jitter[order[idx]] == "brightness": | |
images = brightness_jitter(img_brightness, images) | |
elif jitter[order[idx]] == "contrast": | |
images = contrast_jitter(img_contrast, images) | |
elif jitter[order[idx]] == "saturation": | |
images = saturation_jitter(img_saturation, images) | |
return images | |
def brightness_jitter(var, images): | |
""" | |
Perfrom brightness jittering on the input images. The channels of images | |
should be in order BGR. | |
Args: | |
var (float): jitter ratio for brightness. | |
images (tensor): images to perform color jitter. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
Returns: | |
images (tensor): the jittered images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
alpha = 1.0 + np.random.uniform(-var, var) | |
img_bright = torch.zeros(images.shape) | |
images = blend(images, img_bright, alpha) | |
return images | |
def contrast_jitter(var, images): | |
""" | |
Perfrom contrast jittering on the input images. The channels of images | |
should be in order BGR. | |
Args: | |
var (float): jitter ratio for contrast. | |
images (tensor): images to perform color jitter. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
Returns: | |
images (tensor): the jittered images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
alpha = 1.0 + np.random.uniform(-var, var) | |
img_gray = grayscale(images) | |
img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) | |
images = blend(images, img_gray, alpha) | |
return images | |
def saturation_jitter(var, images): | |
""" | |
Perfrom saturation jittering on the input images. The channels of images | |
should be in order BGR. | |
Args: | |
var (float): jitter ratio for saturation. | |
images (tensor): images to perform color jitter. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
Returns: | |
images (tensor): the jittered images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
alpha = 1.0 + np.random.uniform(-var, var) | |
img_gray = grayscale(images) | |
images = blend(images, img_gray, alpha) | |
return images | |
def lighting_jitter(images, alphastd, eigval, eigvec): | |
""" | |
Perform AlexNet-style PCA jitter on the given images. | |
Args: | |
images (tensor): images to perform lighting jitter. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
alphastd (float): jitter ratio for PCA jitter. | |
eigval (list): eigenvalues for PCA jitter. | |
eigvec (list[list]): eigenvectors for PCA jitter. | |
Returns: | |
out_images (tensor): the jittered images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
if alphastd == 0: | |
return images | |
# generate alpha1, alpha2, alpha3. | |
alpha = np.random.normal(0, alphastd, size=(1, 3)) | |
eig_vec = np.array(eigvec) | |
eig_val = np.reshape(eigval, (1, 3)) | |
rgb = np.sum( | |
eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), | |
axis=1, | |
) | |
out_images = torch.zeros_like(images) | |
if len(images.shape) == 3: | |
# C H W | |
channel_dim = 0 | |
elif len(images.shape) == 4: | |
# T C H W | |
channel_dim = 1 | |
else: | |
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") | |
for idx in range(images.shape[channel_dim]): | |
# C H W | |
if len(images.shape) == 3: | |
out_images[idx] = images[idx] + rgb[2 - idx] | |
# T C H W | |
elif len(images.shape) == 4: | |
out_images[:, idx] = images[:, idx] + rgb[2 - idx] | |
else: | |
raise NotImplementedError( | |
f"Unsupported dimension {len(images.shape)}" | |
) | |
return out_images | |
def color_normalization(images, mean, stddev): | |
""" | |
Perform color nomration on the given images. | |
Args: | |
images (tensor): images to perform color normalization. Dimension is | |
`num frames` x `channel` x `height` x `width`. | |
mean (list): mean values for normalization. | |
stddev (list): standard deviations for normalization. | |
Returns: | |
out_images (tensor): the noramlized images, the dimension is | |
`num frames` x `channel` x `height` x `width`. | |
""" | |
if len(images.shape) == 3: | |
assert ( | |
len(mean) == images.shape[0] | |
), "channel mean not computed properly" | |
assert ( | |
len(stddev) == images.shape[0] | |
), "channel stddev not computed properly" | |
elif len(images.shape) == 4: | |
assert ( | |
len(mean) == images.shape[1] | |
), "channel mean not computed properly" | |
assert ( | |
len(stddev) == images.shape[1] | |
), "channel stddev not computed properly" | |
else: | |
raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") | |
out_images = torch.zeros_like(images) | |
for idx in range(len(mean)): | |
# C H W | |
if len(images.shape) == 3: | |
out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] | |
elif len(images.shape) == 4: | |
out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] | |
else: | |
raise NotImplementedError( | |
f"Unsupported dimension {len(images.shape)}" | |
) | |
return out_images | |
def _get_param_spatial_crop( | |
scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False | |
): | |
""" | |
Given scale, ratio, height and width, return sampled coordinates of the videos. | |
""" | |
for _ in range(num_repeat): | |
area = height * width | |
target_area = random.uniform(*scale) * area | |
if log_scale: | |
log_ratio = (math.log(ratio[0]), math.log(ratio[1])) | |
aspect_ratio = math.exp(random.uniform(*log_ratio)) | |
else: | |
aspect_ratio = random.uniform(*ratio) | |
w = int(round(math.sqrt(target_area * aspect_ratio))) | |
h = int(round(math.sqrt(target_area / aspect_ratio))) | |
if np.random.uniform() < 0.5 and switch_hw: | |
w, h = h, w | |
if 0 < w <= width and 0 < h <= height: | |
i = random.randint(0, height - h) | |
j = random.randint(0, width - w) | |
return i, j, h, w | |
# Fallback to central crop | |
in_ratio = float(width) / float(height) | |
if in_ratio < min(ratio): | |
w = width | |
h = int(round(w / min(ratio))) | |
elif in_ratio > max(ratio): | |
h = height | |
w = int(round(h * max(ratio))) | |
else: # whole image | |
w = width | |
h = height | |
i = (height - h) // 2 | |
j = (width - w) // 2 | |
return i, j, h, w | |
def random_resized_crop( | |
images, | |
target_height, | |
target_width, | |
scale=(0.8, 1.0), | |
ratio=(3.0 / 4.0, 4.0 / 3.0), | |
): | |
""" | |
Crop the given images to random size and aspect ratio. A crop of random | |
size (default: of 0.08 to 1.0) of the original size and a random aspect | |
ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This | |
crop is finally resized to given size. This is popularly used to train the | |
Inception networks. | |
Args: | |
images: Images to perform resizing and cropping. | |
target_height: Desired height after cropping. | |
target_width: Desired width after cropping. | |
scale: Scale range of Inception-style area based random resizing. | |
ratio: Aspect ratio range of Inception-style area based random resizing. | |
""" | |
height = images.shape[2] | |
width = images.shape[3] | |
i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) | |
cropped = images[:, :, i : i + h, j : j + w] | |
return torch.nn.functional.interpolate( | |
cropped, | |
size=(target_height, target_width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
def random_resized_crop_with_shift( | |
images, | |
target_height, | |
target_width, | |
scale=(0.8, 1.0), | |
ratio=(3.0 / 4.0, 4.0 / 3.0), | |
): | |
""" | |
This is similar to random_resized_crop. However, it samples two different | |
boxes (for cropping) for the first and last frame. It then linearly | |
interpolates the two boxes for other frames. | |
Args: | |
images: Images to perform resizing and cropping. | |
target_height: Desired height after cropping. | |
target_width: Desired width after cropping. | |
scale: Scale range of Inception-style area based random resizing. | |
ratio: Aspect ratio range of Inception-style area based random resizing. | |
""" | |
t = images.shape[1] | |
height = images.shape[2] | |
width = images.shape[3] | |
i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) | |
i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) | |
i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] | |
j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] | |
h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] | |
w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] | |
out = torch.zeros((3, t, target_height, target_width)) | |
for ind in range(t): | |
out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( | |
images[ | |
:, | |
ind : ind + 1, | |
i_s[ind] : i_s[ind] + h_s[ind], | |
j_s[ind] : j_s[ind] + w_s[ind], | |
], | |
size=(target_height, target_width), | |
mode="bilinear", | |
align_corners=False, | |
) | |
return out | |
def create_random_augment( | |
input_size, | |
auto_augment=None, | |
interpolation="bilinear", | |
): | |
""" | |
Get video randaug transform. | |
Args: | |
input_size: The size of the input video in tuple. | |
auto_augment: Parameters for randaug. An example: | |
"rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number | |
of operations to apply). | |
interpolation: Interpolation method. | |
""" | |
if isinstance(input_size, tuple): | |
img_size = input_size[-2:] | |
else: | |
img_size = input_size | |
if auto_augment: | |
assert isinstance(auto_augment, str) | |
if isinstance(img_size, tuple): | |
img_size_min = min(img_size) | |
else: | |
img_size_min = img_size | |
aa_params = {"translate_const": int(img_size_min * 0.45)} | |
if interpolation and interpolation != "random": | |
aa_params["interpolation"] = _pil_interp(interpolation) | |
if auto_augment.startswith("rand"): | |
return transforms.Compose( | |
[rand_augment_transform(auto_augment, aa_params)] | |
) | |
raise NotImplementedError | |
def random_sized_crop_img( | |
im, | |
size, | |
jitter_scale=(0.08, 1.0), | |
jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), | |
max_iter=10, | |
): | |
""" | |
Performs Inception-style cropping (used for training). | |
""" | |
assert ( | |
len(im.shape) == 3 | |
), "Currently only support image for random_sized_crop" | |
h, w = im.shape[1:3] | |
i, j, h, w = _get_param_spatial_crop( | |
scale=jitter_scale, | |
ratio=jitter_aspect, | |
height=h, | |
width=w, | |
num_repeat=max_iter, | |
log_scale=False, | |
switch_hw=True, | |
) | |
cropped = im[:, i : i + h, j : j + w] | |
return torch.nn.functional.interpolate( | |
cropped.unsqueeze(0), | |
size=(size, size), | |
mode="bilinear", | |
align_corners=False, | |
).squeeze(0) | |
# The following code are modified based on timm lib, we will replace the following | |
# contents with dependency from PyTorchVideo. | |
# https://github.com/facebookresearch/pytorchvideo | |
class RandomResizedCropAndInterpolation: | |
"""Crop the given PIL Image to random size and aspect ratio with random interpolation. | |
A crop of random size (default: of 0.08 to 1.0) of the original size and a random | |
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop | |
is finally resized to given size. | |
This is popularly used to train the Inception networks. | |
Args: | |
size: expected output size of each edge | |
scale: range of size of the origin size cropped | |
ratio: range of aspect ratio of the origin aspect ratio cropped | |
interpolation: Default: PIL.Image.BILINEAR | |
""" | |
def __init__( | |
self, | |
size, | |
scale=(0.08, 1.0), | |
ratio=(3.0 / 4.0, 4.0 / 3.0), | |
interpolation="bilinear", | |
): | |
if isinstance(size, tuple): | |
self.size = size | |
else: | |
self.size = (size, size) | |
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): | |
print("range should be of kind (min, max)") | |
if interpolation == "random": | |
self.interpolation = _RANDOM_INTERPOLATION | |
else: | |
self.interpolation = _pil_interp(interpolation) | |
self.scale = scale | |
self.ratio = ratio | |
def get_params(img, scale, ratio): | |
"""Get parameters for ``crop`` for a random sized crop. | |
Args: | |
img (PIL Image): Image to be cropped. | |
scale (tuple): range of size of the origin size cropped | |
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped | |
Returns: | |
tuple: params (i, j, h, w) to be passed to ``crop`` for a random | |
sized crop. | |
""" | |
area = img.size[0] * img.size[1] | |
for _ in range(10): | |
target_area = random.uniform(*scale) * area | |
log_ratio = (math.log(ratio[0]), math.log(ratio[1])) | |
aspect_ratio = math.exp(random.uniform(*log_ratio)) | |
w = int(round(math.sqrt(target_area * aspect_ratio))) | |
h = int(round(math.sqrt(target_area / aspect_ratio))) | |
if w <= img.size[0] and h <= img.size[1]: | |
i = random.randint(0, img.size[1] - h) | |
j = random.randint(0, img.size[0] - w) | |
return i, j, h, w | |
# Fallback to central crop | |
in_ratio = img.size[0] / img.size[1] | |
if in_ratio < min(ratio): | |
w = img.size[0] | |
h = int(round(w / min(ratio))) | |
elif in_ratio > max(ratio): | |
h = img.size[1] | |
w = int(round(h * max(ratio))) | |
else: # whole image | |
w = img.size[0] | |
h = img.size[1] | |
i = (img.size[1] - h) // 2 | |
j = (img.size[0] - w) // 2 | |
return i, j, h, w | |
def __call__(self, img): | |
""" | |
Args: | |
img (PIL Image): Image to be cropped and resized. | |
Returns: | |
PIL Image: Randomly cropped and resized image. | |
""" | |
i, j, h, w = self.get_params(img, self.scale, self.ratio) | |
if isinstance(self.interpolation, (tuple, list)): | |
interpolation = random.choice(self.interpolation) | |
else: | |
interpolation = self.interpolation | |
return F.resized_crop(img, i, j, h, w, self.size, interpolation) | |
def __repr__(self): | |
if isinstance(self.interpolation, (tuple, list)): | |
interpolate_str = " ".join( | |
[_pil_interpolation_to_str[x] for x in self.interpolation] | |
) | |
else: | |
interpolate_str = _pil_interpolation_to_str[self.interpolation] | |
format_string = self.__class__.__name__ + "(size={0}".format(self.size) | |
format_string += ", scale={0}".format( | |
tuple(round(s, 4) for s in self.scale) | |
) | |
format_string += ", ratio={0}".format( | |
tuple(round(r, 4) for r in self.ratio) | |
) | |
format_string += ", interpolation={0})".format(interpolate_str) | |
return format_string | |
def transforms_imagenet_train( | |
img_size=224, | |
scale=None, | |
ratio=None, | |
hflip=0.5, | |
vflip=0.0, | |
color_jitter=0.4, | |
auto_augment=None, | |
interpolation="random", | |
use_prefetcher=False, | |
mean=(0.485, 0.456, 0.406), | |
std=(0.229, 0.224, 0.225), | |
re_prob=0.0, | |
re_mode="const", | |
re_count=1, | |
re_num_splits=0, | |
separate=False, | |
): | |
""" | |
If separate==True, the transforms are returned as a tuple of 3 separate transforms | |
for use in a mixing dataset that passes | |
* all data through the first (primary) transform, called the 'clean' data | |
* a portion of the data through the secondary transform | |
* normalizes and converts the branches above with the third, final transform | |
""" | |
if isinstance(img_size, tuple): | |
img_size = img_size[-2:] | |
else: | |
img_size = img_size | |
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range | |
ratio = tuple( | |
ratio or (3.0 / 4.0, 4.0 / 3.0) | |
) # default imagenet ratio range | |
primary_tfl = [ | |
RandomResizedCropAndInterpolation( | |
img_size, scale=scale, ratio=ratio, interpolation=interpolation | |
) | |
] | |
if hflip > 0.0: | |
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] | |
if vflip > 0.0: | |
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] | |
secondary_tfl = [] | |
if auto_augment: | |
assert isinstance(auto_augment, str) | |
if isinstance(img_size, tuple): | |
img_size_min = min(img_size) | |
else: | |
img_size_min = img_size | |
aa_params = dict( | |
translate_const=int(img_size_min * 0.45), | |
img_mean=tuple([min(255, round(255 * x)) for x in mean]), | |
) | |
if interpolation and interpolation != "random": | |
aa_params["interpolation"] = _pil_interp(interpolation) | |
if auto_augment.startswith("rand"): | |
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] | |
elif auto_augment.startswith("augmix"): | |
raise NotImplementedError("Augmix not implemented") | |
else: | |
raise NotImplementedError("Auto aug not implemented") | |
elif color_jitter is not None: | |
# color jitter is enabled when not using AA | |
if isinstance(color_jitter, (list, tuple)): | |
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation | |
# or 4 if also augmenting hue | |
assert len(color_jitter) in (3, 4) | |
else: | |
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue | |
color_jitter = (float(color_jitter),) * 3 | |
secondary_tfl += [transforms.ColorJitter(*color_jitter)] | |
final_tfl = [] | |
final_tfl += [ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), | |
] | |
if re_prob > 0.0: | |
final_tfl.append( | |
RandomErasing( | |
re_prob, | |
mode=re_mode, | |
max_count=re_count, | |
num_splits=re_num_splits, | |
device="cpu", | |
cube=False, | |
) | |
) | |
if separate: | |
return ( | |
transforms.Compose(primary_tfl), | |
transforms.Compose(secondary_tfl), | |
transforms.Compose(final_tfl), | |
) | |
else: | |
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) | |
############################################################################################################ | |
############################################################################################################ | |
class Compose(object): | |
"""Composes several transforms | |
Args: | |
transforms (list of ``Transform`` objects): list of transforms | |
to compose | |
""" | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, clip): | |
for t in self.transforms: | |
clip = t(clip) | |
return clip | |
class RandomHorizontalFlip(object): | |
"""Horizontally flip the list of given images randomly | |
with a probability 0.5 | |
""" | |
def __call__(self, clip): | |
""" | |
Args: | |
img (PIL.Image or numpy.ndarray): List of images to be cropped | |
in format (h, w, c) in numpy.ndarray | |
Returns: | |
PIL.Image or numpy.ndarray: Randomly flipped clip | |
""" | |
if random.random() < 0.5: | |
if isinstance(clip[0], np.ndarray): | |
return [np.fliplr(img) for img in clip] | |
elif isinstance(clip[0], PIL.Image.Image): | |
return [ | |
img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip | |
] | |
else: | |
raise TypeError('Expected numpy.ndarray or PIL.Image' + | |
' but got list of {0}'.format(type(clip[0]))) | |
return clip | |
class RandomResize(object): | |
"""Resizes a list of (H x W x C) numpy.ndarray to the final size | |
The larger the original image is, the more times it takes to | |
interpolate | |
Args: | |
interpolation (str): Can be one of 'nearest', 'bilinear' | |
defaults to nearest | |
size (tuple): (widht, height) | |
""" | |
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): | |
self.ratio = ratio | |
self.interpolation = interpolation | |
def __call__(self, clip): | |
scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) | |
if isinstance(clip[0], np.ndarray): | |
im_h, im_w, im_c = clip[0].shape | |
elif isinstance(clip[0], PIL.Image.Image): | |
im_w, im_h = clip[0].size | |
new_w = int(im_w * scaling_factor) | |
new_h = int(im_h * scaling_factor) | |
new_size = (new_w, new_h) | |
resized = FF.resize_clip( | |
clip, new_size, interpolation=self.interpolation) | |
return resized | |
class Resize(object): | |
"""Resizes a list of (H x W x C) numpy.ndarray to the final size | |
The larger the original image is, the more times it takes to | |
interpolate | |
Args: | |
interpolation (str): Can be one of 'nearest', 'bilinear' | |
defaults to nearest | |
size (tuple): (widht, height) | |
""" | |
def __init__(self, size, interpolation='nearest'): | |
self.size = size | |
self.interpolation = interpolation | |
def __call__(self, clip): | |
resized = FF.resize_clip( | |
clip, self.size, interpolation=self.interpolation) | |
return resized | |
class RandomCrop(object): | |
"""Extract random crop at the same location for a list of images | |
Args: | |
size (sequence or int): Desired output size for the | |
crop in format (h, w) | |
""" | |
def __init__(self, size): | |
if isinstance(size, numbers.Number): | |
size = (size, size) | |
self.size = size | |
def __call__(self, clip): | |
""" | |
Args: | |
img (PIL.Image or numpy.ndarray): List of images to be cropped | |
in format (h, w, c) in numpy.ndarray | |
Returns: | |
PIL.Image or numpy.ndarray: Cropped list of images | |
""" | |
h, w = self.size | |
if isinstance(clip[0], np.ndarray): | |
im_h, im_w, im_c = clip[0].shape | |
elif isinstance(clip[0], PIL.Image.Image): | |
im_w, im_h = clip[0].size | |
else: | |
raise TypeError('Expected numpy.ndarray or PIL.Image' + | |
'but got list of {0}'.format(type(clip[0]))) | |
if w > im_w or h > im_h: | |
error_msg = ( | |
'Initial image size should be larger then ' | |
'cropped size but got cropped sizes : ({w}, {h}) while ' | |
'initial image is ({im_w}, {im_h})'.format( | |
im_w=im_w, im_h=im_h, w=w, h=h)) | |
raise ValueError(error_msg) | |
x1 = random.randint(0, im_w - w) | |
y1 = random.randint(0, im_h - h) | |
cropped = FF.crop_clip(clip, y1, x1, h, w) | |
return cropped | |
class ThreeCrop(object): | |
"""Extract random crop at the same location for a list of images | |
Args: | |
size (sequence or int): Desired output size for the | |
crop in format (h, w) | |
""" | |
def __init__(self, size): | |
if isinstance(size, numbers.Number): | |
size = (size, size) | |
self.size = size | |
def __call__(self, clip): | |
""" | |
Args: | |
img (PIL.Image or numpy.ndarray): List of images to be cropped | |
in format (h, w, c) in numpy.ndarray | |
Returns: | |
PIL.Image or numpy.ndarray: Cropped list of images | |
""" | |
h, w = self.size | |
if isinstance(clip[0], np.ndarray): | |
im_h, im_w, im_c = clip[0].shape | |
elif isinstance(clip[0], PIL.Image.Image): | |
im_w, im_h = clip[0].size | |
else: | |
raise TypeError('Expected numpy.ndarray or PIL.Image' + | |
'but got list of {0}'.format(type(clip[0]))) | |
if w != im_w and h != im_h: | |
clip = FF.resize_clip(clip, self.size, interpolation="bilinear") | |
im_h, im_w, im_c = clip[0].shape | |
step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) | |
cropped = [] | |
for i in range(3): | |
if (im_h > self.size[0]): | |
x1 = 0 | |
y1 = i * step | |
cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) | |
else: | |
x1 = i * step | |
y1 = 0 | |
cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) | |
return cropped | |
class RandomRotation(object): | |
"""Rotate entire clip randomly by a random angle within | |
given bounds | |
Args: | |
degrees (sequence or int): Range of degrees to select from | |
If degrees is a number instead of sequence like (min, max), | |
the range of degrees, will be (-degrees, +degrees). | |
""" | |
def __init__(self, degrees): | |
if isinstance(degrees, numbers.Number): | |
if degrees < 0: | |
raise ValueError('If degrees is a single number,' | |
'must be positive') | |
degrees = (-degrees, degrees) | |
else: | |
if len(degrees) != 2: | |
raise ValueError('If degrees is a sequence,' | |
'it must be of len 2.') | |
self.degrees = degrees | |
def __call__(self, clip): | |
""" | |
Args: | |
img (PIL.Image or numpy.ndarray): List of images to be cropped | |
in format (h, w, c) in numpy.ndarray | |
Returns: | |
PIL.Image or numpy.ndarray: Cropped list of images | |
""" | |
import skimage | |
angle = random.uniform(self.degrees[0], self.degrees[1]) | |
if isinstance(clip[0], np.ndarray): | |
rotated = [skimage.transform.rotate(img, angle) for img in clip] | |
elif isinstance(clip[0], PIL.Image.Image): | |
rotated = [img.rotate(angle) for img in clip] | |
else: | |
raise TypeError('Expected numpy.ndarray or PIL.Image' + | |
'but got list of {0}'.format(type(clip[0]))) | |
return rotated | |
class CenterCrop(object): | |
"""Extract center crop at the same location for a list of images | |
Args: | |
size (sequence or int): Desired output size for the | |
crop in format (h, w) | |
""" | |
def __init__(self, size): | |
if isinstance(size, numbers.Number): | |
size = (size, size) | |
self.size = size | |
def __call__(self, clip): | |
""" | |
Args: | |
img (PIL.Image or numpy.ndarray): List of images to be cropped | |
in format (h, w, c) in numpy.ndarray | |
Returns: | |
PIL.Image or numpy.ndarray: Cropped list of images | |
""" | |
h, w = self.size | |
if isinstance(clip[0], np.ndarray): | |
im_h, im_w, im_c = clip[0].shape | |
elif isinstance(clip[0], PIL.Image.Image): | |
im_w, im_h = clip[0].size | |
else: | |
raise TypeError('Expected numpy.ndarray or PIL.Image' + | |
'but got list of {0}'.format(type(clip[0]))) | |
if w > im_w or h > im_h: | |
error_msg = ( | |
'Initial image size should be larger then ' | |
'cropped size but got cropped sizes : ({w}, {h}) while ' | |
'initial image is ({im_w}, {im_h})'.format( | |
im_w=im_w, im_h=im_h, w=w, h=h)) | |
raise ValueError(error_msg) | |
x1 = int(round((im_w - w) / 2.)) | |
y1 = int(round((im_h - h) / 2.)) | |
cropped = FF.crop_clip(clip, y1, x1, h, w) | |
return cropped | |
class ColorJitter(object): | |
"""Randomly change the brightness, contrast and saturation and hue of the clip | |
Args: | |
brightness (float): How much to jitter brightness. brightness_factor | |
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. | |
contrast (float): How much to jitter contrast. contrast_factor | |
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. | |
saturation (float): How much to jitter saturation. saturation_factor | |
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. | |
hue(float): How much to jitter hue. hue_factor is chosen uniformly from | |
[-hue, hue]. Should be >=0 and <= 0.5. | |
""" | |
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): | |
self.brightness = brightness | |
self.contrast = contrast | |
self.saturation = saturation | |
self.hue = hue | |
def get_params(self, brightness, contrast, saturation, hue): | |
if brightness > 0: | |
brightness_factor = random.uniform( | |
max(0, 1 - brightness), 1 + brightness) | |
else: | |
brightness_factor = None | |
if contrast > 0: | |
contrast_factor = random.uniform( | |
max(0, 1 - contrast), 1 + contrast) | |
else: | |
contrast_factor = None | |
if saturation > 0: | |
saturation_factor = random.uniform( | |
max(0, 1 - saturation), 1 + saturation) | |
else: | |
saturation_factor = None | |
if hue > 0: | |
hue_factor = random.uniform(-hue, hue) | |
else: | |
hue_factor = None | |
return brightness_factor, contrast_factor, saturation_factor, hue_factor | |
def __call__(self, clip): | |
""" | |
Args: | |
clip (list): list of PIL.Image | |
Returns: | |
list PIL.Image : list of transformed PIL.Image | |
""" | |
if isinstance(clip[0], np.ndarray): | |
raise TypeError( | |
'Color jitter not yet implemented for numpy arrays') | |
elif isinstance(clip[0], PIL.Image.Image): | |
brightness, contrast, saturation, hue = self.get_params( | |
self.brightness, self.contrast, self.saturation, self.hue) | |
# Create img transform function sequence | |
img_transforms = [] | |
if brightness is not None: | |
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) | |
if saturation is not None: | |
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) | |
if hue is not None: | |
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) | |
if contrast is not None: | |
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) | |
random.shuffle(img_transforms) | |
# Apply to all images | |
jittered_clip = [] | |
for img in clip: | |
for func in img_transforms: | |
jittered_img = func(img) | |
jittered_clip.append(jittered_img) | |
else: | |
raise TypeError('Expected numpy.ndarray or PIL.Image' + | |
'but got list of {0}'.format(type(clip[0]))) | |
return jittered_clip | |
class Normalize(object): | |
"""Normalize a clip with mean and standard deviation. | |
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform | |
will normalize each channel of the input ``torch.*Tensor`` i.e. | |
``input[channel] = (input[channel] - mean[channel]) / std[channel]`` | |
.. note:: | |
This transform acts out of place, i.e., it does not mutates the input tensor. | |
Args: | |
mean (sequence): Sequence of means for each channel. | |
std (sequence): Sequence of standard deviations for each channel. | |
""" | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, clip): | |
""" | |
Args: | |
clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. | |
Returns: | |
Tensor: Normalized Tensor clip. | |
""" | |
return FF.normalize(clip, self.mean, self.std) | |
def __repr__(self): | |
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) | |