# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is licensed under a Creative Commons # Attribution-NonCommercial-ShareAlike 4.0 International License. # You should have received a copy of the license along with this # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ """Streaming images and labels from datasets created with dataset_tool.py.""" import os import numpy as np import zipfile import PIL.Image import json import torch import random from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels from basicsr.data.transforms import augment from basicsr.utils import img2tensor from basicsr.utils.registry import DATASET_REGISTRY try: import pyspng except ImportError: pyspng = None KERNEL_OPT = { 'blur_kernel_size': 21, 'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'], 'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03], 'sinc_prob': 0.1, 'blur_sigma': [0.2, 3], 'betag_range': [0.5, 4], 'betap_range': [1, 2], 'blur_kernel_size2': 21, 'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'], 'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03], 'sinc_prob2': 0.1, 'blur_sigma2': [0.2, 1.5], 'betag_range2': [0.5, 4], 'betap_range2': [1, 2], 'final_sinc_prob': 0.8, 'use_hflip': False, 'use_rot': False } DEGRADE_OPT = { 'resize_prob': [0.2, 0.7, 0.1], # up, down, keep 'resize_range': [0.15, 1.5], 'gaussian_noise_prob': 0.5, 'noise_range': [1, 30], 'poisson_scale_range': [0.05, 3], 'gray_noise_prob': 0.4, 'jpeg_range': [30, 95], # the second degradation process 'second_blur_prob': 0.8, 'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep 'resize_range2': [0.3, 1.2], 'gaussian_noise_prob2': 0.5, 'noise_range2': [1, 25], 'poisson_scale_range2': [0.05, 2.5], 'gray_noise_prob2': 0.4, 'jpeg_range2': [30, 95], 'gt_size': 512, 'no_degradation_prob': 0.01, 'use_usm': True, 'sf': 4, 'random_size': False, 'resize_lq': False } #---------------------------------------------------------------------------- # Abstract base class for datasets. class Dataset(torch.utils.data.Dataset): def __init__(self, name, # Name of the dataset. raw_shape, # Shape of the raw image data (NCHW). use_labels = True, # Enable conditioning labels? False = label dimension is zero. max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. random_seed = 0, # Random seed to use when applying max_size. cache = False, # Cache images in CPU memory? ): self._name = name self._raw_shape = list(raw_shape) self._use_labels = use_labels self._cache = cache self._cached_images = dict() # {raw_idx: np.ndarray, ...} self._raw_labels = None self._label_shape = None # Apply max_size. self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) if (max_size is not None) and (self._raw_idx.size > max_size): np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) self._raw_idx = np.sort(self._raw_idx[:max_size]) # Apply xflip. self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) if xflip: self._raw_idx = np.tile(self._raw_idx, 2) self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) def _get_raw_labels(self): if self._raw_labels is None: self._raw_labels = self._load_raw_labels() if self._use_labels else None if self._raw_labels is None: self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) assert isinstance(self._raw_labels, np.ndarray) assert self._raw_labels.shape[0] == self._raw_shape[0] assert self._raw_labels.dtype in [np.float32, np.int64] if self._raw_labels.dtype == np.int64: assert self._raw_labels.ndim == 1 assert np.all(self._raw_labels >= 0) return self._raw_labels def close(self): # to be overridden by subclass pass def _load_raw_image(self, raw_idx): # to be overridden by subclass raise NotImplementedError def _load_raw_labels(self): # to be overridden by subclass raise NotImplementedError def __getstate__(self): return dict(self.__dict__, _raw_labels=None) def __del__(self): try: self.close() except: pass def __len__(self): return self._raw_idx.size def __getitem__(self, idx): raw_idx = self._raw_idx[idx] image = self._cached_images.get(raw_idx, None) if image is None: image = self._load_raw_image(raw_idx) if self._cache: self._cached_images[raw_idx] = image assert isinstance(image, np.ndarray) assert list(image.shape) == self._raw_shape[1:] if self._xflip[idx]: assert image.ndim == 3 # CHW image = image[:, :, ::-1] return image.copy(), self.get_label(idx) def get_label(self, idx): label = self._get_raw_labels()[self._raw_idx[idx]] if label.dtype == np.int64: onehot = np.zeros(self.label_shape, dtype=np.float32) onehot[label] = 1 label = onehot return label.copy() def get_details(self, idx): d = dict() d['raw_idx'] = int(self._raw_idx[idx]) d['xflip'] = (int(self._xflip[idx]) != 0) d['raw_label'] = self._get_raw_labels()[d['raw_idx']].copy() return d @property def name(self): return self._name @property def image_shape(self): # [CHW] return list(self._raw_shape[1:]) @property def num_channels(self): assert len(self.image_shape) == 3 # CHW return self.image_shape[0] @property def resolution(self): assert len(self.image_shape) == 3 # CHW assert self.image_shape[1] == self.image_shape[2] return self.image_shape[1] @property def label_shape(self): if self._label_shape is None: raw_labels = self._get_raw_labels() if raw_labels.dtype == np.int64: self._label_shape = [int(np.max(raw_labels)) + 1] else: self._label_shape = raw_labels.shape[1:] return list(self._label_shape) @property def label_dim(self): assert len(self.label_shape) == 1 return self.label_shape[0] @property def has_labels(self): return any(x != 0 for x in self.label_shape) @property def has_onehot_labels(self): return self._get_raw_labels().dtype == np.int64 #---------------------------------------------------------------------------- # Dataset subclass that loads images recursively from the specified directory # or ZIP file. class ImageFolderDataset(Dataset): def __init__(self, path, # Path to directory or zip. resolution = None, # Ensure specific resolution, None = anything goes. **super_kwargs, # Additional arguments for the Dataset base class. ): self._path = path self._zipfile = None if os.path.isdir(self._path): self._type = 'dir' self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} elif self._file_ext(self._path) == '.zip': self._type = 'zip' self._all_fnames = set(self._get_zipfile().namelist()) else: raise IOError('Path must point to a directory or zip') PIL.Image.init() supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'} self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in supported_ext) if len(self._image_fnames) == 0: raise IOError('No image files found in the specified path') name = os.path.splitext(os.path.basename(self._path))[0] raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): raise IOError('Image files do not match the specified resolution') super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) @staticmethod def _file_ext(fname): return os.path.splitext(fname)[1].lower() def _get_zipfile(self): assert self._type == 'zip' if self._zipfile is None: self._zipfile = zipfile.ZipFile(self._path) return self._zipfile def _open_file(self, fname): if self._type == 'dir': return open(os.path.join(self._path, fname), 'rb') if self._type == 'zip': return self._get_zipfile().open(fname, 'r') return None def close(self): try: if self._zipfile is not None: self._zipfile.close() finally: self._zipfile = None def __getstate__(self): return dict(super().__getstate__(), _zipfile=None) def _load_raw_image(self, raw_idx): fname = self._image_fnames[raw_idx] ext = self._file_ext(fname) with self._open_file(fname) as f: if ext == '.npy': image = np.load(f) image = image.reshape(-1, *image.shape[-2:]) elif ext == '.png' and pyspng is not None: image = pyspng.load(f.read()) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) else: image = np.array(PIL.Image.open(f)) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) return image def _load_raw_labels(self): fname = 'dataset.json' if fname not in self._all_fnames: return None with self._open_file(fname) as f: labels = json.load(f)['labels'] if labels is None: return None labels = dict(labels) labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] labels = np.array(labels) labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) return labels #---------------------------------------------------------------------------- @DATASET_REGISTRY.register(suffix='basicsr') class IRImageFolderDataset(ImageFolderDataset): def __init__(self, opt=None, # Degradation kernel config. **super_kwargs, # Additional arguments for the Dataset base class. ): if opt is None: opt = KERNEL_OPT self.opt = opt super().__init__(**super_kwargs) # blur settings for the first degradation self.blur_kernel_size = opt['blur_kernel_size'] self.kernel_list = opt['kernel_list'] self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability self.blur_sigma = opt['blur_sigma'] self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels self.betap_range = opt['betap_range'] # betap used in plateau blur kernels self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters # blur settings for the second degradation self.blur_kernel_size2 = opt['blur_kernel_size2'] self.kernel_list2 = opt['kernel_list2'] self.kernel_prob2 = opt['kernel_prob2'] self.blur_sigma2 = opt['blur_sigma2'] self.betag_range2 = opt['betag_range2'] self.betap_range2 = opt['betap_range2'] self.sinc_prob2 = opt['sinc_prob2'] # a final sinc filter self.final_sinc_prob = opt['final_sinc_prob'] self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 # TODO: kernel range is now hard-coded, should be in the configure file self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect self.pulse_tensor[10, 10] = 1 def _load_raw_image(self, raw_idx): fname = self._image_fnames[raw_idx] ext = self._file_ext(fname) with self._open_file(fname) as f: if ext == '.npy': image = np.load(f) image = image.reshape(-1, *image.shape[-2:]) elif ext == '.png' and pyspng is not None: image = pyspng.load(f.read()) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) else: image = np.array(PIL.Image.open(f)) image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1) return image def __getitem__(self, idx): raw_idx = self._raw_idx[idx] image = self._cached_images.get(raw_idx, None) if image is None: image = self._load_raw_image(raw_idx) if self._cache: self._cached_images[raw_idx] = image assert isinstance(image, np.ndarray), type(image) assert list(image.shape) == self._raw_shape[1:], image.shape # # FIXME: flip or rotate # image = augment(image, self.opt['use_hflip'], self.opt['use_rot']) image = image.astype(np.float32) / 255. # ------------------------ Generate kernels (used in the first degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob']: # this sinc filter setting is for kernels ranging from [7, 21] if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, kernel_size, self.blur_sigma, self.blur_sigma, [-np.pi, np.pi], self.betag_range, self.betap_range, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------ Generate kernels (used in the second degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.opt['sinc_prob2']: if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel2 = random_mixed_kernels( self.kernel_list2, self.kernel_prob2, kernel_size, self.blur_sigma2, self.blur_sigma2, [-np.pi, np.pi], self.betag_range2, self.betap_range2, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------------------- the final sinc kernel ------------------------------------- # if np.random.uniform() < self.opt['final_sinc_prob']: kernel_size = random.choice(self.kernel_range) omega_c = np.random.uniform(np.pi / 3, np.pi) sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) sinc_kernel = torch.FloatTensor(sinc_kernel) else: sinc_kernel = self.pulse_tensor # numpy to tensor img_gt = torch.from_numpy(image).float() kernel = torch.FloatTensor(kernel) kernel2 = torch.FloatTensor(kernel2) return_d = {'image': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel} return return_d # return image.copy(), self.get_label(idx) def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["img_tensor"] for example in examples] kernel1 = [example["kernel1"] for example in examples] kernel2 = [example["kernel2"] for example in examples] sinc_kernel = [example["sinc_kernel"] for example in examples] pil_image = [example["image"] for example in examples] if with_prior_preservation: raise NotImplementedError("Prior preservation not implemented.") pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() kernel1 = torch.stack(kernel1) kernel1 = kernel1.to(memory_format=torch.contiguous_format).float() kernel2 = torch.stack(kernel2) kernel2 = kernel2.to(memory_format=torch.contiguous_format).float() sinc_kernel = torch.stack(sinc_kernel) sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float() batch = {"image": pil_image, "img_tensor": pixel_values, "kernel1": kernel1, "kernel2": kernel2, "sinc_kernel": sinc_kernel} return batch