MedRPG / med_rpg /transforms.py
zy5830850
First model version
91ef820
import math
import torch
import random
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
import torchvision.transforms as T
import torchvision.transforms.functional as F
from utils.box_utils import xyxy2xywh
from utils.misc import interpolate
def crop(image, box, region, negBoxs=None):
cropped_image = F.crop(image, *region)
i, j, h, w = region
max_size = torch.as_tensor([w, h], dtype=torch.float32)
cropped_box = box - torch.as_tensor([j, i, j, i])
cropped_box = torch.min(cropped_box.reshape(2, 2), max_size)
cropped_box = cropped_box.clamp(min=0)
cropped_box = cropped_box.reshape(-1)
if negBoxs is not None:
cropped_negBoxs = []
for negBox in negBoxs:
cropped_negBox = negBox - torch.as_tensor([j, i, j, i])
cropped_negBox = torch.min(cropped_negBox.reshape(2, 2), max_size)
cropped_negBox = cropped_negBox.clamp(min=0)
cropped_negBox = cropped_negBox.reshape(-1)
cropped_negBoxs.append(cropped_negBox)
return cropped_image, cropped_box, cropped_negBoxs
return cropped_image, cropped_box
def resize_according_to_long_side(img, box, size, negBoxs=None):
h, w = img.height, img.width
ratio = float(size / float(max(h, w)))
new_w, new_h = round(w* ratio), round(h * ratio)
img = F.resize(img, (new_h, new_w))
box = box * ratio
if negBoxs is not None:
negBoxs = [negBox * ratio for negBox in negBoxs]
return img, box, negBoxs
return img, box
def resize_according_to_short_side(img, box, size, negBoxs=None):
h, w = img.height, img.width
ratio = float(size / float(min(h, w)))
new_w, new_h = round(w* ratio), round(h * ratio)
img = F.resize(img, (new_h, new_w))
box = box * ratio
if negBoxs is not None:
negBoxs = [negBox * ratio for negBox in negBoxs]
return img, box, negBoxs
return img, box
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, input_dict):
for t in self.transforms:
input_dict = t(input_dict)
return input_dict
def __repr__(self):
format_string = self.__class__.__name__ + "("
for t in self.transforms:
format_string += "\n"
format_string += " {0}".format(t)
format_string += "\n)"
return format_string
class RandomBrightness(object):
def __init__(self, brightness=0.4):
assert brightness >= 0.0
assert brightness <= 1.0
self.brightness = brightness
def __call__(self, img):
brightness_factor = random.uniform(1-self.brightness, 1+self.brightness)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
class RandomContrast(object):
def __init__(self, contrast=0.4):
assert contrast >= 0.0
assert contrast <= 1.0
self.contrast = contrast
def __call__(self, img):
contrast_factor = random.uniform(1-self.contrast, 1+self.contrast)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
class RandomSaturation(object):
def __init__(self, saturation=0.4):
assert saturation >= 0.0
assert saturation <= 1.0
self.saturation = saturation
def __call__(self, img):
saturation_factor = random.uniform(1-self.saturation, 1+self.saturation)
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
class ColorJitter(object):
def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
self.rand_brightness = RandomBrightness(brightness)
self.rand_contrast = RandomContrast(contrast)
self.rand_saturation = RandomSaturation(saturation)
def __call__(self, input_dict):
if random.random() < 0.8:
image = input_dict['img']
func_inds = list(np.random.permutation(3))
for func_id in func_inds:
if func_id == 0:
image = self.rand_brightness(image)
elif func_id == 1:
image = self.rand_contrast(image)
elif func_id == 2:
image = self.rand_saturation(image)
input_dict['img'] = image
return input_dict
class GaussianBlur(object):
def __init__(self, sigma=[.1, 2.], aug_blur=False):
self.sigma = sigma
self.p = 0.5 if aug_blur else 0.
def __call__(self, input_dict):
if random.random() < self.p:
img = input_dict['img']
sigma = random.uniform(self.sigma[0], self.sigma[1])
img = img.filter(ImageFilter.GaussianBlur(radius=sigma))
input_dict['img'] = img
return input_dict
class RandomHorizontalFlip(object):
def __call__(self, input_dict):
if random.random() < 0.5:
img = input_dict['img']
box = input_dict['box']
text = input_dict['text']
img = F.hflip(img)
text = text.replace('right','*&^special^&*').replace('left','right').replace('*&^special^&*','left')
h, w = img.height, img.width
box = box[[2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
input_dict['img'] = img
input_dict['box'] = box
input_dict['text'] = text
if 'NegBBoxs' in input_dict.keys():
input_dict['NegBBoxs'] = [negBox[[2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) for negBox in input_dict['NegBBoxs']]
return input_dict
class RandomResize(object):
def __init__(self, sizes, with_long_side=True):
assert isinstance(sizes, (list, tuple))
self.sizes = sizes
self.with_long_side = with_long_side
def __call__(self, input_dict):
img = input_dict['img']
box = input_dict['box']
size = random.choice(self.sizes)
if 'NegBBoxs' in input_dict.keys():
if self.with_long_side:
resized_img, resized_box, NegBBoxs = resize_according_to_long_side(img, box, size, input_dict['NegBBoxs'])
else:
resized_img, resized_box, NegBBoxs = resize_according_to_short_side(img, box, size, input_dict['NegBBoxs'])
input_dict['NegBBoxs'] = NegBBoxs
else:
if self.with_long_side:
resized_img, resized_box = resize_according_to_long_side(img, box, size)
else:
resized_img, resized_box = resize_according_to_short_side(img, box, size)
input_dict['img'] = resized_img
input_dict['box'] = resized_box
return input_dict
class RandomSizeCrop(object):
def __init__(self, min_size: int, max_size: int, max_try: int=20):
self.min_size = min_size
self.max_size = max_size
self.max_try = max_try
def __call__(self, input_dict):
img = input_dict['img']
box = input_dict['box']
num_try = 0
while num_try < self.max_try:
num_try += 1
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w]) # [i, j, target_w, target_h] [j, i, target_h, target_w]
box_xywh = xyxy2xywh(box)
box_x, box_y = box_xywh[0], box_xywh[1]
# if box_x > region[0] and box_y > region[1]: # ζ„Ÿθ§‰θΏ™ι‡Œε†™ι”™δΊ†,w hζžεδΊ†
if box_x > region[1] and box_y > region[0]:
if 'NegBBoxs' in input_dict.keys():
img, box, NegBBoxs = crop(img, box, region, input_dict['NegBBoxs'])
input_dict['NegBBoxs'] = NegBBoxs
img, box = crop(img, box, region)
input_dict['img'] = img
input_dict['box'] = box
return input_dict
return input_dict
class RandomSelect(object):
def __init__(self, transforms1, transforms2, p=0.5):
self.transforms1 = transforms1
self.transforms2 = transforms2
self.p = p
def __call__(self, input_dict):
text = input_dict['text']
dir_words = ['left', 'right', 'top', 'bottom', 'middle']
for wd in dir_words:
if wd in text:
return self.transforms1(input_dict)
if random.random() < self.p:
return self.transforms2(input_dict)
else:
return self.transforms1(input_dict)
class ToTensor(object):
def __call__(self, input_dict):
img = input_dict['img']
# img = img.transpose((2,0,1))
# img = torch.from_numpy(img).float()
img = F.to_tensor(img)
input_dict['img'] = img
return input_dict
class NormalizeAndPad(object):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], size=640, aug_translate=False):
self.mean = mean
self.std = std
self.size = size
self.aug_translate = aug_translate
def __call__(self, input_dict):
img = input_dict['img']
img = F.normalize(img, mean=self.mean, std=self.std)
h, w = img.shape[1:]
dw = self.size - w
dh = self.size - h
if self.aug_translate:
top = random.randint(0, dh)
left = random.randint(0, dw)
else:
top = round(dh / 2.0 - 0.1)
left = round(dw / 2.0 - 0.1)
# dw = (self.size - w) / 2.0
# dh = (self.size - h) / 2.0
# top, bottom = round(dh - 0.1), round(dh + 0.1)
# left, right = round(dw - 0.1), round(dw + 0.1)
out_img = torch.zeros((3, self.size, self.size)).float()
out_mask = torch.ones((self.size, self.size)).int()
out_img[:, top:top+h, left:left+w] = img
out_mask[top:top+h, left:left+w] = 0
input_dict['img'] = out_img
input_dict['mask'] = out_mask
if 'box' in input_dict.keys():
box = input_dict['box']
box[0], box[2] = box[0]+left, box[2]+left
box[1], box[3] = box[1]+top, box[3]+top
h, w = out_img.shape[-2:]
box = xyxy2xywh(box)
box = box / torch.tensor([w, h, w, h], dtype=torch.float32)
input_dict['box'] = box
if 'NegBBoxs' in input_dict.keys():
NegBBoxs = input_dict['NegBBoxs']
new_NegBBoxs = []
for NegBBox in NegBBoxs:
NegBBox[0], NegBBox[2] = NegBBox[0] + left, NegBBox[2] + left
NegBBox[1], NegBBox[3] = NegBBox[1] + top, NegBBox[3] + top
h, w = out_img.shape[-2:]
NegBBox = xyxy2xywh(NegBBox)
NegBBox = NegBBox / torch.tensor([w, h, w, h], dtype=torch.float32)
new_NegBBoxs.append(NegBBox)
input_dict['NegBBoxs'] = new_NegBBoxs
return input_dict