|
|
|
""" |
|
talking_video_dataset.py |
|
|
|
This module defines the TalkingVideoDataset class, a custom PyTorch dataset |
|
for handling talking video data. The dataset uses video files, masks, and |
|
embeddings to prepare data for tasks such as video generation and |
|
speech-driven video animation. |
|
|
|
Classes: |
|
TalkingVideoDataset |
|
|
|
Dependencies: |
|
json |
|
random |
|
torch |
|
decord.VideoReader, decord.cpu |
|
PIL.Image |
|
torch.utils.data.Dataset |
|
torchvision.transforms |
|
|
|
Example: |
|
from talking_video_dataset import TalkingVideoDataset |
|
from torch.utils.data import DataLoader |
|
|
|
# Example configuration for the Wav2Vec model |
|
class Wav2VecConfig: |
|
def __init__(self, audio_type, model_scale, features): |
|
self.audio_type = audio_type |
|
self.model_scale = model_scale |
|
self.features = features |
|
|
|
wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature") |
|
|
|
# Initialize dataset |
|
dataset = TalkingVideoDataset( |
|
img_size=(512, 512), |
|
sample_rate=16000, |
|
audio_margin=2, |
|
n_motion_frames=0, |
|
n_sample_frames=16, |
|
data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"], |
|
wav2vec_cfg=wav2vec_cfg, |
|
) |
|
|
|
# Initialize dataloader |
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
|
|
|
# Fetch one batch of data |
|
batch = next(iter(dataloader)) |
|
print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512) |
|
|
|
The TalkingVideoDataset class provides methods for loading video frames, masks, |
|
audio embeddings, and other relevant data, applying transformations, and preparing |
|
the data for training and evaluation in a deep learning pipeline. |
|
|
|
Attributes: |
|
img_size (tuple): The dimensions to resize the video frames to. |
|
sample_rate (int): The audio sample rate. |
|
audio_margin (int): The margin for audio sampling. |
|
n_motion_frames (int): The number of motion frames. |
|
n_sample_frames (int): The number of sample frames. |
|
data_meta_paths (list): List of paths to the JSON metadata files. |
|
wav2vec_cfg (object): Configuration for the Wav2Vec model. |
|
|
|
Methods: |
|
augmentation(images, transform, state=None): Apply transformation to input images. |
|
__getitem__(index): Get a sample from the dataset at the specified index. |
|
__len__(): Return the length of the dataset. |
|
""" |
|
|
|
import json |
|
import random |
|
from typing import List |
|
|
|
import torch |
|
from decord import VideoReader, cpu |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
|
|
class TalkingVideoDataset(Dataset): |
|
""" |
|
A dataset class for processing talking video data. |
|
|
|
Args: |
|
img_size (tuple, optional): The size of the output images. Defaults to (512, 512). |
|
sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000. |
|
audio_margin (int, optional): The margin for the audio data. Defaults to 2. |
|
n_motion_frames (int, optional): The number of motion frames. Defaults to 0. |
|
n_sample_frames (int, optional): The number of sample frames. Defaults to 16. |
|
data_meta_paths (list, optional): The paths to the data metadata. Defaults to None. |
|
wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None. |
|
|
|
Attributes: |
|
img_size (tuple): The size of the output images. |
|
sample_rate (int): The sample rate of the audio data. |
|
audio_margin (int): The margin for the audio data. |
|
n_motion_frames (int): The number of motion frames. |
|
n_sample_frames (int): The number of sample frames. |
|
data_meta_paths (list): The paths to the data metadata. |
|
wav2vec_cfg (dict): The configuration for the wav2vec model. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
img_size=(512, 512), |
|
sample_rate=16000, |
|
audio_margin=2, |
|
n_motion_frames=0, |
|
n_sample_frames=16, |
|
data_meta_paths=None, |
|
wav2vec_cfg=None, |
|
): |
|
super().__init__() |
|
self.sample_rate = sample_rate |
|
self.img_size = img_size |
|
self.audio_margin = audio_margin |
|
self.n_motion_frames = n_motion_frames |
|
self.n_sample_frames = n_sample_frames |
|
self.audio_type = wav2vec_cfg.audio_type |
|
self.audio_model = wav2vec_cfg.model_scale |
|
self.audio_features = wav2vec_cfg.features |
|
|
|
vid_meta = [] |
|
for data_meta_path in data_meta_paths: |
|
with open(data_meta_path, "r", encoding="utf-8") as f: |
|
vid_meta.extend(json.load(f)) |
|
self.vid_meta = vid_meta |
|
self.length = len(self.vid_meta) |
|
self.pixel_transform = transforms.Compose( |
|
[ |
|
transforms.Resize(self.img_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]), |
|
] |
|
) |
|
|
|
self.cond_transform = transforms.Compose( |
|
[ |
|
transforms.Resize(self.img_size), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
self.attn_transform_64 = transforms.Compose( |
|
[ |
|
transforms.Resize((64,64)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
self.attn_transform_32 = transforms.Compose( |
|
[ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
self.attn_transform_16 = transforms.Compose( |
|
[ |
|
transforms.Resize((16, 16)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
self.attn_transform_8 = transforms.Compose( |
|
[ |
|
transforms.Resize((8, 8)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
def augmentation(self, images, transform, state=None): |
|
""" |
|
Apply the given transformation to the input images. |
|
|
|
Args: |
|
images (List[PIL.Image] or PIL.Image): The input images to be transformed. |
|
transform (torchvision.transforms.Compose): The transformation to be applied to the images. |
|
state (torch.ByteTensor, optional): The state of the random number generator. |
|
If provided, it will set the RNG state to this value before applying the transformation. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: The transformed images as a tensor. |
|
If the input was a list of images, the tensor will have shape (f, c, h, w), |
|
where f is the number of images, c is the number of channels, h is the height, and w is the width. |
|
If the input was a single image, the tensor will have shape (c, h, w), |
|
where c is the number of channels, h is the height, and w is the width. |
|
""" |
|
if state is not None: |
|
torch.set_rng_state(state) |
|
if isinstance(images, List): |
|
transformed_images = [transform(img) for img in images] |
|
ret_tensor = torch.stack(transformed_images, dim=0) |
|
else: |
|
ret_tensor = transform(images) |
|
return ret_tensor |
|
|
|
def __getitem__(self, index): |
|
video_meta = self.vid_meta[index] |
|
video_path = video_meta["video_path"] |
|
mask_path = video_meta["mask_path"] |
|
lip_mask_union_path = video_meta.get("sep_mask_lip", None) |
|
face_mask_union_path = video_meta.get("sep_mask_face", None) |
|
full_mask_union_path = video_meta.get("sep_mask_border", None) |
|
face_emb_path = video_meta["face_emb_path"] |
|
audio_emb_path = video_meta[ |
|
f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}" |
|
] |
|
tgt_mask_pil = Image.open(mask_path) |
|
video_frames = VideoReader(video_path, ctx=cpu(0)) |
|
assert tgt_mask_pil is not None, "Fail to load target mask." |
|
assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames." |
|
video_length = len(video_frames) |
|
|
|
assert ( |
|
video_length |
|
> self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin |
|
) |
|
start_idx = random.randint( |
|
self.n_motion_frames, |
|
video_length - self.n_sample_frames - self.audio_margin - 1, |
|
) |
|
|
|
videos = video_frames[start_idx : start_idx + self.n_sample_frames] |
|
|
|
frame_list = [ |
|
Image.fromarray(video).convert("RGB") for video in videos.asnumpy() |
|
] |
|
|
|
face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames |
|
lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames |
|
full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames |
|
assert face_masks_list[0] is not None, "Fail to load face mask." |
|
assert lip_masks_list[0] is not None, "Fail to load lip mask." |
|
assert full_masks_list[0] is not None, "Fail to load full mask." |
|
|
|
|
|
face_emb = torch.load(face_emb_path) |
|
audio_emb = torch.load(audio_emb_path) |
|
indices = ( |
|
torch.arange(2 * self.audio_margin + 1) - self.audio_margin |
|
) |
|
center_indices = torch.arange( |
|
start_idx, |
|
start_idx + self.n_sample_frames, |
|
).unsqueeze(1) + indices.unsqueeze(0) |
|
audio_tensor = audio_emb[center_indices] |
|
|
|
ref_img_idx = random.randint( |
|
self.n_motion_frames, |
|
video_length - self.n_sample_frames - self.audio_margin - 1, |
|
) |
|
ref_img = video_frames[ref_img_idx].asnumpy() |
|
ref_img = Image.fromarray(ref_img) |
|
|
|
if self.n_motion_frames > 0: |
|
motions = video_frames[start_idx - self.n_motion_frames : start_idx] |
|
motion_list = [ |
|
Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy() |
|
] |
|
|
|
|
|
state = torch.get_rng_state() |
|
pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state) |
|
|
|
pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state) |
|
pixel_values_mask = pixel_values_mask.repeat(3, 1, 1) |
|
|
|
pixel_values_face_mask = [ |
|
self.augmentation(face_masks_list, self.attn_transform_64, state), |
|
self.augmentation(face_masks_list, self.attn_transform_32, state), |
|
self.augmentation(face_masks_list, self.attn_transform_16, state), |
|
self.augmentation(face_masks_list, self.attn_transform_8, state), |
|
] |
|
pixel_values_lip_mask = [ |
|
self.augmentation(lip_masks_list, self.attn_transform_64, state), |
|
self.augmentation(lip_masks_list, self.attn_transform_32, state), |
|
self.augmentation(lip_masks_list, self.attn_transform_16, state), |
|
self.augmentation(lip_masks_list, self.attn_transform_8, state), |
|
] |
|
pixel_values_full_mask = [ |
|
self.augmentation(full_masks_list, self.attn_transform_64, state), |
|
self.augmentation(full_masks_list, self.attn_transform_32, state), |
|
self.augmentation(full_masks_list, self.attn_transform_16, state), |
|
self.augmentation(full_masks_list, self.attn_transform_8, state), |
|
] |
|
|
|
pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) |
|
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) |
|
if self.n_motion_frames > 0: |
|
pixel_values_motion = self.augmentation( |
|
motion_list, self.pixel_transform, state |
|
) |
|
pixel_values_ref_img = torch.cat( |
|
[pixel_values_ref_img, pixel_values_motion], dim=0 |
|
) |
|
|
|
sample = { |
|
"video_dir": video_path, |
|
"pixel_values_vid": pixel_values_vid, |
|
"pixel_values_mask": pixel_values_mask, |
|
"pixel_values_face_mask": pixel_values_face_mask, |
|
"pixel_values_lip_mask": pixel_values_lip_mask, |
|
"pixel_values_full_mask": pixel_values_full_mask, |
|
"audio_tensor": audio_tensor, |
|
"pixel_values_ref_img": pixel_values_ref_img, |
|
"face_emb": face_emb, |
|
} |
|
|
|
return sample |
|
|
|
def __len__(self): |
|
return len(self.vid_meta) |
|
|