|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
from typing import Literal, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as F |
|
from PIL import Image |
|
from timm.data import ( |
|
IMAGENET_INCEPTION_MEAN, |
|
IMAGENET_INCEPTION_STD, |
|
OPENAI_CLIP_MEAN, |
|
OPENAI_CLIP_STD, |
|
) |
|
from timm.data.transforms_factory import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
from transformers.image_utils import ImageInput, make_list_of_images |
|
from transformers.utils import TensorType |
|
|
|
NormalizationType = Literal["imagenet", "imagenet_inception", "openai_clip"] |
|
|
|
|
|
class CLYPImageProcessor(BaseImageProcessor): |
|
def __init__( |
|
self, |
|
image_size: int = 224, |
|
normalization_type: NormalizationType = "imagenet", |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.image_size = image_size |
|
self.normalization_type: NormalizationType = normalization_type |
|
|
|
def preprocess( |
|
self, |
|
images: ImageInput | list[ImageInput], |
|
return_tensors: Optional[str | TensorType] = None, |
|
**kwargs, |
|
) -> BatchFeature: |
|
images = make_list_of_images(images, expected_ndims=3) |
|
|
|
transforms = TestTransform( |
|
self.image_size, normalization_type=self.normalization_type |
|
) |
|
images = [transforms(image).numpy() for image in images] |
|
return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) |
|
|
|
|
|
class TrainTransform: |
|
def __init__( |
|
self, |
|
image_size: int, |
|
scale_range_min: float, |
|
scale_range_max: float, |
|
normalization_type: NormalizationType = "imagenet", |
|
) -> None: |
|
""" |
|
Args: |
|
image_size (int): output-image size. |
|
scale_range_min (float): minimum value of the scale to crop an input image. |
|
scale_range_max (float): maximum value of the scale to crop an input image. |
|
normalization_type (str): select mean and std for normalization (see get_mean_and_std). |
|
""" |
|
scale = (scale_range_min, scale_range_max) |
|
mean_and_std = get_mean_and_std(normalization_type) |
|
|
|
self.transform = T.Compose( |
|
[ |
|
T.RandomResizedCrop( |
|
image_size, scale=scale, interpolation=T.InterpolationMode.BICUBIC |
|
), |
|
_convert_to_rgb, |
|
T.ToTensor(), |
|
T.Normalize(**mean_and_std), |
|
] |
|
) |
|
|
|
def __call__(self, img): |
|
return self.transform(img) |
|
|
|
|
|
class TestTransform: |
|
def __init__( |
|
self, image_size: int, normalization_type: NormalizationType = "imagenet" |
|
) -> None: |
|
""" |
|
Args: |
|
image_size (int): output-image size. |
|
normalization_type (str): select mean and std for normalization (see get_mean_and_std). |
|
""" |
|
mean_and_std = get_mean_and_std(normalization_type) |
|
|
|
self.transform = T.Compose( |
|
[ |
|
ResizeMaxSize(image_size, fill=0), |
|
T.CenterCrop(image_size), |
|
_convert_to_rgb, |
|
T.ToTensor(), |
|
T.Normalize(**mean_and_std), |
|
] |
|
) |
|
|
|
def __call__(self, img): |
|
return self.transform(img) |
|
|
|
|
|
class SmallestMaxSize(T.Resize): |
|
"""Resize shorter side of an input image. |
|
|
|
The shorter side of an input image is resized to the max_size. |
|
Note that an large part of the input image is discarded when an aspect-ratio value of the input image is extremely small or large. |
|
""" |
|
|
|
def __init__(self, max_size: int, **kwargs): |
|
super().__init__(max_size, **kwargs) |
|
|
|
@staticmethod |
|
def target_size(w: int, h: int, size: int) -> tuple[int, int]: |
|
if h < w: |
|
w, h = int(size * w / h), size |
|
else: |
|
w, h = size, int(size * h / w) |
|
return (h, w) |
|
|
|
def __call__(self, img): |
|
size = self.size |
|
assert isinstance(size, int) |
|
w, h = img.size |
|
target_size = self.target_size(w, h, size) |
|
return F.resize(img, list(target_size), self.interpolation) |
|
|
|
|
|
class ResizeMaxSize(nn.Module): |
|
"""Resize longer side of an input image. |
|
|
|
The longer side of an input image is resized to the max_size. |
|
Note that an large part of the output image is padded when an aspect-ration value of the input image is extremely small or large. |
|
Adapted from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transform.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_size: int, |
|
interpolation: T.InterpolationMode = T.InterpolationMode.BICUBIC, |
|
fn: str = "max", |
|
fill: int = 0, |
|
): |
|
super().__init__() |
|
if not isinstance(max_size, int): |
|
raise TypeError(f"Size should be int. Got {type(max_size)}") |
|
self.max_size = max_size |
|
self.interpolation = interpolation |
|
self.fn = min if fn == "min" else min |
|
self.fill = fill |
|
|
|
def forward(self, img): |
|
if isinstance(img, torch.Tensor): |
|
height, width = img.shape[:2] |
|
else: |
|
width, height = img.size |
|
scale = self.max_size / float(max(height, width)) |
|
if scale != 1.0: |
|
new_size = tuple(round(dim * scale) for dim in (height, width)) |
|
img = F.resize(img, new_size, self.interpolation) |
|
pad_h = self.max_size - new_size[0] |
|
pad_w = self.max_size - new_size[1] |
|
img = F.pad( |
|
img, |
|
padding=[ |
|
pad_w // 2, |
|
pad_h // 2, |
|
pad_w - pad_w // 2, |
|
pad_h - pad_h // 2, |
|
], |
|
fill=self.fill, |
|
) |
|
return img |
|
|
|
|
|
def get_mean_and_std(normalization_type: NormalizationType) -> dict: |
|
"""Return mean and std tensors for T.Normalize() |
|
NOTE: |
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) |
|
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) |
|
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) |
|
""" |
|
if normalization_type == "imagenet": |
|
return { |
|
"mean": torch.tensor(IMAGENET_DEFAULT_MEAN), |
|
"std": torch.tensor(IMAGENET_DEFAULT_STD), |
|
} |
|
elif normalization_type == "imagenet_inception": |
|
return { |
|
"mean": torch.tensor(IMAGENET_INCEPTION_MEAN), |
|
"std": torch.tensor(IMAGENET_INCEPTION_STD), |
|
} |
|
elif normalization_type == "openai_clip": |
|
return { |
|
"mean": torch.tensor(OPENAI_CLIP_MEAN), |
|
"std": torch.tensor(OPENAI_CLIP_STD), |
|
} |
|
else: |
|
raise ValueError(normalization_type) |
|
|
|
|
|
def _convert_to_rgb(image: Image.Image) -> Image.Image: |
|
return image.convert("RGB") |
|
|