Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import os | |
import json | |
import re | |
import cv2 | |
from dataclasses import dataclass, field | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from craftsman import register | |
from craftsman.utils.typing import * | |
from craftsman.utils.config import parse_structured | |
from .base import BaseDataModuleConfig, BaseDataset | |
class ObjaverseDataModuleConfig(BaseDataModuleConfig): | |
pass | |
class ObjaverseDataset(BaseDataset): | |
pass | |
class ObjaverseDataModule(pl.LightningDataModule): | |
cfg: ObjaverseDataModuleConfig | |
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
super().__init__() | |
self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) | |
def setup(self, stage=None) -> None: | |
if stage in [None, "fit"]: | |
self.train_dataset = ObjaverseDataset(self.cfg, "train") | |
if stage in [None, "fit", "validate"]: | |
self.val_dataset = ObjaverseDataset(self.cfg, "val") | |
if stage in [None, "test", "predict"]: | |
self.test_dataset = ObjaverseDataset(self.cfg, "test") | |
def prepare_data(self): | |
pass | |
def general_loader(self, dataset, batch_size, collate_fn=None, num_workers=0) -> DataLoader: | |
return DataLoader( | |
dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=num_workers | |
) | |
def train_dataloader(self) -> DataLoader: | |
return self.general_loader( | |
self.train_dataset, | |
batch_size=self.cfg.batch_size, | |
collate_fn=self.train_dataset.collate, | |
num_workers=self.cfg.num_workers | |
) | |
def val_dataloader(self) -> DataLoader: | |
return self.general_loader(self.val_dataset, batch_size=1) | |
def test_dataloader(self) -> DataLoader: | |
return self.general_loader(self.test_dataset, batch_size=1) | |
def predict_dataloader(self) -> DataLoader: | |
return self.general_loader(self.test_dataset, batch_size=1) |