VideoLLaMA3 / app.py
lixin4ever's picture
Update demo (#4)
c03cca7 verified
raw
history blame
7.48 kB
import os
import os.path as osp
import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
HEADER = ("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
<img src="https://github.com/DAMO-NLP-SG/VideoLLaMA3/blob/main/assets/logo.png?raw=true" alt="VideoLLaMA 3 πŸ”₯πŸš€πŸ”₯" style="max-width: 120px; height: auto;">
</a>
<div>
<h1>VideoLLaMA 3: Frontier Multimodal Foundation Models for Video Understanding</h1>
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or πŸ’– on this space.</h5>
</div>
</div>
<div style="display: flex; justify-content: center; margin-top: 10px;">
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9C276A' style="margin-right: 5px;"></a>
<a href="https://arxiv.org/pdf/2501.13106"><img src="https://img.shields.io/badge/Arxiv-2501.13106-AD1C18" style="margin-right: 5px;"></a>
<a href="https://huggingface.co./collections/DAMO-NLP-SG/videollama3-678cdda9281a0e32fe79af15"><img src="https://img.shields.io/badge/πŸ€—-Checkpoints-ED5A22.svg" style="margin-right: 5px;"></a>
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3/stargazers"><img src="https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA3.svg?style=social"></a>
</div>
""")
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
"DAMO-NLP-SG/VideoLLaMA3-7B",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.to(device)
processor = AutoProcessor.from_pretrained("DAMO-NLP-SG/VideoLLaMA3-7B", trust_remote_code=True)
example_dir = "./examples"
image_formats = ("png", "jpg", "jpeg")
video_formats = ("mp4",)
image_examples, video_examples = [], []
if example_dir is not None:
example_files = [
osp.join(example_dir, f) for f in os.listdir(example_dir)
]
for example_file in example_files:
if example_file.endswith(image_formats):
image_examples.append([example_file])
elif example_file.endswith(video_formats):
video_examples.append([example_file])
def _on_video_upload(messages, video):
if video is not None:
# messages.append({"role": "user", "content": gr.Video(video)})
messages.append({"role": "user", "content": {"path": video}})
return messages, None
def _on_image_upload(messages, image):
if image is not None:
# messages.append({"role": "user", "content": gr.Image(image)})
messages.append({"role": "user", "content": {"path": image}})
return messages, None
def _on_text_submit(messages, text):
messages.append({"role": "user", "content": text})
return messages, ""
@spaces.GPU(duration=120)
def _predict(messages, input_text, do_sample, temperature, top_p, max_new_tokens,
fps, max_frames):
if len(input_text) > 0:
messages.append({"role": "user", "content": input_text})
new_messages = []
contents = []
for message in messages:
if message["role"] == "assistant":
if len(contents):
new_messages.append({"role": "user", "content": contents})
contents = []
new_messages.append(message)
elif message["role"] == "user":
if isinstance(message["content"], str):
contents.append(message["content"])
else:
media_path = message["content"][0]
if media_path.endswith(video_formats):
contents.append({"type": "video", "video": {"video_path": media_path, "fps": fps, "max_frames": max_frames}})
elif media_path.endswith(image_formats):
contents.append({"type": "image", "image": {"image_path": media_path}})
else:
raise ValueError(f"Unsupported media type: {media_path}")
if len(contents):
new_messages.append({"role": "user", "content": contents})
if len(new_messages) == 0 or new_messages[-1]["role"] != "user":
return messages
generation_config = {
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens
}
inputs = processor(
conversation=new_messages,
add_system_prompt=True,
add_generation_prompt=True,
return_tensors="pt"
)
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
**generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
messages.append({"role": "assistant", "content": ""})
for token in streamer:
messages[-1]['content'] += token
yield messages
with gr.Blocks() as interface:
gr.Markdown(HEADER)
with gr.Row():
chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=710)
with gr.Column():
with gr.Tab(label="Input"):
with gr.Row():
input_video = gr.Video(sources=["upload"], label="Upload Video")
input_image = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
if len(image_examples):
gr.Examples(image_examples, inputs=[input_image], label="Example Images")
if len(video_examples):
gr.Examples(video_examples, inputs=[input_video], label="Example Videos")
input_text = gr.Textbox(label="Input Text", placeholder="Type your message here and press enter to submit")
submit_button = gr.Button("Generate")
with gr.Tab(label="Configure"):
with gr.Accordion("Generation Config", open=True):
do_sample = gr.Checkbox(value=True, label="Do Sample")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=2048, step=1, label="Max New Tokens")
with gr.Accordion("Video Config", open=True):
fps = gr.Slider(minimum=0.0, maximum=10.0, value=1, label="FPS")
max_frames = gr.Slider(minimum=0, maximum=256, value=180, step=1, label="Max Frames")
input_video.change(_on_video_upload, [chatbot, input_video], [chatbot, input_video])
input_image.change(_on_image_upload, [chatbot, input_image], [chatbot, input_image])
input_text.submit(_on_text_submit, [chatbot, input_text], [chatbot, input_text])
submit_button.click(
_predict,
[
chatbot, input_text, do_sample, temperature, top_p, max_new_tokens,
fps, max_frames
],
[chatbot],
)
if __name__ == "__main__":
interface.launch()