Spaces:
Running
on
Zero
Not Appropriately Made for API call
import os
import os.path as osp
import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
HEADER = ("""
""")device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
"DAMO-NLP-SG/VideoLLaMA3-7B",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.to(device)
processor = AutoProcessor.from_pretrained("DAMO-NLP-SG/VideoLLaMA3-7B", trust_remote_code=True)
example_dir = "./examples"
image_formats = ("png", "jpg", "jpeg")
video_formats = ("mp4",)
image_examples, video_examples = [], []
if example_dir is not None:
example_files = [
osp.join(example_dir, f) for f in os.listdir(example_dir)
]
for example_file in example_files:
if example_file.endswith(image_formats):
image_examples.append([example_file])
elif example_file.endswith(video_formats):
video_examples.append([example_file])
################################################
FIX #1: Make sure we store the file path in a list
################################################
def _on_video_upload(messages, video):
"""
video will be a dict like:
{
"video": "/tmp/path_to_uploaded.mp4",
"subtitles": None
}
We'll store video["video"] inside a list, so _predict can read it.
"""
if video is not None:
# store the path as a list
video_path = video["video"]
messages.append({"role": "user", "content": [video_path]})
return messages, None
################################################
FIX #2: Same logic for images; store in a list
################################################
def _on_image_upload(messages, image):
"""
image will be a dict like:
{
"path": "/tmp/path_to_uploaded.png",
"name": "some_filename.png",
"orig_name": "some_filename.png",
...
}
We'll store image["path"] inside a list.
"""
if image is not None:
image_path = image["path"]
messages.append({"role": "user", "content": [image_path]})
return messages, None
def _on_text_submit(messages, text):
messages.append({"role": "user", "content": text})
return messages, ""
@spaces.GPU(duration=120)
def _predict(messages, input_text, do_sample, temperature, top_p, max_new_tokens,
fps, max_frames):
# If there's fresh text from the user
if len(input_text) > 0:
messages.append({"role": "user", "content": input_text})
new_messages = []
contents = []
for message in messages:
if message["role"] == "assistant":
# If we have collected some user content, add it before assistant
if len(contents):
new_messages.append({"role": "user", "content": contents})
contents = []
new_messages.append(message)
elif message["role"] == "user":
# If message is pure text, accumulate it
if isinstance(message["content"], str):
contents.append(message["content"])
else:
# Expect a list of file paths from _on_video_upload/_on_image_upload
for media_path in message["content"]:
if media_path.endswith(video_formats):
contents.append({
"type": "video",
"video": {
"video_path": media_path,
"fps": fps,
"max_frames": max_frames
}
})
elif media_path.endswith(image_formats):
contents.append({
"type": "image",
"image": {
"image_path": media_path
}
})
else:
raise ValueError(f"Unsupported media type: {media_path}")
if len(contents):
new_messages.append({"role": "user", "content": contents})
# If for some reason no user messages, just return
if len(new_messages) == 0 or new_messages[-1]["role"] != "user":
return messages
generation_config = {
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens
}
inputs = processor(
conversation=new_messages,
add_system_prompt=True,
add_generation_prompt=True,
return_tensors="pt"
)
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
**generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
messages.append({"role": "assistant", "content": ""})
for token in streamer:
messages[-1]['content'] += token
yield messages
with gr.Blocks() as interface:
gr.HTML(HEADER)
with gr.Row():
chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=835)
with gr.Column():
with gr.Tab(label="Input"):
with gr.Row():
input_video = gr.Video(sources=["upload"], label="Upload Video")
input_image = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
input_text = gr.Textbox(
label="Input Text",
placeholder="Type your message here and press enter to submit"
)
submit_button = gr.Button("Generate")
gr.Examples(
examples=[
[f"examples/bear.mp4", "What is unusual in the video?"],
[f"examples/dog.mp4", "Please describe the video in detail."],
[f"examples/exercise.mp4", "What is the man doing in the video?"],
],
inputs=[input_video, input_text],
label="Video examples"
)
with gr.Tab(label="Configure"):
with gr.Accordion("Generation Config", open=True):
do_sample = gr.Checkbox(value=True, label="Do Sample")
temperature = gr.Slider(
minimum=0.0, maximum=1.0, value=0.2, label="Temperature"
)
top_p = gr.Slider(
minimum=0.0, maximum=1.0, value=0.9, label="Top P"
)
max_new_tokens = gr.Slider(
minimum=0, maximum=4096, value=2048, step=1, label="Max New Tokens"
)
with gr.Accordion("Video Config", open=True):
fps = gr.Slider(
minimum=0.0, maximum=10.0, value=1, label="FPS"
)
max_frames = gr.Slider(
minimum=0, maximum=256, value=180, step=1, label="Max Frames"
)
# NOTE: the client will call these 3 endpoints:
# 1) /_on_video_upload
# 2) /_on_image_upload
# 3) /_on_text_submit
# 4) /_predict
# They are triggered by Gradio events below:
# Trigger for uploading a video
input_video.change(
_on_video_upload,
[chatbot, input_video],
[chatbot, input_video],
api_name="/_on_video_upload"
)
# Trigger for uploading an image
input_image.change(
_on_image_upload,
[chatbot, input_image],
[chatbot, input_image],
api_name="/_on_image_upload"
)
# Trigger for text submission (Enter key in the textbox)
input_text.submit(
_on_text_submit,
[chatbot, input_text],
[chatbot, input_text],
api_name="/_on_text_submit"
)
# Trigger for "Generate" button => calls _predict
submit_button.click(
_predict,
[
chatbot,
input_text,
do_sample,
temperature,
top_p,
max_new_tokens,
fps,
max_frames
],
[chatbot],
api_name="/_predict"
)
if name == "main":
interface.launch(show_error=True) # You can set show_error=True to see any server errors