File size: 4,782 Bytes
7062b81 981d2df 7062b81 2bbc3ee 506e597 7062b81 77c21de 7062b81 2bbc3ee 4fa279e b5a9347 2bbc3ee 7062b81 2bbc3ee 7062b81 2bbc3ee 7062b81 2bbc3ee 7062b81 2bbc3ee 7062b81 2bbc3ee 7062b81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import cv2
import glob
import torch
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from networks.amts import Model as AMTS
from networks.amtl import Model as AMTL
from networks.amtg import Model as AMTG
from utils import (
img2tensor, tensor2img,
InputPadder,
check_dim_and_resize
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = {
'AMT-S': AMTS, 'AMT-L': AMTL, 'AMT-G': AMTG
}
def img2vid(model_type, img0, img1, frame_ratio, iters):
model = model_dict[model_type]()
model.to(device)
ckpt_path = hf_hub_download(repo_id='lalala125/AMT', filename=f'{model_type.lower()}.pth')
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['state_dict'])
model.eval()
img0_t = img2tensor(img0).to(device)
img1_t = img2tensor(img1).to(device)
inputs = [img0_t, img1_t]
if device == 'cuda':
anchor_resolution = 1024 * 512
anchor_memory = 1500 * 1024**2
anchor_memory_bias = 2500 * 1024**2
vram_avail = torch.cuda.get_device_properties(device).total_memory
else:
# Do not resize in cpu mode
anchor_resolution = 8192*8192
anchor_memory = 1
anchor_memory_bias = 0
vram_avail = 1
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
inputs = check_dim_and_resize(inputs)
h, w = inputs[0].shape[-2:]
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory)
scale = 1 if scale > 1 else scale
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
if scale < 1:
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}")
padding = int(16 / scale)
padder = InputPadder(inputs[0].shape, padding)
inputs = padder.pad(*inputs)
for i in range(iters):
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}')
outputs = [inputs[0]]
for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
in_0 = in_0.to(device)
in_1 = in_1.to(device)
with torch.no_grad():
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred']
outputs += [imgt_pred.cpu(), in_1.cpu()]
inputs = outputs
outputs = padder.unpad(*outputs)
out_path = 'results'
size = outputs[0].shape[2:][::-1]
writer = cv2.VideoWriter(f'{out_path}/demo.mp4', cv2.VideoWriter_fourcc(*'mp4v'), frame_ratio, size)
for i, imgt_pred in enumerate(outputs):
imgt_pred = tensor2img(imgt_pred)
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR)
writer.write(imgt_pred)
writer.release()
return 'results/demo.mp4'
def demo_img():
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('## Image Demo')
with gr.Row():
gr.HTML(
"""
<div style="text-align: left; auto;">
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Description: With 2 input images, you can generate a short video from them.
</h3>
</div>
""")
with gr.Row():
with gr.Column():
img0 = gr.Image(label='Image0')
img1 = gr.Image(label='Image1')
with gr.Column():
result = gr.Video(label="Generated Video")
with gr.Accordion('Advanced options', open=False):
ratio = gr.Slider(label='Multiple Ratio',
minimum=4,
maximum=7,
value=6,
step=1)
frame_ratio = gr.Slider(label='Frame Ratio',
minimum=8,
maximum=64,
value=16,
step=1)
model_type = gr.Radio(['AMT-S', 'AMT-L', 'AMT-G'],
label='Model Select',
value='AMT-S')
run_button = gr.Button(label='Run')
inputs = [
model_type,
img0,
img1,
frame_ratio,
ratio,
]
gr.Examples(examples=glob.glob("examples/*.png"),
inputs=img0,
label='Example images (drag them to input windows)',
run_on_click=False,
)
run_button.click(fn=img2vid,
inputs=inputs,
outputs=result,)
return demo |