|
from typing import Literal, List |
|
|
|
import os |
|
import json |
|
import torch |
|
import torchvision.transforms as T |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image |
|
from lightning import LightningDataModule |
|
|
|
|
|
class SwimDataset(Dataset): |
|
def __init__( |
|
self, |
|
root_dir: str = "./datasets/swim_data", |
|
split: Literal["train", "val"] = "train", |
|
img_size: int = 512, |
|
): |
|
super().__init__() |
|
self.root_dir = root_dir |
|
self.split_dir = os.path.join(root_dir, split) |
|
self.img_size = img_size |
|
|
|
if split == "train": |
|
self.transform = T.Compose( |
|
[ |
|
T.Resize(img_size), |
|
T.RandomCrop(img_size), |
|
T.RandomHorizontalFlip(), |
|
T.ToTensor(), |
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
] |
|
) |
|
elif split == "val": |
|
self.transform = T.Compose( |
|
[ |
|
T.Resize(img_size), |
|
T.CenterCrop(img_size), |
|
T.ToTensor(), |
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
] |
|
) |
|
|
|
with open(os.path.join(self.split_dir, "labels.json"), "r") as f: |
|
self.data = json.load(f) |
|
|
|
|
|
self.data = [ |
|
img |
|
for img in self.data |
|
if not (img["timeofday"] == "night" and img["weather"] != "clear") |
|
] |
|
|
|
def __len__(self): |
|
return len(self.data) // 100 |
|
|
|
def __getitem__(self, idx): |
|
data = self.data[idx] |
|
|
|
|
|
img_path = os.path.join(self.split_dir, "images", data["name"]) |
|
img = Image.open(img_path).convert("RGB") |
|
img = self.transform(img) |
|
|
|
|
|
if data["weather"] != "clear": |
|
style_name = data["weather"] |
|
elif data["timeofday"] == "night": |
|
style_name = "night" |
|
else: |
|
style_name = "clear" |
|
|
|
|
|
style_flag = style_name != "clear" |
|
|
|
|
|
style = torch.zeros(4) |
|
|
|
if style_flag: |
|
style[self.get_stylenames().index(style_name)] = 1 |
|
|
|
return { |
|
"image": img, |
|
"style": style, |
|
"style_flag": style_flag, |
|
} |
|
|
|
def get_stylenames(self) -> List[str]: |
|
return ["rain", "snow", "fog", "night"] |
|
|
|
|
|
class SwimDataModule(LightningDataModule): |
|
def __init__( |
|
self, |
|
root_dir: str = "./datasets/swim_data", |
|
batch_size: int = 1, |
|
img_size: int = 512, |
|
): |
|
super().__init__() |
|
self.root_dir = root_dir |
|
self.img_size = img_size |
|
self.batch_size = batch_size |
|
|
|
def setup(self, stage=None): |
|
if stage == "fit" or stage is None: |
|
self.train_dataset = SwimDataset( |
|
root_dir=self.root_dir, split="train", img_size=self.img_size |
|
) |
|
self.val_dataset = SwimDataset( |
|
root_dir=self.root_dir, split="val", img_size=self.img_size |
|
) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
num_workers=4, |
|
collate_fn=self.custom_collate_fn, |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
num_workers=4, |
|
collate_fn=self.custom_collate_fn, |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self.val_dataset, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=4, |
|
collate_fn=self.custom_collate_fn, |
|
) |
|
|
|
@staticmethod |
|
def custom_collate_fn(batch): |
|
images = torch.stack([item["image"] for item in batch]) |
|
styles = torch.stack([item["style"] for item in batch]) |
|
style_flags = [item["style_flag"] for item in batch] |
|
return {"images": images, "styles": styles, "style_flags": style_flags} |
|
|