|
import os |
|
import sys |
|
sys.path.append(os.path.abspath('.')) |
|
|
|
import argparse |
|
import datetime |
|
import numpy as np |
|
import time |
|
import torch |
|
import io |
|
import json |
|
import jsonlines |
|
|
|
import cv2 |
|
import math |
|
import random |
|
from pathlib import Path |
|
from tqdm import tqdm |
|
|
|
from concurrent import futures |
|
from torch.utils.data import Dataset, DataLoader, DistributedSampler |
|
from collections import OrderedDict |
|
from torchvision import transforms as pth_transforms |
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
from PIL import Image |
|
from PIL import ImageFile |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
from trainer_misc import init_distributed_mode |
|
from video_vae import CausalVideoVAELossWrapper |
|
|
|
|
|
|
|
def get_transform(width, height, new_width=None, new_height=None, resize=False,): |
|
transform_list = [] |
|
|
|
if resize: |
|
|
|
scale = max(new_width / width, new_height / height) |
|
resized_width = round(width * scale) |
|
resized_height = round(height * scale) |
|
|
|
transform_list.append(pth_transforms.Resize((resized_height, resized_width), InterpolationMode.BICUBIC, antialias=True)) |
|
transform_list.append(pth_transforms.CenterCrop((new_height, new_width))) |
|
|
|
transform_list.extend([ |
|
pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
]) |
|
transform_list = pth_transforms.Compose(transform_list) |
|
|
|
return transform_list |
|
|
|
|
|
def load_video_and_transform(video_path, frame_indexs, frame_number, new_width=None, new_height=None, resize=False): |
|
video_capture = None |
|
frame_indexs_set = set(frame_indexs) |
|
|
|
try: |
|
video_capture = cv2.VideoCapture(video_path) |
|
frames = [] |
|
frame_index = 0 |
|
while True: |
|
flag, frame = video_capture.read() |
|
if not flag: |
|
break |
|
if frame_index > frame_indexs[-1]: |
|
break |
|
if frame_index in frame_indexs_set: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame = torch.from_numpy(frame) |
|
frame = frame.permute(2, 0, 1) |
|
frames.append(frame) |
|
frame_index += 1 |
|
|
|
video_capture.release() |
|
|
|
if len(frames) == 0: |
|
print(f"Empty video {video_path}") |
|
return None |
|
|
|
frames = frames[:frame_number] |
|
duration = ((len(frames) - 1) // 8) * 8 + 1 |
|
frames = frames[:duration] |
|
frames = torch.stack(frames).float() / 255 |
|
video_transform = get_transform(frames.shape[-1], frames.shape[-2], new_width, new_height, resize=resize) |
|
frames = video_transform(frames).permute(1, 0, 2, 3) |
|
return frames |
|
|
|
except Exception as e: |
|
print(f"Loading video: {video_path} exception {e}") |
|
if video_capture is not None: |
|
video_capture.release() |
|
return None |
|
|
|
|
|
class VideoDataset(Dataset): |
|
def __init__(self, anno_file, width, height, num_frames): |
|
super().__init__() |
|
self.annotation = [] |
|
self.width = width |
|
self.height = height |
|
self.num_frames = num_frames |
|
|
|
with jsonlines.open(anno_file, 'r') as reader: |
|
for item in tqdm(reader): |
|
self.annotation.append(item) |
|
|
|
tot_len = len(self.annotation) |
|
print(f"Totally {len(self.annotation)} videos") |
|
|
|
def process_one_video(self, video_item): |
|
videos_per_task = [] |
|
video_path = video_item['video'] |
|
output_latent_path = video_item['latent'] |
|
|
|
|
|
frame_indexs = video_item['frames'] if 'frames' in video_item else list(range(self.num_frames)) |
|
|
|
try: |
|
video_frames_tensors = load_video_and_transform( |
|
video_path, |
|
frame_indexs, |
|
frame_number=self.num_frames, |
|
new_width=self.width, |
|
new_height=self.height, |
|
resize=True |
|
) |
|
|
|
if video_frames_tensors is None: |
|
return videos_per_task |
|
|
|
video_frames_tensors = video_frames_tensors.unsqueeze(0) |
|
videos_per_task.append({'video': video_path, 'input': video_frames_tensors, 'output': output_latent_path}) |
|
|
|
except Exception as e: |
|
print(f"Load video tensor ERROR: {e}") |
|
|
|
return videos_per_task |
|
|
|
def __getitem__(self, index): |
|
try: |
|
video_item = self.annotation[index] |
|
videos_per_task = self.process_one_video(video_item) |
|
except Exception as e: |
|
print(f'Error with {e}') |
|
videos_per_task = [] |
|
|
|
return videos_per_task |
|
|
|
def __len__(self): |
|
return len(self.annotation) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser('Pytorch Multi-process Training script', add_help=False) |
|
parser.add_argument('--batch_size', default=4, type=int) |
|
parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path') |
|
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16") |
|
parser.add_argument('--anno_file', type=str, default='', help="The video annotation file") |
|
parser.add_argument('--width', type=int, default=640, help="The video width") |
|
parser.add_argument('--height', type=int, default=384, help="The video height") |
|
parser.add_argument('--num_frames', type=int, default=121, help="The frame number to encode") |
|
parser.add_argument('--save_memory', action='store_true', help="Open the VAE tiling") |
|
return parser.parse_args() |
|
|
|
|
|
def build_model(args): |
|
model_path = args.model_path |
|
model_dtype = args.model_dtype |
|
model = CausalVideoVAELossWrapper(model_path, model_dtype=model_dtype, interpolate=False, add_discriminator=False) |
|
model = model.eval() |
|
return model |
|
|
|
|
|
def build_data_loader(args): |
|
|
|
def collate_fn(batch): |
|
return_batch = {'input' : [], 'output': []} |
|
for videos_ in batch: |
|
for video_input in videos_: |
|
return_batch['input'].append(video_input['input']) |
|
return_batch['output'].append(video_input['output']) |
|
return return_batch |
|
|
|
dataset = VideoDataset(args.anno_file, args.width, args.height, args.num_frames) |
|
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False) |
|
loader = DataLoader( |
|
dataset, batch_size=args.batch_size, num_workers=6, pin_memory=True, |
|
sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False, prefetch_factor=2, |
|
) |
|
return loader |
|
|
|
|
|
def save_tensor(tensor, output_path): |
|
try: |
|
torch.save(tensor.clone(), output_path) |
|
except Exception as e: |
|
pass |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
init_distributed_mode(args) |
|
|
|
device = torch.device('cuda') |
|
rank = args.rank |
|
|
|
model = build_model(args) |
|
model.to(device) |
|
|
|
if args.model_dtype == "bf16": |
|
torch_dtype = torch.bfloat16 |
|
elif args.model_dtype == "fp16": |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
data_loader = build_data_loader(args) |
|
torch.distributed.barrier() |
|
|
|
window_size = 16 |
|
temporal_chunk = True |
|
task_queue = [] |
|
|
|
if args.save_memory: |
|
|
|
model.vae.enable_tiling() |
|
|
|
with futures.ThreadPoolExecutor(max_workers=16) as executor: |
|
|
|
for sample in tqdm(data_loader): |
|
input_video_list = sample['input'] |
|
output_path_list = sample['output'] |
|
|
|
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype): |
|
for video_input, output_path in zip(input_video_list, output_path_list): |
|
video_latent = model.encode_latent(video_input.to(device), sample=True, window_size=window_size, temporal_chunk=temporal_chunk, tile_sample_min_size=256) |
|
video_latent = video_latent.to(torch_dtype).cpu() |
|
task_queue.append(executor.submit(save_tensor, video_latent, output_path)) |
|
|
|
for future in futures.as_completed(task_queue): |
|
res = future.result() |
|
|
|
torch.distributed.barrier() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |