import cv2 import glob import torch import gradio as gr from huggingface_hub import hf_hub_download from networks.amts import Model as AMTS from networks.amtl import Model as AMTL from networks.amtg import Model as AMTG from utils import img2tensor, tensor2img, InputPadder device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_dict = { 'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG } def vid2vid(model_type, video, iters): model = model_dict[model_type]() model.to(device) ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) model.load_state_dict(ckpt['state_dict']) model.eval() vcap = cv2.VideoCapture(video) ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS) inputs = [] h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) padder = InputPadder((h, w), 16) while True: ret, frame = vcap.read() if ret is False: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame_t = img2tensor(frame).to(device) frame_t = padder.pad(frame_t) inputs.append(frame_t) embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) for i in range(iters): print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') outputs = [inputs[0]] for in_0, in_1 in zip(inputs[:-1], inputs[1:]): with torch.no_grad(): imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred'] imgt_pred = padder.unpad(imgt_pred) in_1 = padder.unpad(in_1) outputs += [imgt_pred, in_1] inputs = outputs out_path = 'results' size = outputs[0].shape[2:][::-1] writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4', cv2.VideoWriter_fourcc(*'mp4v'), ori_frame_rate * 2 ** iters, size) for i, imgt_pred in enumerate(outputs): imgt_pred = tensor2img(imgt_pred) imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) writer.write(imgt_pred) writer.release() return 'results/demo_vfi.mp4' def demo_vid(): with gr.Blocks() as demo: with gr.Row(): gr.Markdown('## Image Demo') with gr.Row(): gr.HTML( """