Vadhwid / app.py
QinOwen
change-weights-path
869596d
import gradio as gr
import os
import spaces
import sys
from copy import deepcopy
sys.path.append('./VADER-VideoCrafter/scripts/main')
sys.path.append('./VADER-VideoCrafter/scripts')
sys.path.append('./VADER-VideoCrafter')
from train_t2v_lora import main_fn, setup_model
examples = [
["Fairy and Magical Flowers: A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic',
8, 901, 384, 512, 12.0, 25, 1.0, 24, 10],
["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.",
'huggingface-hps-aesthetic', 8, 208, 384, 512, 12.0, 25, 1.0, 24, 10],
["A raccoon playing a guitar under a blossoming cherry tree.",
'huggingface-hps-aesthetic', 8, 180, 384, 512, 12.0, 25, 1.0, 24, 10],
["A raccoon playing an electric bass in a garage band setting.",
'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
["A talking bird with shimmering feathers and a melodious voice finds a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.",
"huggingface-pickscore", 16, 200, 384, 512, 12.0, 25, 1.0, 24, 10],
["A snow princess stands on the balcony of her ice castle, her hair adorned with delicate snowflakes, overlooking her serene realm.",
"huggingface-pickscore", 16, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
"huggingface-pickscore", 16, 800, 384, 512, 12.0, 25, 1.0, 24, 10],
]
model = setup_model()
@spaces.GPU(duration=180)
def gradio_main_fn(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
frames, savefps):
global model
if model is None:
return "Model is not loaded. Please load the model first."
video_path = main_fn(prompt=prompt,
lora_model=lora_model,
lora_rank=int(lora_rank),
seed=int(seed),
height=int(height),
width=int(width),
unconditional_guidance_scale=float(unconditional_guidance_scale),
ddim_steps=int(ddim_steps),
ddim_eta=float(ddim_eta),
frames=int(frames),
savefps=int(savefps),
model=deepcopy(model))
return video_path
def reset_fn():
return ("A brown dog eagerly eats from a bowl in a kitchen.",
200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore")
def update_lora_rank(lora_model):
if lora_model == "huggingface-pickscore":
return gr.update(value=16)
elif lora_model == "huggingface-hps-aesthetic":
return gr.update(value=8)
else: # "Base Model"
return gr.update(value=8)
def update_dropdown(lora_rank):
if lora_rank == 16:
return gr.update(value="huggingface-pickscore")
elif lora_rank == 8:
return gr.update(value="huggingface-hps-aesthetic")
else: # 0
return gr.update(value="Base Model")
custom_css = """
#centered {
display: flex;
justify-content: center;
width: 60%;
margin: 0 auto;
}
.column-centered {
display: flex;
flex-direction: column;
align-items: center;
width: 60%;
}
#image-upload {
flex-grow: 1;
}
#params .tabs {
display: flex;
flex-direction: column;
flex-grow: 1;
}
#params .tabitem[style="display: block;"] {
flex-grow: 1;
display: flex !important;
}
#params .gap {
flex-grow: 1;
}
#params .form {
flex-grow: 1 !important;
}
#params .form > :last-child{
flex-grow: 1;
}
"""
with gr.Blocks(css=custom_css) as demo:
with gr.Row():
with gr.Column():
gr.HTML(
"""
<h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'>
Video Diffusion Alignment via Reward Gradient
</h1>
"""
)
gr.HTML(
"""
<style>
body {
font-family: Arial, sans-serif;
text-align: center;
margin: 50px;
}
a {
text-decoration: none !important;
color: black !important;
}
</style>
<body>
<div style="font-size: 1.4em; margin-bottom: 0.5em; ">
<a href="https://mihirp1998.github.io">Mihir Prabhudesai</a><sup>*</sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://russellmendonca.github.io/">Russell Mendonca</a><sup>*</sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
<a href="mailto: [email protected]">Zheyang Qin</a><sup>*</sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://www.cs.cmu.edu/~katef/">Katerina Fragkiadaki</a><sup></sup>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://www.cs.cmu.edu/~dpathak/">Deepak Pathak</a><sup></sup>
</div>
<div style="font-size: 1.3em; font-style: italic;">
Carnegie Mellon University
</div>
</body>
"""
)
gr.HTML(
"""
<head>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
<style>
.button-container {
display: flex;
justify-content: center;
gap: 10px;
margin-top: 10px;
}
.button-container a {
display: inline-flex;
align-items: center;
padding: 10px 20px;
border-radius: 30px;
border: 1px solid #ccc;
text-decoration: none;
color: #333 !important;
font-size: 16px;
text-decoration: none !important;
}
.button-container a i {
margin-right: 8px;
}
</style>
</head>
<div class="button-container">
<a href="https://arxiv.org/abs/2407.08737" class="btn btn-outline-primary">
<i class="fa-solid fa-file-pdf"></i> Paper
</a>
<a href="https://vader-vid.github.io/" class="btn btn-outline-danger">
<i class="fa-solid fa-video"></i> Website
<a href="https://github.com/mihirp1998/VADER" class="btn btn-outline-secondary">
<i class="fa-brands fa-github"></i> Code
</a>
</div>
"""
)
with gr.Row(elem_id="centered"):
with gr.Column(elem_id="params"):
lora_model = gr.Dropdown(
label="VADER Model",
choices=["huggingface-pickscore", "huggingface-hps-aesthetic"],
value="huggingface-pickscore"
)
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
value="A brown dog eagerly eats from a bowl in a kitchen.")
run_btn = gr.Button("Run Inference")
with gr.Column():
output_video = gr.Video(elem_id="image-upload")
with gr.Row(elem_id="centered"):
with gr.Column():
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
with gr.Row():
height = gr.Slider(minimum=0, maximum=512, label="Height", step = 16, value=384)
width = gr.Slider(minimum=0, maximum=512, label="Width", step = 16, value=512)
with gr.Row():
frames = gr.Slider(minimum=0, maximum=50, label="Frames", step = 1, value=24)
savefps = gr.Slider(minimum=0, maximum=30, label="Save FPS", step = 1, value=10)
with gr.Row():
DDIM_Steps = gr.Slider(minimum=0, maximum=50, label="DDIM Steps", step = 1, value=25)
unconditional_guidance_scale = gr.Slider(minimum=0, maximum=50, label="Guidance Scale", step = 0.1, value=12.0)
DDIM_Eta = gr.Slider(minimum=0, maximum=1, label="DDIM Eta", step = 0.01, value=1.0)
# reset button
reset_btn = gr.Button("Reset")
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
run_btn.click(fn=gradio_main_fn,
inputs=[prompt, lora_model, lora_rank,
seed, height, width, unconditional_guidance_scale,
DDIM_Steps, DDIM_Eta, frames, savefps],
outputs=output_video
)
lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank)
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
gr.Examples(examples=examples,
inputs=[prompt, lora_model, lora_rank, seed,
height, width, unconditional_guidance_scale,
DDIM_Steps, DDIM_Eta, frames, savefps],
outputs=output_video,
fn=gradio_main_fn,
run_on_click=False,
cache_examples="lazy",
)
demo.launch(share=True)