File size: 5,348 Bytes
5a510e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# pylint: disable=R0801
"""
This module contains the code for a dataset class called FaceMaskDataset, which is used to process and
load image data related to face masks. The dataset class inherits from the PyTorch Dataset class and
provides methods for data augmentation, getting items from the dataset, and determining the length of the
dataset. The module also includes imports for necessary libraries such as json, random, pathlib, torch,
PIL, and transformers.
"""
import json
import random
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import CLIPImageProcessor
class FaceMaskDataset(Dataset):
"""
FaceMaskDataset is a custom dataset for face mask images.
Args:
img_size (int): The size of the input images.
drop_ratio (float, optional): The ratio of dropped pixels during data augmentation. Defaults to 0.1.
data_meta_paths (list, optional): The paths to the metadata files containing image paths and labels. Defaults to ["./data/HDTF_meta.json"].
sample_margin (int, optional): The margin for sampling regions in the image. Defaults to 30.
Attributes:
img_size (int): The size of the input images.
drop_ratio (float): The ratio of dropped pixels during data augmentation.
data_meta_paths (list): The paths to the metadata files containing image paths and labels.
sample_margin (int): The margin for sampling regions in the image.
processor (CLIPImageProcessor): The image processor for preprocessing images.
transform (transforms.Compose): The image augmentation transform.
"""
def __init__(
self,
img_size,
drop_ratio=0.1,
data_meta_paths=None,
sample_margin=30,
):
super().__init__()
self.img_size = img_size
self.sample_margin = sample_margin
vid_meta = []
for data_meta_path in data_meta_paths:
with open(data_meta_path, "r", encoding="utf-8") as f:
vid_meta.extend(json.load(f))
self.vid_meta = vid_meta
self.length = len(self.vid_meta)
self.clip_image_processor = CLIPImageProcessor()
self.transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
transforms.Resize(self.img_size),
transforms.ToTensor(),
]
)
self.drop_ratio = drop_ratio
def augmentation(self, image, transform, state=None):
"""
Apply data augmentation to the input image.
Args:
image (PIL.Image): The input image.
transform (torchvision.transforms.Compose): The data augmentation transforms.
state (dict, optional): The random state for reproducibility. Defaults to None.
Returns:
PIL.Image: The augmented image.
"""
if state is not None:
torch.set_rng_state(state)
return transform(image)
def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["image_path"]
mask_path = video_meta["mask_path"]
face_emb_path = video_meta["face_emb"]
video_frames = sorted(Path(video_path).iterdir())
video_length = len(video_frames)
margin = min(self.sample_margin, video_length)
ref_img_idx = random.randint(0, video_length - 1)
if ref_img_idx + margin < video_length:
tgt_img_idx = random.randint(
ref_img_idx + margin, video_length - 1)
elif ref_img_idx - margin > 0:
tgt_img_idx = random.randint(0, ref_img_idx - margin)
else:
tgt_img_idx = random.randint(0, video_length - 1)
ref_img_pil = Image.open(video_frames[ref_img_idx])
tgt_img_pil = Image.open(video_frames[tgt_img_idx])
tgt_mask_pil = Image.open(mask_path)
assert ref_img_pil is not None, "Fail to load reference image."
assert tgt_img_pil is not None, "Fail to load target image."
assert tgt_mask_pil is not None, "Fail to load target mask."
state = torch.get_rng_state()
tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
tgt_mask_img = self.augmentation(
tgt_mask_pil, self.cond_transform, state)
tgt_mask_img = tgt_mask_img.repeat(3, 1, 1)
ref_img_vae = self.augmentation(
ref_img_pil, self.transform, state)
face_emb = torch.load(face_emb_path)
sample = {
"video_dir": video_path,
"img": tgt_img,
"tgt_mask": tgt_mask_img,
"ref_img": ref_img_vae,
"face_emb": face_emb,
}
return sample
def __len__(self):
return len(self.vid_meta)
if __name__ == "__main__":
data = FaceMaskDataset(img_size=(512, 512))
train_dataloader = torch.utils.data.DataLoader(
data, batch_size=4, shuffle=True, num_workers=1
)
for step, batch in enumerate(train_dataloader):
print(batch["tgt_mask"].shape)
break
|