Spaces:
Running
Running
import os | |
import json | |
import numpy as np | |
import webdataset as wds | |
import pytorch_lightning as pl | |
import torch | |
from torch.utils.data import Dataset | |
from torch.utils.data.distributed import DistributedSampler | |
from PIL import Image | |
from pathlib import Path | |
from src.utils.train_util import instantiate_from_config | |
class DataModuleFromConfig(pl.LightningDataModule): | |
def __init__( | |
self, | |
batch_size=8, | |
num_workers=4, | |
train=None, | |
validation=None, | |
test=None, | |
**kwargs, | |
): | |
super().__init__() | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.dataset_configs = dict() | |
if train is not None: | |
self.dataset_configs['train'] = train | |
if validation is not None: | |
self.dataset_configs['validation'] = validation | |
if test is not None: | |
self.dataset_configs['test'] = test | |
def setup(self, stage): | |
if stage in ['fit']: | |
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) | |
else: | |
raise NotImplementedError | |
def train_dataloader(self): | |
sampler = DistributedSampler(self.datasets['train']) | |
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) | |
def val_dataloader(self): | |
sampler = DistributedSampler(self.datasets['validation']) | |
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False, sampler=sampler) | |
def test_dataloader(self): | |
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
class ObjaverseData(Dataset): | |
def __init__(self, | |
root_dir='objaverse/', | |
meta_fname='valid_paths.json', | |
image_dir='rendering_zero123plus', | |
validation=False, | |
): | |
self.root_dir = Path(root_dir) | |
self.image_dir = image_dir | |
with open(os.path.join(root_dir, meta_fname)) as f: | |
lvis_dict = json.load(f) | |
paths = [] | |
for k in lvis_dict.keys(): | |
paths.extend(lvis_dict[k]) | |
self.paths = paths | |
total_objects = len(self.paths) | |
if validation: | |
self.paths = self.paths[-16:] # used last 16 as validation | |
else: | |
self.paths = self.paths[:-16] | |
print('============= length of dataset %d =============' % len(self.paths)) | |
def __len__(self): | |
return len(self.paths) | |
def load_im(self, path, color): | |
pil_img = Image.open(path) | |
image = np.asarray(pil_img, dtype=np.float32) / 255. | |
alpha = image[:, :, 3:] | |
image = image[:, :, :3] * alpha + color * (1 - alpha) | |
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() | |
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() | |
return image, alpha | |
def __getitem__(self, index): | |
while True: | |
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index]) | |
'''background color, default: white''' | |
bkg_color = [1., 1., 1.] | |
img_list = [] | |
try: | |
for idx in range(7): | |
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color) | |
img_list.append(img) | |
except Exception as e: | |
print(e) | |
index = np.random.randint(0, len(self.paths)) | |
continue | |
break | |
imgs = torch.stack(img_list, dim=0).float() | |
data = { | |
'cond_imgs': imgs[0], # (3, H, W) | |
'target_imgs': imgs[1:], # (6, 3, H, W) | |
} | |
return data | |