File size: 4,728 Bytes
7062b81
 
 
981d2df
7062b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65775ba
 
7062b81
 
 
 
 
 
4fa279e
 
 
 
 
b5a9347
2bbc3ee
 
 
 
 
 
 
 
 
 
 
 
1947ad8
7062b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2bbc3ee
7062b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd77bb2
7062b81
 
 
 
 
cd77bb2
7062b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188b22e
7062b81
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import cv2
import glob
import torch
import numpy as np
import gradio as gr
from huggingface_hub import hf_hub_download

from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import img2tensor, tensor2img, InputPadder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
    'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}


def vid2vid(model_type, video, iters):
    model = model_dict[model_type]()
    model.to(device)
    ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
    ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    vcap = cv2.VideoCapture(video)
    ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS)
    inputs = []
    h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
    w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    if device == 'cuda':
        anchor_resolution = 1024 * 512
        anchor_memory = 1500 * 1024**2
        anchor_memory_bias = 2500 * 1024**2
        vram_avail = torch.cuda.get_device_properties(device).total_memory
    else:
        # Do not resize in cpu mode
        anchor_resolution = 8192*8192
        anchor_memory = 1
        anchor_memory_bias = 0
        vram_avail = 1
    
    scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
    scale = 1 if scale > 1 else scale
    scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
    if scale < 1:
        print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
    padding = int(16 / scale)
    padder = InputPadder((h, w), padding)
    while True:
        ret, frame = vcap.read()
        if ret is False:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_t = img2tensor(frame).to(device)
        frame_t = padder.pad(frame_t)
        inputs.append(frame_t)
    embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)

    for i in range(iters):
        print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
        outputs = [inputs[0]]
        for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
            with torch.no_grad():
                imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
            imgt_pred = padder.unpad(imgt_pred)
            in_1 = padder.unpad(in_1)
            outputs += [imgt_pred, in_1]
        inputs = outputs

    out_path = 'results'
    size = outputs[0].shape[2:][::-1]
    writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4', 
                             cv2.VideoWriter_fourcc(*'mp4v'), 
                             ori_frame_rate * 2 ** iters, size)
    for i, imgt_pred in enumerate(outputs):
        imgt_pred = tensor2img(imgt_pred)
        imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
        writer.write(imgt_pred)
    writer.release()
    return 'results/demo_vfi.mp4'

    
def demo_vid():
    with gr.Blocks() as demo:
        with gr.Row():
            gr.Markdown('## Video Demo')
        with gr.Row():
            gr.HTML(
                """
                <div style="text-align: left; auto;">
                <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
                    Description: You can increase the frame rate of the video by 2 times, 4 times, or 8 times. (The video should be less than 10 seconds.)
                </h3>
                </div>
                """)

        with gr.Row():
            with gr.Column():
                video = gr.Video(label='Video Input')
            with gr.Column():
                result = gr.Video(label="Generated Video")
                with gr.Accordion('Advanced options', open=False):
                    ratio = gr.Slider(label='Multiple Ratio',
                                     minimum=1,
                                     maximum=4,
                                     value=2,
                                     step=1)
                    model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
                                          label='Model Select',
                                          value='AMT-S')
                run_button = gr.Button(label='Run')
        inputs = [
            model_type,
            video,
            ratio,
        ]

        gr.Examples(examples=glob.glob("examples/*.mp4"),
                inputs=video,
                label='Example videos (drag them to the input window)',
                run_on_click=False,
        )

        run_button.click(fn=vid2vid,
                         inputs=inputs,
                         outputs=result,)
    return demo