AMT / demo_vid.py
zzl
[Release] Demo v1.0
7062b81
raw
history blame
3.89 kB
import cv2
import glob
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import img2tensor, tensor2img, InputPadder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}
def vid2vid(model_type, video, iters):
model = model_dict[model_type]()
model.to(device)
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
print(model_type)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'])
model.eval()
vcap = cv2.VideoCapture(video)
ori_frame_rate = vcap.get(cv2.CAP_PROP_FPS)
inputs = []
h = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH))
w = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
padder = InputPadder((h, w), 16)
while True:
ret, frame = vcap.read()
if ret is False:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_t = img2tensor(frame).to(device)
frame_t = padder.pad(frame_t)
inputs.append(frame_t)
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
for i in range(iters):
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
outputs = [inputs[0]]
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
with torch.no_grad():
imgt_pred = model(in_0, in_1, embt, eval=True)['imgt_pred']
imgt_pred = padder.unpad(imgt_pred)
in_1 = padder.unpad(in_1)
outputs += [imgt_pred, in_1]
inputs = outputs
out_path = 'results'
size = outputs[0].shape[2:][::-1]
writer = cv2.VideoWriter(f'{out_path}/demo_vfi.mp4',
cv2.VideoWriter_fourcc(*'mp4v'),
ori_frame_rate * 2 ** iters, size)
for i, imgt_pred in enumerate(outputs):
imgt_pred = tensor2img(imgt_pred)
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
writer.write(imgt_pred)
writer.release()
return 'results/demo_vfi.mp4'
def demo_vid():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Image Demo')
with gr.Row():
gr.HTML(
"""
<div style="text-align: left; auto;">
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Description: With 2 input images, you can generate a short video from them.
</h3>
</div>
""")
with gr.Row():
with gr.Column():
video = gr.Video(label='Video Input')
with gr.Column():
result = gr.Video(label="Generated Video")
with gr.Accordion('Advanced options', open=False):
ratio = gr.Slider(label='Multiple Ratio',
minimum=1,
maximum=4,
value=2,
step=1)
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
label='Model Select',
value='AMT-S')
run_button = gr.Button(label='Run')
inputs = [
model_type,
video,
ratio,
]
gr.Examples(examples=glob.glob("examples/*.mp4"),
inputs=video,
label='Example images (drag them to input windows)',
run_on_click=False,
)
run_button.click(fn=vid2vid,
inputs=inputs,
outputs=result,)
return demo