TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
import glob
import logging
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from PIL import Image
from segment_anything_hq import (SamPredictor, build_sam_vit_b,
build_sam_vit_h, build_sam_vit_l)
from segment_anything_hq.build_sam import build_sam_vit_t
from tqdm.rich import tqdm
logger = logging.getLogger(__name__)
build_sam_table={
"sam_hq_vit_l":build_sam_vit_l,
"sam_hq_vit_h":build_sam_vit_h,
"sam_hq_vit_b":build_sam_vit_b,
"sam_hq_vit_tiny":build_sam_vit_t,
}
# adapted from https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/grounded_sam_demo.py
class MaskPredictor:
def __init__(self,model_config_path, model_checkpoint_path,device, sam_checkpoint, box_threshold=0.3, text_threshold=0.25 ):
self.groundingdino_model = None
self.sam_predictor = None
self.model_config_path = model_config_path
self.model_checkpoint_path = model_checkpoint_path
self.device = device
self.sam_checkpoint = sam_checkpoint
self.box_threshold = box_threshold
self.text_threshold = text_threshold
def load_groundingdino_model(self):
args = SLConfig.fromfile(self.model_config_path)
args.device = self.device
model = build_model(args)
checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
#print(load_res)
_ = model.eval()
self.groundingdino_model = model
def load_sam_predictor(self):
s = Path(self.sam_checkpoint)
self.sam_predictor = SamPredictor(build_sam_table[ s.stem ](checkpoint=self.sam_checkpoint).to(self.device))
def transform_image(self,image_pil):
import groundingdino.datasets.transforms as T
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image
def get_grounding_output(self, image, caption, with_logits=True):
model = self.groundingdino_model
device = self.device
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > self.box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
return boxes_filt, pred_phrases
def __call__(self, image_pil:Image, text_prompt):
if self.groundingdino_model is None:
self.load_groundingdino_model()
self.load_sam_predictor()
transformed_img = self.transform_image(image_pil)
# run grounding dino model
boxes_filt, pred_phrases = self.get_grounding_output(
transformed_img, text_prompt
)
if boxes_filt.shape[0] == 0:
logger.info(f"object not found")
w, h = image_pil.size
return np.zeros(shape=(1,h,w), dtype=bool)
img_array = np.array(image_pil)
self.sam_predictor.set_image(img_array)
size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
boxes_filt = boxes_filt.cpu()
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_filt, img_array.shape[:2]).to(self.device)
masks, _, _ = self.sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes.to(self.device),
multimask_output = False,
)
result = None
for m in masks:
if result is None:
result = m
else:
result |= m
result = result.cpu().detach().numpy().copy()
return result
def load_mask_list(mask_dir, masked_area_list, mask_padding):
mask_frame_list = sorted(glob.glob( os.path.join(mask_dir, "[0-9]*.png"), recursive=False))
kernel = np.ones((abs(mask_padding),abs(mask_padding)),np.uint8)
for m in mask_frame_list:
cur = int(Path(m).stem)
tmp = np.asarray(Image.open(m))
if mask_padding < 0:
tmp = cv2.erode(tmp, kernel,iterations = 1)
elif mask_padding > 0:
tmp = cv2.dilate(tmp, kernel,iterations = 1)
masked_area_list[cur] = tmp[None,...]
return masked_area_list
def crop_mask_list(mask_list):
area_list = []
max_h = 0
max_w = 0
for m in mask_list:
if m is None:
area_list.append(None)
continue
m = m > 127
area = np.where(m[0] == True)
if area[0].size == 0:
area_list.append(None)
continue
ymin = min(area[0])
ymax = max(area[0])
xmin = min(area[1])
xmax = max(area[1])
h = ymax+1 - ymin
w = xmax+1 - xmin
max_h = max(max_h, h)
max_w = max(max_w, w)
area_list.append( (ymin, ymax, xmin, xmax) )
#crop = m[ymin:ymax+1,xmin:xmax+1]
logger.info(f"{max_h=}")
logger.info(f"{max_w=}")
border_h = mask_list[0].shape[1]
border_w = mask_list[0].shape[2]
mask_pos_list=[]
cropped_mask_list=[]
for a, m in zip(area_list, mask_list):
if m is None or a is None:
mask_pos_list.append(None)
cropped_mask_list.append(None)
continue
ymin,ymax,xmin,xmax = a
h = ymax+1 - ymin
w = xmax+1 - xmin
# H
diff_h = max_h - h
dh1 = diff_h//2
dh2 = diff_h - dh1
y1 = ymin - dh1
y2 = ymax + dh2
if y1 < 0:
y1 = 0
y2 = max_h-1
elif y2 >= border_h:
y1 = (border_h-1) - (max_h - 1)
y2 = (border_h-1)
# W
diff_w = max_w - w
dw1 = diff_w//2
dw2 = diff_w - dw1
x1 = xmin - dw1
x2 = xmax + dw2
if x1 < 0:
x1 = 0
x2 = max_w-1
elif x2 >= border_w:
x1 = (border_w-1) - (max_w - 1)
x2 = (border_w-1)
mask_pos_list.append( (int(x1),int(y1)) )
m = m[0][y1:y2+1,x1:x2+1]
cropped_mask_list.append( m[None,...] )
return cropped_mask_list, mask_pos_list, (max_h,max_w)
def crop_frames(pos_list, crop_size_hw, frame_dir):
h,w = crop_size_hw
for i,pos in tqdm(enumerate(pos_list),total=len(pos_list)):
filename = f"{i:08d}.png"
frame_path = frame_dir / filename
if not frame_path.is_file():
logger.info(f"{frame_path=} not found. skip")
continue
if pos is None:
continue
x, y = pos
tmp = np.asarray(Image.open(frame_path))
tmp = tmp[y:y+h,x:x+w,...]
Image.fromarray(tmp).save(frame_path)
def save_crop_info(mask_pos_list, crop_size_hw, frame_size_hw, save_path):
import json
pos_map = {}
for i, pos in enumerate(mask_pos_list):
if pos is not None:
pos_map[str(i)]=pos
info = {
"frame_height" : int(frame_size_hw[0]),
"frame_width" : int(frame_size_hw[1]),
"height": int(crop_size_hw[0]),
"width": int(crop_size_hw[1]),
"pos_map" : pos_map,
}
with open(save_path, mode="wt", encoding="utf-8") as f:
json.dump(info, f, ensure_ascii=False, indent=4)
def restore_position(mask_list, crop_info):
f_h = crop_info["frame_height"]
f_w = crop_info["frame_width"]
h = crop_info["height"]
w = crop_info["width"]
pos_map = crop_info["pos_map"]
for i in pos_map:
x,y = pos_map[i]
i = int(i)
m = mask_list[i]
if m is None:
continue
m = cv2.resize( m, (w,h) )
if len(m.shape) == 2:
m = m[...,None]
frame = np.zeros(shape=(f_h,f_w,m.shape[2]), dtype=np.uint8)
frame[y:y+h,x:x+w,...] = m
mask_list[i] = frame
return mask_list
def load_frame_list(frame_dir, frame_array_list, crop_info):
frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
for f in frame_list:
cur = int(Path(f).stem)
frame_array_list[cur] = np.asarray(Image.open(f))
if not crop_info:
logger.info(f"crop_info is not exists -> skip restore")
return frame_array_list
for i,f in enumerate(frame_array_list):
if f is None:
continue
frame_array_list[i] = f
frame_array_list = restore_position(frame_array_list, crop_info)
return frame_array_list
def create_fg(mask_token, frame_dir, output_dir, output_mask_dir, masked_area_list,
box_threshold=0.3,
text_threshold=0.25,
bg_color=(0,255,0),
mask_padding=0,
groundingdino_config="config/GroundingDINO/GroundingDINO_SwinB_cfg.py",
groundingdino_checkpoint="data/models/GroundingDINO/groundingdino_swinb_cogcoor.pth",
sam_checkpoint="data/models/SAM/sam_hq_vit_l.pth",
device="cuda",
):
frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
with torch.no_grad():
predictor = MaskPredictor(
model_config_path=groundingdino_config,
model_checkpoint_path=groundingdino_checkpoint,
device=device,
sam_checkpoint=sam_checkpoint,
box_threshold=box_threshold,
text_threshold=text_threshold,
)
if mask_padding != 0:
kernel = np.ones((abs(mask_padding),abs(mask_padding)),np.uint8)
kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
for i, frame in tqdm(enumerate(frame_list),total=len(frame_list), desc=f"creating mask from {mask_token=}"):
frame = Path(frame)
file_name = frame.name
cur_frame_no = int(frame.stem)
img = Image.open(frame)
mask_array = predictor(img, mask_token)
mask_array = mask_array[0].astype(np.uint8) * 255
if mask_padding < 0:
mask_array = cv2.erode(mask_array.astype(np.uint8),kernel,iterations = 1)
elif mask_padding > 0:
mask_array = cv2.dilate(mask_array.astype(np.uint8),kernel,iterations = 1)
mask_array = cv2.morphologyEx(mask_array.astype(np.uint8), cv2.MORPH_OPEN, kernel2)
mask_array = cv2.GaussianBlur(mask_array, (7, 7), sigmaX=3, sigmaY=3, borderType=cv2.BORDER_DEFAULT)
if masked_area_list[cur_frame_no] is not None:
masked_area_list[cur_frame_no] = np.where(masked_area_list[cur_frame_no] > mask_array[None,...], masked_area_list[cur_frame_no], mask_array[None,...])
#masked_area_list[cur_frame_no] = masked_area_list[cur_frame_no] | mask_array[None,...]
else:
masked_area_list[cur_frame_no] = mask_array[None,...]
if output_mask_dir:
#mask_array2 = mask_array.astype(np.uint8).clip(0,1)
#mask_array2 *= 255
Image.fromarray(mask_array).save( output_mask_dir / file_name )
img_array = np.asarray(img).copy()
if bg_color is not None:
img_array[mask_array == 0] = bg_color
img = Image.fromarray(img_array)
img.save( output_dir / file_name )
return masked_area_list
def dilate_mask(masked_area_list, flow_mask_dilates=8, mask_dilates=5):
kernel = np.ones((flow_mask_dilates,flow_mask_dilates),np.uint8)
flow_masks = [ cv2.dilate(mask[0].astype(np.uint8),kernel,iterations = 1) for mask in masked_area_list ]
flow_masks = [ Image.fromarray(mask * 255) for mask in flow_masks ]
kernel = np.ones((mask_dilates,mask_dilates),np.uint8)
dilated_masks = [ cv2.dilate(mask[0].astype(np.uint8),kernel,iterations = 1) for mask in masked_area_list ]
dilated_masks = [ Image.fromarray(mask * 255) for mask in dilated_masks ]
return flow_masks, dilated_masks
# adapted from https://github.com/sczhou/ProPainter/blob/main/inference_propainter.py
def resize_frames(frames, size=None):
if size is not None:
out_size = size
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
frames = [f.resize(process_size) for f in frames]
else:
out_size = frames[0].size
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
if not out_size == process_size:
frames = [f.resize(process_size) for f in frames]
return frames, process_size, out_size
def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
ref_index = []
if ref_num == -1:
for i in range(0, length, ref_stride):
if i not in neighbor_ids:
ref_index.append(i)
else:
start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
for i in range(start_idx, end_idx, ref_stride):
if i not in neighbor_ids:
if len(ref_index) > ref_num:
break
ref_index.append(i)
return ref_index
def create_bg(frame_dir, output_dir, masked_area_list,
use_half = True,
raft_iter = 20,
subvideo_length=80,
neighbor_length=10,
ref_stride=10,
device="cuda",
low_vram = False,
):
import sys
repo_path = Path("src/animatediff/repo/ProPainter").absolute()
repo_path = str(repo_path)
sys.path.append(repo_path)
from animatediff.repo.ProPainter.core.utils import to_tensors
from animatediff.repo.ProPainter.model.modules.flow_comp_raft import \
RAFT_bi
from animatediff.repo.ProPainter.model.propainter import InpaintGenerator
from animatediff.repo.ProPainter.model.recurrent_flow_completion import \
RecurrentFlowCompleteNet
from animatediff.repo.ProPainter.utils.download_util import \
load_file_from_url
pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
model_dir = Path("data/models/ProPainter")
model_dir.mkdir(parents=True, exist_ok=True)
frame_list = sorted(glob.glob( os.path.join(frame_dir, "[0-9]*.png"), recursive=False))
frames = [Image.open(f) for f in frame_list]
if low_vram:
org_size = frames[0].size
_w, _h = frames[0].size
if max(_w, _h) > 512:
_w = int(_w * 0.75)
_h = int(_h * 0.75)
frames, size, out_size = resize_frames(frames, (_w, _h))
out_size = org_size
masked_area_list = [m[0] for m in masked_area_list]
masked_area_list = [cv2.resize(m.astype(np.uint8), dsize=size) for m in masked_area_list]
masked_area_list = [ m>127 for m in masked_area_list]
masked_area_list = [m[None,...] for m in masked_area_list]
else:
frames, size, out_size = resize_frames(frames, None)
masked_area_list = [ m>127 for m in masked_area_list]
w, h = size
flow_masks,masks_dilated = dilate_mask(masked_area_list)
frames_inp = [np.array(f).astype(np.uint8) for f in frames]
frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
flow_masks = to_tensors()(flow_masks).unsqueeze(0)
masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device)
##############################################
# set up RAFT and flow competition model
##############################################
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'),
model_dir=model_dir, progress=True, file_name=None)
fix_raft = RAFT_bi(ckpt_path, device)
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'),
model_dir=model_dir, progress=True, file_name=None)
fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path)
for p in fix_flow_complete.parameters():
p.requires_grad = False
fix_flow_complete.to(device)
fix_flow_complete.eval()
##############################################
# set up ProPainter model
##############################################
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'),
model_dir=model_dir, progress=True, file_name=None)
model = InpaintGenerator(model_path=ckpt_path).to(device)
model.eval()
##############################################
# ProPainter inference
##############################################
video_length = frames.size(1)
logger.info(f'\nProcessing: [{video_length} frames]...')
with torch.no_grad():
# ---- compute flow ----
if max(w,h) <= 640:
short_clip_len = 12
elif max(w,h) <= 720:
short_clip_len = 8
elif max(w,h) <= 1280:
short_clip_len = 4
else:
short_clip_len = 2
# use fp32 for RAFT
if frames.size(1) > short_clip_len:
gt_flows_f_list, gt_flows_b_list = [], []
for f in range(0, video_length, short_clip_len):
end_f = min(video_length, f + short_clip_len)
if f == 0:
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=raft_iter)
else:
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=raft_iter)
gt_flows_f_list.append(flows_f)
gt_flows_b_list.append(flows_b)
torch.cuda.empty_cache()
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
gt_flows_bi = (gt_flows_f, gt_flows_b)
else:
gt_flows_bi = fix_raft(frames, iters=raft_iter)
torch.cuda.empty_cache()
if use_half:
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
fix_flow_complete = fix_flow_complete.half()
model = model.half()
# ---- complete flow ----
flow_length = gt_flows_bi[0].size(1)
if flow_length > subvideo_length:
pred_flows_f, pred_flows_b = [], []
pad_len = 5
for f in range(0, flow_length, subvideo_length):
s_f = max(0, f - pad_len)
e_f = min(flow_length, f + subvideo_length + pad_len)
pad_len_s = max(0, f) - s_f
pad_len_e = e_f - min(flow_length, f + subvideo_length)
pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow(
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
flow_masks[:, s_f:e_f+1])
pred_flows_bi_sub = fix_flow_complete.combine_flow(
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
pred_flows_bi_sub,
flow_masks[:, s_f:e_f+1])
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
torch.cuda.empty_cache()
pred_flows_f = torch.cat(pred_flows_f, dim=1)
pred_flows_b = torch.cat(pred_flows_b, dim=1)
pred_flows_bi = (pred_flows_f, pred_flows_b)
else:
pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
torch.cuda.empty_cache()
# ---- image propagation ----
masked_frames = frames * (1 - masks_dilated)
subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
if video_length > subvideo_length_img_prop:
updated_frames, updated_masks = [], []
pad_len = 10
for f in range(0, video_length, subvideo_length_img_prop):
s_f = max(0, f - pad_len)
e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
pad_len_s = max(0, f) - s_f
pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f],
pred_flows_bi_sub,
masks_dilated[:, s_f:e_f],
'nearest')
updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
torch.cuda.empty_cache()
updated_frames = torch.cat(updated_frames, dim=1)
updated_masks = torch.cat(updated_masks, dim=1)
else:
b, t, _, _, _ = masks_dilated.size()
prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
updated_masks = updated_local_masks.view(b, t, 1, h, w)
torch.cuda.empty_cache()
ori_frames = frames_inp
comp_frames = [None] * video_length
neighbor_stride = neighbor_length // 2
if video_length > subvideo_length:
ref_num = subvideo_length // ref_stride
else:
ref_num = -1
# ---- feature propagation + transformer ----
for f in tqdm(range(0, video_length, neighbor_stride)):
neighbor_ids = [
i for i in range(max(0, f - neighbor_stride),
min(video_length, f + neighbor_stride + 1))
]
ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
with torch.no_grad():
# 1.0 indicates mask
l_t = len(neighbor_ids)
# pred_img = selected_imgs # results of image propagation
pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
pred_img = pred_img.view(-1, 3, h, w)
pred_img = (pred_img + 1) / 2
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
0, 2, 3, 1).numpy().astype(np.uint8)
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ ori_frames[idx] * (1 - binary_masks[i])
if comp_frames[idx] is None:
comp_frames[idx] = img
else:
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
torch.cuda.empty_cache()
# save each frame
for idx in range(video_length):
f = comp_frames[idx]
f = cv2.resize(f, out_size, interpolation = cv2.INTER_CUBIC)
f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
dst_img_path = output_dir.joinpath( f"{idx:08d}.png" )
cv2.imwrite(str(dst_img_path), f)
sys.path.remove(repo_path)