Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
from torch.utils.data import DataLoader | |
class BASEDataModule(pl.LightningDataModule): | |
def __init__(self, collate_fn): | |
super().__init__() | |
self.dataloader_options = {"collate_fn": collate_fn} | |
self.persistent_workers = True | |
self.is_mm = False | |
self._train_dataset = None | |
self._val_dataset = None | |
self._test_dataset = None | |
def get_sample_set(self, overrides={}): | |
sample_params = self.hparams.copy() | |
sample_params.update(overrides) | |
return self.DatasetEval(**sample_params) | |
def train_dataset(self): | |
if self._train_dataset is None: | |
self._train_dataset = self.Dataset(split=self.cfg.TRAIN.SPLIT, | |
**self.hparams) | |
return self._train_dataset | |
def val_dataset(self): | |
if self._val_dataset is None: | |
params = self.hparams.copy() | |
params['code_path'] = None | |
params['split'] = self.cfg.EVAL.SPLIT | |
self._val_dataset = self.DatasetEval(**params) | |
return self._val_dataset | |
def test_dataset(self): | |
if self._test_dataset is None: | |
# self._test_dataset = self.DatasetEval(split=self.cfg.TEST.SPLIT, | |
# **self.hparams) | |
params = self.hparams.copy() | |
params['code_path'] = None | |
params['split'] = self.cfg.TEST.SPLIT | |
self._test_dataset = self.DatasetEval( **params) | |
return self._test_dataset | |
def setup(self, stage=None): | |
# Use the getter the first time to load the data | |
if stage in (None, "fit"): | |
_ = self.train_dataset | |
_ = self.val_dataset | |
if stage in (None, "test"): | |
_ = self.test_dataset | |
def train_dataloader(self): | |
dataloader_options = self.dataloader_options.copy() | |
dataloader_options["batch_size"] = self.cfg.TRAIN.BATCH_SIZE | |
dataloader_options["num_workers"] = self.cfg.TRAIN.NUM_WORKERS | |
return DataLoader( | |
self.train_dataset, | |
shuffle=False, | |
persistent_workers=True, | |
**dataloader_options, | |
) | |
def predict_dataloader(self): | |
dataloader_options = self.dataloader_options.copy() | |
dataloader_options[ | |
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE | |
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS | |
dataloader_options["shuffle"] = False | |
return DataLoader( | |
self.test_dataset, | |
persistent_workers=True, | |
**dataloader_options, | |
) | |
def val_dataloader(self): | |
# overrides batch_size and num_workers | |
dataloader_options = self.dataloader_options.copy() | |
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE | |
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS | |
dataloader_options["shuffle"] = False | |
return DataLoader( | |
self.val_dataset, | |
persistent_workers=True, | |
**dataloader_options, | |
) | |
def test_dataloader(self): | |
# overrides batch_size and num_workers | |
dataloader_options = self.dataloader_options.copy() | |
dataloader_options[ | |
"batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE | |
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS | |
dataloader_options["shuffle"] = False | |
return DataLoader( | |
self.test_dataset, | |
persistent_workers=True, | |
**dataloader_options, | |
) | |