Hunyuan3D-1.0 / basicsr /data /imagent_dataset.py
gokaygokay's picture
Upload 556 files
0324143 verified
raw
history blame
18.2 kB
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
"""Streaming images and labels from datasets created with dataset_tool.py."""
import os
import numpy as np
import zipfile
import PIL.Image
import json
import torch
import random
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.data.transforms import augment
from basicsr.utils import img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
try:
import pyspng
except ImportError:
pyspng = None
KERNEL_OPT = {
'blur_kernel_size': 21,
'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
'sinc_prob': 0.1,
'blur_sigma': [0.2, 3],
'betag_range': [0.5, 4],
'betap_range': [1, 2],
'blur_kernel_size2': 21,
'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
'sinc_prob2': 0.1,
'blur_sigma2': [0.2, 1.5],
'betag_range2': [0.5, 4],
'betap_range2': [1, 2],
'final_sinc_prob': 0.8,
'use_hflip': False,
'use_rot': False
}
DEGRADE_OPT = {
'resize_prob': [0.2, 0.7, 0.1], # up, down, keep
'resize_range': [0.15, 1.5],
'gaussian_noise_prob': 0.5,
'noise_range': [1, 30],
'poisson_scale_range': [0.05, 3],
'gray_noise_prob': 0.4,
'jpeg_range': [30, 95],
# the second degradation process
'second_blur_prob': 0.8,
'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep
'resize_range2': [0.3, 1.2],
'gaussian_noise_prob2': 0.5,
'noise_range2': [1, 25],
'poisson_scale_range2': [0.05, 2.5],
'gray_noise_prob2': 0.4,
'jpeg_range2': [30, 95],
'gt_size': 512,
'no_degradation_prob': 0.01,
'use_usm': True,
'sf': 4,
'random_size': False,
'resize_lq': False
}
#----------------------------------------------------------------------------
# Abstract base class for datasets.
class Dataset(torch.utils.data.Dataset):
def __init__(self,
name, # Name of the dataset.
raw_shape, # Shape of the raw image data (NCHW).
use_labels = True, # Enable conditioning labels? False = label dimension is zero.
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
random_seed = 0, # Random seed to use when applying max_size.
cache = False, # Cache images in CPU memory?
):
self._name = name
self._raw_shape = list(raw_shape)
self._use_labels = use_labels
self._cache = cache
self._cached_images = dict() # {raw_idx: np.ndarray, ...}
self._raw_labels = None
self._label_shape = None
# Apply max_size.
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
if (max_size is not None) and (self._raw_idx.size > max_size):
np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
self._raw_idx = np.sort(self._raw_idx[:max_size])
# Apply xflip.
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
if xflip:
self._raw_idx = np.tile(self._raw_idx, 2)
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
def _get_raw_labels(self):
if self._raw_labels is None:
self._raw_labels = self._load_raw_labels() if self._use_labels else None
if self._raw_labels is None:
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
assert isinstance(self._raw_labels, np.ndarray)
assert self._raw_labels.shape[0] == self._raw_shape[0]
assert self._raw_labels.dtype in [np.float32, np.int64]
if self._raw_labels.dtype == np.int64:
assert self._raw_labels.ndim == 1
assert np.all(self._raw_labels >= 0)
return self._raw_labels
def close(self): # to be overridden by subclass
pass
def _load_raw_image(self, raw_idx): # to be overridden by subclass
raise NotImplementedError
def _load_raw_labels(self): # to be overridden by subclass
raise NotImplementedError
def __getstate__(self):
return dict(self.__dict__, _raw_labels=None)
def __del__(self):
try:
self.close()
except:
pass
def __len__(self):
return self._raw_idx.size
def __getitem__(self, idx):
raw_idx = self._raw_idx[idx]
image = self._cached_images.get(raw_idx, None)
if image is None:
image = self._load_raw_image(raw_idx)
if self._cache:
self._cached_images[raw_idx] = image
assert isinstance(image, np.ndarray)
assert list(image.shape) == self._raw_shape[1:]
if self._xflip[idx]:
assert image.ndim == 3 # CHW
image = image[:, :, ::-1]
return image.copy(), self.get_label(idx)
def get_label(self, idx):
label = self._get_raw_labels()[self._raw_idx[idx]]
if label.dtype == np.int64:
onehot = np.zeros(self.label_shape, dtype=np.float32)
onehot[label] = 1
label = onehot
return label.copy()
def get_details(self, idx):
d = dict()
d['raw_idx'] = int(self._raw_idx[idx])
d['xflip'] = (int(self._xflip[idx]) != 0)
d['raw_label'] = self._get_raw_labels()[d['raw_idx']].copy()
return d
@property
def name(self):
return self._name
@property
def image_shape(self): # [CHW]
return list(self._raw_shape[1:])
@property
def num_channels(self):
assert len(self.image_shape) == 3 # CHW
return self.image_shape[0]
@property
def resolution(self):
assert len(self.image_shape) == 3 # CHW
assert self.image_shape[1] == self.image_shape[2]
return self.image_shape[1]
@property
def label_shape(self):
if self._label_shape is None:
raw_labels = self._get_raw_labels()
if raw_labels.dtype == np.int64:
self._label_shape = [int(np.max(raw_labels)) + 1]
else:
self._label_shape = raw_labels.shape[1:]
return list(self._label_shape)
@property
def label_dim(self):
assert len(self.label_shape) == 1
return self.label_shape[0]
@property
def has_labels(self):
return any(x != 0 for x in self.label_shape)
@property
def has_onehot_labels(self):
return self._get_raw_labels().dtype == np.int64
#----------------------------------------------------------------------------
# Dataset subclass that loads images recursively from the specified directory
# or ZIP file.
class ImageFolderDataset(Dataset):
def __init__(self,
path, # Path to directory or zip.
resolution = None, # Ensure specific resolution, None = anything goes.
**super_kwargs, # Additional arguments for the Dataset base class.
):
self._path = path
self._zipfile = None
if os.path.isdir(self._path):
self._type = 'dir'
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
elif self._file_ext(self._path) == '.zip':
self._type = 'zip'
self._all_fnames = set(self._get_zipfile().namelist())
else:
raise IOError('Path must point to a directory or zip')
PIL.Image.init()
supported_ext = PIL.Image.EXTENSION.keys() | {'.npy'}
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in supported_ext)
if len(self._image_fnames) == 0:
raise IOError('No image files found in the specified path')
name = os.path.splitext(os.path.basename(self._path))[0]
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
raise IOError('Image files do not match the specified resolution')
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
@staticmethod
def _file_ext(fname):
return os.path.splitext(fname)[1].lower()
def _get_zipfile(self):
assert self._type == 'zip'
if self._zipfile is None:
self._zipfile = zipfile.ZipFile(self._path)
return self._zipfile
def _open_file(self, fname):
if self._type == 'dir':
return open(os.path.join(self._path, fname), 'rb')
if self._type == 'zip':
return self._get_zipfile().open(fname, 'r')
return None
def close(self):
try:
if self._zipfile is not None:
self._zipfile.close()
finally:
self._zipfile = None
def __getstate__(self):
return dict(super().__getstate__(), _zipfile=None)
def _load_raw_image(self, raw_idx):
fname = self._image_fnames[raw_idx]
ext = self._file_ext(fname)
with self._open_file(fname) as f:
if ext == '.npy':
image = np.load(f)
image = image.reshape(-1, *image.shape[-2:])
elif ext == '.png' and pyspng is not None:
image = pyspng.load(f.read())
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
else:
image = np.array(PIL.Image.open(f))
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
return image
def _load_raw_labels(self):
fname = 'dataset.json'
if fname not in self._all_fnames:
return None
with self._open_file(fname) as f:
labels = json.load(f)['labels']
if labels is None:
return None
labels = dict(labels)
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
labels = np.array(labels)
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
return labels
#----------------------------------------------------------------------------
@DATASET_REGISTRY.register(suffix='basicsr')
class IRImageFolderDataset(ImageFolderDataset):
def __init__(self,
opt=None, # Degradation kernel config.
**super_kwargs, # Additional arguments for the Dataset base class.
):
if opt is None: opt = KERNEL_OPT
self.opt = opt
super().__init__(**super_kwargs)
# blur settings for the first degradation
self.blur_kernel_size = opt['blur_kernel_size']
self.kernel_list = opt['kernel_list']
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
self.blur_sigma = opt['blur_sigma']
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
# blur settings for the second degradation
self.blur_kernel_size2 = opt['blur_kernel_size2']
self.kernel_list2 = opt['kernel_list2']
self.kernel_prob2 = opt['kernel_prob2']
self.blur_sigma2 = opt['blur_sigma2']
self.betag_range2 = opt['betag_range2']
self.betap_range2 = opt['betap_range2']
self.sinc_prob2 = opt['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = opt['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
# TODO: kernel range is now hard-coded, should be in the configure file
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
def _load_raw_image(self, raw_idx):
fname = self._image_fnames[raw_idx]
ext = self._file_ext(fname)
with self._open_file(fname) as f:
if ext == '.npy':
image = np.load(f)
image = image.reshape(-1, *image.shape[-2:])
elif ext == '.png' and pyspng is not None:
image = pyspng.load(f.read())
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
else:
image = np.array(PIL.Image.open(f))
image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
return image
def __getitem__(self, idx):
raw_idx = self._raw_idx[idx]
image = self._cached_images.get(raw_idx, None)
if image is None:
image = self._load_raw_image(raw_idx)
if self._cache:
self._cached_images[raw_idx] = image
assert isinstance(image, np.ndarray), type(image)
assert list(image.shape) == self._raw_shape[1:], image.shape
# # FIXME: flip or rotate
# image = augment(image, self.opt['use_hflip'], self.opt['use_rot'])
image = image.astype(np.float32) / 255.
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob']:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-np.pi, np.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.opt['sinc_prob2']:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-np.pi, np.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- the final sinc kernel ------------------------------------- #
if np.random.uniform() < self.opt['final_sinc_prob']:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
# numpy to tensor
img_gt = torch.from_numpy(image).float()
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return_d = {'image': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel}
return return_d
# return image.copy(), self.get_label(idx)
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["img_tensor"] for example in examples]
kernel1 = [example["kernel1"] for example in examples]
kernel2 = [example["kernel2"] for example in examples]
sinc_kernel = [example["sinc_kernel"] for example in examples]
pil_image = [example["image"] for example in examples]
if with_prior_preservation:
raise NotImplementedError("Prior preservation not implemented.")
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
kernel1 = torch.stack(kernel1)
kernel1 = kernel1.to(memory_format=torch.contiguous_format).float()
kernel2 = torch.stack(kernel2)
kernel2 = kernel2.to(memory_format=torch.contiguous_format).float()
sinc_kernel = torch.stack(sinc_kernel)
sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float()
batch = {"image": pil_image, "img_tensor": pixel_values, "kernel1": kernel1, "kernel2": kernel2, "sinc_kernel": sinc_kernel}
return batch