Vista / app.py
Leonard Bruns
Add Vista example
d323598
"""Gradio interface for Vista model."""
from __future__ import annotations
import glob
import os
import queue
import threading
import gradio as gr
import gradio_rerun
import rerun as rr
import spaces
import vista
@spaces.GPU(duration=400)
@rr.thread_local_stream("Vista")
def generate_gradio(
first_frame_file_name: str,
n_rounds: float=3,
n_steps: float=10,
height=576,
width=1024,
n_frames=25,
cfg_scale=2.5,
cond_aug=0.0,
):
global model
n_rounds = int(n_rounds)
n_steps = int(n_steps)
# Use a queue to log immediately from internals
log_queue = queue.SimpleQueue()
stream = rr.binary_stream()
blueprint = vista.generate_blueprint(n_rounds)
rr.send_blueprint(blueprint)
yield stream.read()
handle = threading.Thread(
target=vista.run_sampling,
args=[
log_queue,
first_frame_file_name,
height,
width,
n_rounds,
n_frames,
n_steps,
cfg_scale,
cond_aug,
model,
],
)
handle.start()
while True:
msg = log_queue.get()
if msg == "done":
break
else:
entity_path, entity, times = msg
rr.reset_time()
for timeline, time in times:
if isinstance(time, int):
rr.set_time_sequence(timeline, time)
else:
rr.set_time_seconds(timeline, time)
rr.log(entity_path, entity)
yield stream.read()
handle.join()
model = vista.create_model()
with gr.Blocks(css="style.css") as demo:
gr.Markdown(
"""
# Vista: A Generalizable Driving World Model with High Fidelity and Versatile Controllability
[Shenyuan Gao](https://github.com/Little-Podi), [Jiazhi Yang](https://scholar.google.com/citations?user=Ju7nGX8AAAAJ&hl=en), [Li Chen](https://scholar.google.com/citations?user=ulZxvY0AAAAJ&hl=en), [Kashyap Chitta](https://kashyap7x.github.io/), [Yihang Qiu](https://scholar.google.com/citations?user=qgRUOdIAAAAJ&hl=en), [Andreas Geiger](https://www.cvlibs.net/), [Jun Zhang](https://eejzhang.people.ust.hk/), [Hongyang Li](https://lihongyang.info/)
This is a demo of the [Vista model](https://github.com/OpenDriveLab/Vista), a driving world model that can be used to simulate a variety of driving scenarios. This demo uses [Rerun](https://rerun.io/)'s custom [gradio component](https://www.gradio.app/custom-components/gallery?id=radames%2Fgradio_rerun) to livestream the model's output and show intermediate results.
[📜technical report](https://arxiv.org/abs/2405.17398), [🎬video demos](https://vista-demo.github.io/), [🤗model weights](https://huggingface.co./OpenDriveLab/Vista)
Note that the GPU time is limited to 400 seconds per run. If you need more time, you can run the model locally or on your own server.
"""
)
first_frame = gr.Image(sources="upload", type="filepath")
example_dir_path = os.path.join(os.path.dirname(__file__), "example_images")
example_file_paths = sorted(glob.glob(os.path.join(example_dir_path, "*.*")))
example_gallery = gr.Examples(
examples=example_file_paths,
inputs=first_frame,
cache_examples=False,
)
btn = gr.Button("Generate video")
num_rounds = gr.Slider(
label="Segments",
info="Number of 25 frame segments to generate. Higher values lead to longer videos. Try to keep the product of segments and steps below 30 to avoid running out of time.",
minimum=1,
maximum=5,
value=2,
step=1
)
num_steps = gr.Slider(
label="Diffusion Steps",
info="Number of diffusion steps per segment. Higher values lead to more detailed videos. Try to keep the product of segments and steps below 30 to avoid running out of time.",
minimum=1,
maximum=50,
value=15,
step=1
)
with gr.Row():
viewer = gradio_rerun.Rerun(streaming=True)
btn.click(
generate_gradio,
inputs=[first_frame, num_rounds, num_steps],
outputs=[viewer],
)
demo.launch()