FreeTraj / app.py
Anonymous
add spaces
47058ed
raw
history blame contribute delete
No virus
27.9 kB
import sys
import random
import gradio as gr
import matplotlib.pyplot as plt
import os
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
import spaces
sys.path.insert(0, "scripts/evaluation")
from funcs import (
batch_ddim_sampling,
batch_ddim_sampling_freetraj,
load_model_checkpoint,
)
from utils.utils import instantiate_from_config
from utils.utils_freetraj import plan_path
video_length = 16
width = 512
height = 320
MAX_KEYS = 5
ckpt_dir_512 = "checkpoints/base_512_v2"
ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt"
if not os.path.exists(ckpt_path_512):
os.makedirs(ckpt_dir_512, exist_ok=True)
hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True)
print('Model Loaded.')
def check_move(trajectory, video_length=16):
traj_len = len(trajectory)
if traj_len < 2:
return False
prev_pos = trajectory[0]
for i in range(1, traj_len):
cur_pos = trajectory[i]
if cur_pos[0] > video_length - 1:
return False
if (cur_pos[0] - prev_pos[0]) * ((cur_pos[1] - prev_pos[1]) ** 2 + (cur_pos[2] - prev_pos[2]) ** 2) ** 0.5 < 0.02:
print("Too small movement, please use ori mode.")
return False
prev_pos = cur_pos
return True
def check(radio_mode):
if radio_mode == 'ori':
video_path = "output.mp4"
video_bbox_path = "output.mp4"
else:
video_path = "output_freetraj.mp4"
video_bbox_path = "output_freetraj_bbox.mp4"
return video_path, video_bbox_path
def infer(*user_args):
prompt_in = user_args[0]
target_indices = user_args[1]
ddim_edit = user_args[2]
seed = user_args[3]
ddim_steps = user_args[4]
unconditional_guidance_scale = user_args[5]
video_fps = user_args[6]
save_fps = user_args[7]
height_ratio = user_args[8]
width_ratio = user_args[9]
radio_mode = user_args[10]
dropdown_diy = user_args[11]
frame_indices = user_args[-3 * MAX_KEYS: -2 * MAX_KEYS]
h_positions = user_args[-2 * MAX_KEYS: -MAX_KEYS]
w_positions = user_args[-MAX_KEYS:]
print(user_args)
if radio_mode == 'ori':
config_512 = "configs/inference_t2v_512_v2.0.yaml"
else:
config_512 = "configs/inference_t2v_freetraj_512_v2.0.yaml"
trajectory = []
for i in range(dropdown_diy):
trajectory.append([int(frame_indices[i]), h_positions[i], w_positions[i]])
trajectory.sort()
print(trajectory)
if not check_move(trajectory):
print("Error trajectory.")
input_traj = []
h_remain = 1 - height_ratio
w_remain = 1 - width_ratio
for i in trajectory:
h_relative = i[1] * h_remain
w_relative = i[2] * w_remain
input_traj.append([i[0], h_relative, h_relative+height_ratio, w_relative, w_relative+width_ratio])
if len(target_indices) < 1:
indices_list = [1, 2]
else:
indices_list = target_indices.split(',')
idx_list = []
for i in indices_list:
idx_list.append(int(i))
config_512 = OmegaConf.load(config_512)
model_config_512 = config_512.pop("model", OmegaConf.create())
args = argparse.Namespace(
mode="base",
savefps=save_fps,
n_samples=1,
ddim_steps=ddim_steps,
ddim_eta=0.0,
bs=1,
fps=video_fps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_guidance_scale_temporal=None,
cond_input=None,
prompt_in = prompt_in,
seed = seed,
ddim_edit = ddim_edit,
model_config_512 = model_config_512,
idx_list = idx_list,
input_traj = input_traj,
)
print('GPU starts')
video = infer_gpu_part(args)
print('GPU ends')
video = torch.clamp(video.float(), -1.0, 1.0)
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
if radio_mode == 'ori':
video_path = "output.mp4"
video_bbox_path = "output.mp4"
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
else:
video_path = "output_freetraj.mp4"
video_bbox_path = "output_freetraj_bbox.mp4"
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
for framesheet in video
] # [3, 1*h, n*w]
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(
video_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
BOX_SIZE_H = input_traj[0][2] - input_traj[0][1]
BOX_SIZE_W = input_traj[0][4] - input_traj[0][3]
PATHS = plan_path(input_traj)
h_len = grid.shape[1]
w_len = grid.shape[2]
sub_h = int(BOX_SIZE_H * h_len)
sub_w = int(BOX_SIZE_W * w_len)
for j in range(grid.shape[0]):
h_start = int(PATHS[j][0] * h_len)
h_end = h_start + sub_h
w_start = int(PATHS[j][2] * w_len)
w_end = w_start + sub_w
h_start = max(1, h_start)
h_end = min(h_len-1, h_end)
w_start = max(1, w_start)
w_end = min(w_len-1, w_end)
grid[j, h_start-1:h_end+1, w_start-1:w_start+2, :] = torch.ones_like(grid[j, h_start-1:h_end+1, w_start-1:w_start+2, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
grid[j, h_start-1:h_end+1, w_end-2:w_end+1, :] = torch.ones_like(grid[j, h_start-1:h_end+1, w_end-2:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
grid[j, h_start-1:h_start+2, w_start-1:w_end+1, :] = torch.ones_like(grid[j, h_start-1:h_start+2, w_start-1:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
grid[j, h_end-2:h_end+1, w_start-1:w_end+1, :] = torch.ones_like(grid[j, h_end-2:h_end+1, w_start-1:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
torchvision.io.write_video(
video_bbox_path,
grid,
fps=args.savefps,
video_codec="h264",
options={"crf": "10"},
)
return video_path, video_bbox_path
@spaces.GPU(duration=250)
def infer_gpu_part(args):
model = instantiate_from_config(args.model_config_512)
model = model.cuda()
model = load_model_checkpoint(model, ckpt_path_512)
model.eval()
if args.seed is None:
seed = int.from_bytes(os.urandom(2), "big")
else:
seed = args.seed
print(f"Using seed: {seed}")
seed_everything(seed)
## latent noise shape
h, w = height // 8, width // 8
frames = video_length
channels = model.channels
batch_size = 1
noise_shape = [batch_size, channels, frames, h, w]
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
prompts = [args.prompt_in]
text_emb = model.get_learned_conditioning(prompts)
cond = {"c_crossattn": [text_emb], "fps": fps}
## inference
if radio_mode == 'ori':
batch_samples = batch_ddim_sampling(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
args=args,
)
else:
batch_samples = batch_ddim_sampling_freetraj(
model,
cond,
noise_shape,
args.n_samples,
args.ddim_steps,
args.ddim_eta,
args.unconditional_guidance_scale,
idx_list = args.idx_list,
input_traj = args.input_traj,
args=args,
)
vid_tensor = batch_samples[0]
video = vid_tensor.detach().cpu()
return video
examples = [
["A squirrel jumping from one tree to another.",],
["A bear climbing down a tree after spotting a threat.",],
["A corgi running on the grassland on the grassland.",],
["A barrel floating in a river.",],
["A horse galloping on a street.",],
["A majestic eagle soaring high above the treetops, surveying its territory.",],
]
css = """
#col-container {max-width: 1024px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
animation: spin 1s linear infinite;
}
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
max-width: 15rem;
height: 36px;
}
div#share-btn-container > div {
flex-direction: row;
background: black;
align-items: center;
}
#share-btn-container:hover {
background-color: #060606;
}
#share-btn {
all: initial;
color: #ffffff;
font-weight: 600;
cursor:pointer;
font-family: 'IBM Plex Sans', sans-serif;
margin-left: 0.5rem !important;
padding-top: 0.5rem !important;
padding-bottom: 0.5rem !important;
right:0;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
#share-btn-container.hidden {
display: none!important;
}
img[src*='#center'] {
display: inline-block;
margin: unset;
}
.footer {
margin-bottom: 45px;
margin-top: 10px;
text-align: center;
border-bottom: 1px solid #e5e5e5;
}
.footer>p {
font-size: .8rem;
display: inline-block;
padding: 0 10px;
transform: translateY(10px);
background: white;
}
.dark .footer {
border-color: #303030;
}
.dark .footer>p {
background: #0b0f19;
}
"""
def mode_update(mode):
if mode == 'demo':
trajectories_mode = [gr.Row(visible=True), gr.Row(visible=False)]
elif mode == 'diy':
trajectories_mode = [gr.Row(visible=False), gr.Row(visible=True)]
else:
trajectories_mode = [gr.Row(visible=False), gr.Row(visible=False)]
return trajectories_mode
def keyframe_update(num):
keyframes = []
if type(num) != int:
num = 0
for i in range(num):
keyframes.append(gr.Row(visible=True))
for i in range(MAX_KEYS - num):
keyframes.append(gr.Row(visible=False))
return keyframes
def demo_update(mode):
if mode == 'topleft->bottomright':
num = 2
elif mode == 'bottomleft->topright':
num = 2
elif mode == 'topleft->bottomleft->bottomright':
num = 3
elif mode == 'bottomright->topright->topleft':
num = 3
elif mode == '"V"':
num = 4
elif mode == '"^"':
num = 4
elif mode == 'left->right->left->right':
num = 4
elif mode == 'triangle':
num = 4
else:
num = 0
return num
def demo_update_frame(mode):
frame_indices = []
if mode == 'topleft->bottomright':
num = 2
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=15))
elif mode == 'bottomleft->topright':
num = 2
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=15))
elif mode == 'topleft->bottomleft->bottomright':
num = 3
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=9))
frame_indices.append(gr.Text(value=15))
elif mode == 'bottomright->topright->topleft':
num = 3
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=6))
frame_indices.append(gr.Text(value=15))
elif mode == '"V"':
num = 4
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=7))
frame_indices.append(gr.Text(value=8))
frame_indices.append(gr.Text(value=15))
elif mode == '"^"':
num = 4
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=7))
frame_indices.append(gr.Text(value=8))
frame_indices.append(gr.Text(value=15))
elif mode == 'left->right->left->right':
num = 4
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=5))
frame_indices.append(gr.Text(value=10))
frame_indices.append(gr.Text(value=15))
elif mode == 'triangle':
num = 4
frame_indices.append(gr.Text(value=0))
frame_indices.append(gr.Text(value=5))
frame_indices.append(gr.Text(value=10))
frame_indices.append(gr.Text(value=15))
else:
num = 0
for i in range(MAX_KEYS - num):
frame_indices.append(gr.Text())
return frame_indices
def demo_update_h(mode):
h_positions = []
if mode == 'topleft->bottomright':
num = 2
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.9))
elif mode == 'bottomleft->topright':
num = 2
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.1))
elif mode == 'topleft->bottomleft->bottomright':
num = 3
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.9))
elif mode == 'bottomright->topright->topleft':
num = 3
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.1))
elif mode == '"V"':
num = 4
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.1))
elif mode == '"^"':
num = 4
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.9))
elif mode == 'left->right->left->right':
num = 4
h_positions.append(gr.Slider(value=0.5))
h_positions.append(gr.Slider(value=0.5))
h_positions.append(gr.Slider(value=0.5))
h_positions.append(gr.Slider(value=0.5))
elif mode == 'triangle':
num = 4
h_positions.append(gr.Slider(value=0.1))
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.9))
h_positions.append(gr.Slider(value=0.1))
else:
num = 0
for i in range(MAX_KEYS - num):
h_positions.append(gr.Slider())
return h_positions
def demo_update_w(mode):
w_positions = []
if mode == 'topleft->bottomright':
num = 2
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.9))
elif mode == 'bottomleft->topright':
num = 2
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.9))
elif mode == 'topleft->bottomleft->bottomright':
num = 3
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.9))
elif mode == 'bottomright->topright->topleft':
num = 3
w_positions.append(gr.Slider(value=0.9))
w_positions.append(gr.Slider(value=0.9))
w_positions.append(gr.Slider(value=0.1))
elif mode == '"V"':
num = 4
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.8/15*7 + 0.1))
w_positions.append(gr.Slider(value=0.8/15*8 + 0.1))
w_positions.append(gr.Slider(value=0.9))
elif mode == '"^"':
num = 4
w_positions.append(gr.Slider(value=0.9))
w_positions.append(gr.Slider(value=0.8/15*8 + 0.1))
w_positions.append(gr.Slider(value=0.8/15*7 + 0.1))
w_positions.append(gr.Slider(value=0.1))
elif mode == 'left->right->left->right':
num = 4
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.9))
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.9))
elif mode == 'triangle':
num = 4
w_positions.append(gr.Slider(value=0.5))
w_positions.append(gr.Slider(value=0.9))
w_positions.append(gr.Slider(value=0.1))
w_positions.append(gr.Slider(value=0.5))
else:
num = 0
for i in range(MAX_KEYS - num):
w_positions.append(gr.Slider())
return w_positions
def plot_update(*positions):
key_length = positions[-1]
frame_indices = positions[:key_length]
if type(key_length) != int or len(frame_indices) < 2:
traj_plot = gr.Plot(
label="Trajectory"
)
return traj_plot
frame_indices = [int(i) for i in frame_indices]
h_positions = positions[MAX_KEYS:MAX_KEYS+key_length]
w_positions = positions[2*MAX_KEYS:2*MAX_KEYS+key_length]
frame_indices, h_positions, w_positions = zip(*sorted(zip(frame_indices, h_positions, w_positions)))
plt.cla()
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.gca().invert_yaxis()
plt.gca().xaxis.tick_top()
plt.plot(w_positions, h_positions, linestyle='-', marker = 'o', markerfacecolor='r')
traj_plot = gr.Plot(
label="Trajectory",
value = plt
)
return traj_plot
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(
"""
<h1 style="text-align: center;">FreeTraj</h1>
<p style="text-align: center;">
Tuning-Free Trajectory Control in Video Diffusion Models
</p>
<p style="text-align: center;">
<a href="https://arxiv.org/abs/2406.16863" target="_blank"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="http://haonanqiu.com/projects/FreeTraj.html" target="_blank"><b>[Project Page]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://github.com/arthur-qiu/FreeTraj" target="_blank"><b>[Code]</b></a>
</p>
"""
)
keyframes = []
frame_indices = []
h_positions = []
w_positions = []
with gr.Row():
video_result = gr.Video(label="Video Output")
video_result_bbox = gr.Video(label="Video Output with BBox")
with gr.Group():
with gr.Row():
prompt_in = gr.Textbox(label="Prompt", placeholder="A corgi running on the grassland on the grassland.", scale = 5)
target_indices = gr.Textbox(label="Target Indices (1 for the first word, necessary!)", placeholder="1,2", scale = 2)
with gr.Row():
radio_mode = gr.Radio(label='Trajectory Mode', choices = ['demo', 'diy', 'ori'], scale = 1)
height_ratio = gr.Slider(label='Height Ratio of BBox',
minimum=0.2,
maximum=0.4,
step=0.01,
value=0.3,
scale = 1)
width_ratio = gr.Slider(label='Width Ratio of BBox',
minimum=0.2,
maximum=0.4,
step=0.01,
value=0.3,
scale = 1)
with gr.Row(visible=False) as row_demo:
dropdown_demo = gr.Dropdown(
label="Demo Trajectory",
choices= ['topleft->bottomright', 'bottomleft->topright', 'topleft->bottomleft->bottomright', 'bottomright->topright->topleft', '"V"', '"^"', 'left->right->left->right', 'triangle']
)
with gr.Row(visible=False) as row_diy:
dropdown_diy = gr.Dropdown(
label="Number of Keyframes",
choices=range(2, MAX_KEYS+1),
)
for i in range(MAX_KEYS):
with gr.Row(visible=False) as row:
text = gr.Textbox(
value=f"Keyframe #{i}",
interactive=False,
container = False,
lines = 3,
scale=1
)
frame_ids = gr.Textbox(
None,
label=f"Frame Indices #{i}",
interactive=True,
scale=2
)
h_position = gr.Slider(label='Position in Height',
minimum=0.0,
maximum=1.0,
step=0.01,
scale=2)
w_position = gr.Slider(label='Position in Width',
minimum=0.0,
maximum=1.0,
step=0.01,
scale=2)
frame_indices.append(frame_ids)
h_positions.append(h_position)
w_positions.append(w_position)
keyframes.append(row)
dropdown_demo.change(demo_update, dropdown_demo, dropdown_diy)
dropdown_diy.change(keyframe_update, dropdown_diy, keyframes)
dropdown_demo.change(demo_update_frame, dropdown_demo, frame_indices)
dropdown_demo.change(demo_update_h, dropdown_demo, h_positions)
dropdown_demo.change(demo_update_w, dropdown_demo, w_positions)
radio_mode.change(mode_update, radio_mode, [row_demo, row_diy])
traj_plot = gr.Plot(
label="Trajectory"
)
h_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
h_positions[1].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
h_positions[2].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
h_positions[3].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
h_positions[4].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
w_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
w_positions[1].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
w_positions[2].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
w_positions[3].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
w_positions[4].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
with gr.Row():
with gr.Accordion('Useful FreeTraj Parameters (feel free to adjust these parameters based on your prompt): ', open=True):
with gr.Row():
ddim_edit = gr.Slider(label='Editing Steps (larger for better control while losing some quality)',
minimum=0,
maximum=12,
step=1,
value=6)
seed = gr.Slider(label='Random Seed',
minimum=0,
maximum=10000,
step=1,
value=123)
with gr.Row():
with gr.Accordion('Useless FreeTraj Parameters (mostly no need to adjust): ', open=False):
with gr.Row():
ddim_steps = gr.Slider(label='DDIM Steps',
minimum=5,
maximum=50,
step=1,
value=50)
unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale',
minimum=1.0,
maximum=20.0,
step=0.1,
value=12.0)
with gr.Row():
video_fps = gr.Slider(label='Video FPS (larger for quicker motion)',
minimum=8,
maximum=36,
step=4,
value=16)
save_fps = gr.Slider(label='Save FPS',
minimum=1,
maximum=30,
step=1,
value=10)
with gr.Row():
submit_btn = gr.Button("Generate", variant='primary')
with gr.Row():
check_btn = gr.Button("Check Existing Results (in case of the connection lost)", variant='secondary')
with gr.Row():
gr.Examples(label='Sample Prompts', examples=examples, inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions])
demo_list = ['0026_0_0.4_0.4.gif', '0047_1_0.4_0.3.gif', '0051_1_0.4_0.4.gif']
demo_pick = random.randint(0, len(demo_list) - 1)
with gr.Row():
for i in range(len(demo_list)):
gr.Image(show_label = False, show_download_button = False, value='assets/' + demo_list[i])
with gr.Row():
gr.Markdown(
"""
<h2 style="text-align: center;">Hints</h2>
<p style="text-align: center;">
1. Choose trajectory mode <b>"ori"</b> to see whether the prompt works on the pre-trained model.
</p>
<p style="text-align: center;">
2. Adjust the prompt or random seed to get a qualified video.
</p>
<p style="text-align: center;">
3. Choose trajectory mode <b>"demo"</b> to see whether <b>FreeTraj</b> works or not.
</p>
<p style="text-align: center;">
4. Choose trajectory mode <b>"diy"</b> to plan new trajectory. It may fail in some extreme cases.
</p>
"""
)
submit_btn.click(fn=infer,
inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions],
outputs=[video_result, video_result_bbox],
api_name="generate")
check_btn.click(fn=check,
inputs=[radio_mode],
outputs=[video_result, video_result_bbox],
api_name="check")
demo.queue(max_size=8).launch(show_api=True)