|
import cv2 |
|
import glob |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
|
|
from networks.amts import Model as AMTS |
|
from networks.amtl import Model as AMTL |
|
from networks.amtg import Model as AMTG |
|
from utils import img2tensor, tensor2img, InputPadder |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model_dict = { |
|
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG |
|
} |
|
|
|
|
|
def vid2vid(model_type, video, iters): |
|
model = model_dict[model_type]() |
|
model.to(device) |
|
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth') |
|
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
model.load_state_dict(ckpt['state_dict']) |
|
model.eval() |
|
vcap = cv2.VideoCapture(video) |
|
ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS) |
|
inputs = [] |
|
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
padder = InputPadder((h, w), 16) |
|
while True: |
|
ret, frame = vcap.read() |
|
if ret is False: |
|
break |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame_t = img2tensor(frame).to(device) |
|
frame_t = padder.pad(frame_t) |
|
inputs.append(frame_t) |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
for i in range(iters): |
|
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') |
|
outputs = [inputs[0]] |
|
for in_0, in_1 in zip(inputs[:-1], inputs[1:]): |
|
with torch.no_grad(): |
|
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred'] |
|
imgt_pred = padder.unpad(imgt_pred) |
|
in_1 = padder.unpad(in_1) |
|
outputs += [imgt_pred, in_1] |
|
inputs = outputs |
|
|
|
out_path = 'results' |
|
size = outputs[0].shape[2:][::-1] |
|
writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4', |
|
cv2.VideoWriter_fourcc(*'mp4v'), |
|
ori_frame_rate * 2 ** iters, size) |
|
for i, imgt_pred in enumerate(outputs): |
|
imgt_pred = tensor2img(imgt_pred) |
|
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) |
|
writer.write(imgt_pred) |
|
writer.release() |
|
return 'results/demo_vfi.mp4' |
|
|
|
|
|
def demo_vid(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('## Image Demo') |
|
with gr.Row(): |
|
gr.HTML( |
|
""" |
|
<div style="text-align: left; auto;"> |
|
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem"> |
|
Description: With 2 input images, you can generate a short video from them. |
|
</h3> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video = gr.Video(label='Video Input') |
|
with gr.Column(): |
|
result = gr.Video(label="Generated Video") |
|
with gr.Accordion('Advanced options', open=False): |
|
ratio = gr.Slider(label='Multiple Ratio', |
|
minimum=1, |
|
maximum=4, |
|
value=2, |
|
step=1) |
|
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'], |
|
label='Model Select', |
|
value='AMT-S') |
|
run_button = gr.Button(label='Run') |
|
inputs = [ |
|
model_type, |
|
video, |
|
ratio, |
|
] |
|
|
|
gr.Examples(examples=glob.glob("examples/*.mp4"), |
|
inputs=video, |
|
label='Example images (drag them to input windows)', |
|
run_on_click=False, |
|
) |
|
|
|
run_button.click(fn=vid2vid, |
|
inputs=inputs, |
|
outputs=result,) |
|
return demo |