swim_new / swim /data /swim_data.py
qninhdt's picture
cc
8cc0674
raw
history blame
4.34 kB
from typing import Literal, List
import os
import json
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from lightning import LightningDataModule
class SwimDataset(Dataset):
def __init__(
self,
root_dir: str = "./datasets/swim_data",
split: Literal["train", "val"] = "train",
img_size: int = 512,
):
super().__init__()
self.root_dir = root_dir
self.split_dir = os.path.join(root_dir, split)
self.img_size = img_size
if split == "train":
self.transform = T.Compose(
[
T.Resize(img_size), # smaller edge of image resized to img_size
T.RandomCrop(img_size), # get a random crop of img_size x img_size
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
elif split == "val":
self.transform = T.Compose(
[
T.Resize(img_size),
T.CenterCrop(img_size),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
with open(os.path.join(self.split_dir, "labels.json"), "r") as f:
self.data = json.load(f)
# filter out images that are both at night and have adverse weather conditions
self.data = [
img
for img in self.data
if not (img["timeofday"] == "night" and img["weather"] != "clear")
]
def __len__(self):
return len(self.data) // 100
def __getitem__(self, idx):
data = self.data[idx]
# load image
img_path = os.path.join(self.split_dir, "images", data["name"])
img = Image.open(img_path).convert("RGB")
img = self.transform(img)
# load style
if data["weather"] != "clear":
style_name = data["weather"]
elif data["timeofday"] == "night":
style_name = "night"
else:
style_name = "clear"
# true if image has any styles
style_flag = style_name != "clear"
# one-hot encode style
style = torch.zeros(4)
if style_flag:
style[self.get_stylenames().index(style_name)] = 1
return {
"image": img,
"style": style,
"style_flag": style_flag,
}
def get_stylenames(self) -> List[str]:
return ["rain", "snow", "fog", "night"]
class SwimDataModule(LightningDataModule):
def __init__(
self,
root_dir: str = "./datasets/swim_data",
batch_size: int = 1,
img_size: int = 512,
):
super().__init__()
self.root_dir = root_dir
self.img_size = img_size
self.batch_size = batch_size
def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = SwimDataset(
root_dir=self.root_dir, split="train", img_size=self.img_size
)
self.val_dataset = SwimDataset(
root_dir=self.root_dir, split="val", img_size=self.img_size
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
collate_fn=self.custom_collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
collate_fn=self.custom_collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=1,
shuffle=False,
num_workers=4,
collate_fn=self.custom_collate_fn,
)
@staticmethod
def custom_collate_fn(batch):
images = torch.stack([item["image"] for item in batch])
styles = torch.stack([item["style"] for item in batch])
style_flags = [item["style_flag"] for item in batch]
return {"images": images, "styles": styles, "style_flags": style_flags}