|
""" |
|
Use torchvision instead of transformers to perform resize and center crop. |
|
|
|
This is because transformers' version is sometimes 1-pixel off. |
|
|
|
For example, if the image size is 640x480, both results are consistent. |
|
(e.g., "http://images.cocodataset.org/val2017/000000039769.jpg") |
|
|
|
However, if the image size is 500x334, the following happens: |
|
(e.g., "http://images.cocodataset.org/val2014/COCO_val2014_000000324158.jpg") |
|
|
|
>>> # Results' shape: (h, w) |
|
>>> torch.allclose(torchvision_result[:, :-1], transformers_result[:, 1:]) |
|
... True |
|
|
|
Note that if only resize is performed with torchvision, |
|
the inconsistency remains. |
|
Therefore, center crop must also be done with torchvision. |
|
""" |
|
|
|
import PIL |
|
from torchvision.transforms import CenterCrop, InterpolationMode, Resize |
|
from transformers import AutoImageProcessor, CLIPImageProcessor |
|
from transformers.image_processing_utils import get_size_dict |
|
from transformers.image_utils import ImageInput, PILImageResampling, make_list_of_images |
|
|
|
|
|
def PILImageResampling_to_InterpolationMode( |
|
resample: PILImageResampling, |
|
) -> InterpolationMode: |
|
return getattr(InterpolationMode, PILImageResampling(resample).name) |
|
|
|
|
|
class CustomCLIPImageProcessor(CLIPImageProcessor): |
|
def preprocess( |
|
self, |
|
images: ImageInput, |
|
do_resize: bool = None, |
|
size: dict[str, int] = None, |
|
resample: PILImageResampling = None, |
|
do_center_crop: bool = None, |
|
crop_size: int = None, |
|
**kwargs, |
|
) -> PIL.Image.Image: |
|
do_resize = do_resize if do_resize is not None else self.do_resize |
|
size = size if size is not None else self.size |
|
resample = resample if resample is not None else self.resample |
|
do_center_crop = ( |
|
do_center_crop if do_center_crop is not None else self.do_center_crop |
|
) |
|
crop_size = crop_size if crop_size is not None else self.crop_size |
|
|
|
images = make_list_of_images(images) |
|
|
|
if do_resize: |
|
|
|
_size = get_size_dict( |
|
size, |
|
param_name="size", |
|
default_to_square=getattr(self, "use_square_size", False), |
|
) |
|
if set(_size) == {"shortest_edge"}: |
|
|
|
resize = Resize( |
|
size=_size["shortest_edge"], |
|
interpolation=PILImageResampling_to_InterpolationMode(resample), |
|
) |
|
images = [resize(image) for image in images] |
|
do_resize = False |
|
|
|
if do_center_crop: |
|
|
|
_crop_size = get_size_dict( |
|
crop_size, param_name="crop_size", default_to_square=True |
|
) |
|
|
|
center_crop = CenterCrop( |
|
size=tuple(map(_crop_size.get, ["height", "width"])) |
|
) |
|
images = [center_crop(image) for image in images] |
|
do_center_crop = False |
|
|
|
return super().preprocess( |
|
images=images, |
|
do_resize=do_resize, |
|
size=size, |
|
resample=resample, |
|
do_center_crop=do_center_crop, |
|
crop_size=crop_size, |
|
**kwargs, |
|
) |
|
|
|
|
|
AutoImageProcessor.register("CustomCLIPImageProcessor", CustomCLIPImageProcessor) |
|
|