RAVE / utils /video_grid_utils.py
ozgurkara's picture
first commit
eb9a9b4
import os
import cv2 as cv
import numpy as np
import torch
import imageio
import glob
from torchvision.utils import make_grid
from torchvision.transforms import transforms
from torchvision.transforms.functional import to_pil_image
def prepare_video_to_grid(path, grid_count, grid_size, pad):
video = cv.VideoCapture(path)
if grid_count == -1:
frame_count = int(video.get(cv.CAP_PROP_FRAME_COUNT))
else:
frame_count = min(grid_count * pad * grid_size**2, int(video.get(cv.CAP_PROP_FRAME_COUNT)))
transform = transforms.Compose([
transforms.ConvertImageDtype(dtype=torch.float),
])
success = True
max_grid_area = 512*512* grid_size**2
grids = []
frames = []
total_grid = grid_size**2
for idx in range(frame_count):
success,image = video.read()
assert success, 'Video read failed'
if idx % pad == 0:
rgb_img = cv.cvtColor(image, cv.COLOR_BGR2RGB)
rgb_img = np.transpose(rgb_img, (2, 0, 1))
frames.append(transform(torch.from_numpy(rgb_img)))
if len(frames) == total_grid:
grid = make_grid(frames, nrow=grid_size, padding=0)
pil_image = (to_pil_image(grid))
w,h = pil_image.size
a = float(np.sqrt((w*h/max_grid_area)))
w1 = int((w//a)//(grid_size*8))*grid_size*8
h1 = int((h//a)//(grid_size*8))*grid_size*8
pil_image= pil_image.resize((w1, h1))
grids.append(pil_image)
frames = []
return grids # list of frames
def prepare_video_to_frames(path, grid_count, grid_size, pad, format='gif'):
video = cv.VideoCapture(path)
if grid_count == -1:
frame_count = int(video.get(cv.CAP_PROP_FRAME_COUNT))
else:
frame_count = min(grid_count * pad * grid_size**2, int(video.get(cv.CAP_PROP_FRAME_COUNT)))
frame_idx = 0
frames = []
frames_grid = []
dir_path = os.path.dirname(path)
video_name = path.split('/')[-1].split('.')[0]
os.makedirs(os.path.join(dir_path, 'frames/', video_name), exist_ok=True)
os.makedirs(os.path.join(dir_path, 'video/', video_name), exist_ok=True)
for idx in range(frame_count):
success,image = video.read()
assert success, 'Video read failed'
if idx % pad == 0:
frames.append(image)
for frame in frames[:(len(frames)//(grid_size**2)*(grid_size**2))]:
frames_grid.append(frame)
cv.imwrite(os.path.join(dir_path, 'frames/', video_name, f'{str(frame_idx).zfill(5)}.png'), frame)
frame_idx += 1
if format == 'gif':
with imageio.get_writer(os.path.join(dir_path, 'video/', f'{video_name}_fc{frame_idx}_pad{pad}_grid{grid_size}.gif'), mode='I') as writer:
for frame in frames_grid:
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
writer.append_data(frame)
elif format == 'mp4':
image_files = sorted(glob.glob(os.path.join(dir_path, 'frames/', video_name, '*.png')))
images = [imageio.imread(image_file) for image_file in image_files]
save_file_path = os.path.join(dir_path, 'video/', f'{video_name}_fc{frame_idx}_pad{pad}_grid{grid_size}.mp4')
imageio.mimsave(save_file_path, images, fps=20)
return frame_idx # number of frames