AMT / demo_img.py
zzl
release
981d2df
raw
history blame
4.78 kB
import cv2
import glob
import torch
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import (
img2tensor, tensor2img,
InputPadder,
check_dim_and_resize
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}
def img2vid(model_type, img0, img1, frame_ratio, iters):
model = model_dict[model_type]()
model.to(device)
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['state_dict'])
model.eval()
img0_t = img2tensor(img0).to(device)
img1_t = img2tensor(img1).to(device)
inputs = [img0_t, img1_t]
if device == 'cuda':
anchor_resolution = 1024 * 512
anchor_memory = 1500 * 1024**2
anchor_memory_bias = 2500 * 1024**2
vram_avail = torch.cuda.get_device_properties(device).total_memory
else:
# Do not resize in cpu mode
anchor_resolution = 8192*8192
anchor_memory = 1
anchor_memory_bias = 0
vram_avail = 1
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
inputs = check_dim_and_resize(inputs)
h, w = inputs[0].shape[-2:]
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
scale = 1 if scale > 1 else scale
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
if scale < 1:
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
padding = int(16 / scale)
padder = InputPadder(inputs[0].shape, padding)
inputs = padder.pad(*inputs)
for i in range(iters):
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
outputs = [inputs[0]]
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
in_0 = in_0.to(device)
in_1 = in_1.to(device)
with torch.no_grad():
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
outputs += [imgt_pred.cpu(), in_1.cpu()]
inputs = outputs
outputs = padder.unpad(*outputs)
out_path = 'results'
size = outputs[0].shape[2:][::-1]
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
for i, imgt_pred in enumerate(outputs):
imgt_pred = tensor2img(imgt_pred)
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
writer.write(imgt_pred)
writer.release()
return 'results/demo.mp4'
def demo_img():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Image Demo')
with gr.Row():
gr.HTML(
"""
<div style="text-align: left; auto;">
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Description: With 2 input images, you can generate a short video from them.
</h3>
</div>
""")
with gr.Row():
with gr.Column():
img0 = gr.Image(label='Image0')
img1 = gr.Image(label='Image1')
with gr.Column():
result = gr.Video(label="Generated Video")
with gr.Accordion('Advanced options', open=False):
ratio = gr.Slider(label='Multiple Ratio',
minimum=4,
maximum=7,
value=6,
step=1)
frame_ratio = gr.Slider(label='Frame Ratio',
minimum=8,
maximum=64,
value=16,
step=1)
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
label='Model Select',
value='AMT-S')
run_button = gr.Button(label='Run')
inputs = [
model_type,
img0,
img1,
frame_ratio,
ratio,
]
gr.Examples(examples=glob.glob("examples/*.png"),
inputs=img0,
label='Example images (drag them to input windows)',
run_on_click=False,
)
run_button.click(fn=img2vid,
inputs=inputs,
outputs=result,)
return demo