Not Appropriately Made for API call

#9
by shaikhmohammedmujammil - opened

import os
import os.path as osp

import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer

HEADER = ("""

VideoGPT: Frontier Multimodal Foundation Models for Video Understanding

""")

device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
"DAMO-NLP-SG/VideoLLaMA3-7B",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.to(device)
processor = AutoProcessor.from_pretrained("DAMO-NLP-SG/VideoLLaMA3-7B", trust_remote_code=True)

example_dir = "./examples"
image_formats = ("png", "jpg", "jpeg")
video_formats = ("mp4",)

image_examples, video_examples = [], []
if example_dir is not None:
example_files = [
osp.join(example_dir, f) for f in os.listdir(example_dir)
]
for example_file in example_files:
if example_file.endswith(image_formats):
image_examples.append([example_file])
elif example_file.endswith(video_formats):
video_examples.append([example_file])

################################################

FIX #1: Make sure we store the file path in a list

################################################
def _on_video_upload(messages, video):
"""
video will be a dict like:
{
"video": "/tmp/path_to_uploaded.mp4",
"subtitles": None
}
We'll store video["video"] inside a list, so _predict can read it.
"""
if video is not None:
# store the path as a list
video_path = video["video"]
messages.append({"role": "user", "content": [video_path]})
return messages, None

################################################

FIX #2: Same logic for images; store in a list

################################################
def _on_image_upload(messages, image):
"""
image will be a dict like:
{
"path": "/tmp/path_to_uploaded.png",
"name": "some_filename.png",
"orig_name": "some_filename.png",
...
}
We'll store image["path"] inside a list.
"""
if image is not None:
image_path = image["path"]
messages.append({"role": "user", "content": [image_path]})
return messages, None

def _on_text_submit(messages, text):
messages.append({"role": "user", "content": text})
return messages, ""

@spaces.GPU(duration=120)
def _predict(messages, input_text, do_sample, temperature, top_p, max_new_tokens,
fps, max_frames):
# If there's fresh text from the user
if len(input_text) > 0:
messages.append({"role": "user", "content": input_text})

new_messages = []
contents = []

for message in messages:
    if message["role"] == "assistant":
        # If we have collected some user content, add it before assistant
        if len(contents):
            new_messages.append({"role": "user", "content": contents})
            contents = []
        new_messages.append(message)

    elif message["role"] == "user":
        # If message is pure text, accumulate it
        if isinstance(message["content"], str):
            contents.append(message["content"])
        else:
            # Expect a list of file paths from _on_video_upload/_on_image_upload
            for media_path in message["content"]:
                if media_path.endswith(video_formats):
                    contents.append({
                        "type": "video",
                        "video": {
                            "video_path": media_path,
                            "fps": fps,
                            "max_frames": max_frames
                        }
                    })
                elif media_path.endswith(image_formats):
                    contents.append({
                        "type": "image",
                        "image": {
                            "image_path": media_path
                        }
                    })
                else:
                    raise ValueError(f"Unsupported media type: {media_path}")

if len(contents):
    new_messages.append({"role": "user", "content": contents})

# If for some reason no user messages, just return
if len(new_messages) == 0 or new_messages[-1]["role"] != "user":
    return messages

generation_config = {
    "do_sample": do_sample,
    "temperature": temperature,
    "top_p": top_p,
    "max_new_tokens": max_new_tokens
}

inputs = processor(
    conversation=new_messages,
    add_system_prompt=True,
    add_generation_prompt=True,
    return_tensors="pt"
)
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
    inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)

streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
    **inputs,
    **generation_config,
    "streamer": streamer,
}

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

messages.append({"role": "assistant", "content": ""})
for token in streamer:
    messages[-1]['content'] += token
    yield messages

with gr.Blocks() as interface:
gr.HTML(HEADER)

with gr.Row():
    chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=835)

    with gr.Column():
        with gr.Tab(label="Input"):
            with gr.Row():
                input_video = gr.Video(sources=["upload"], label="Upload Video")
                input_image = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
            
            input_text = gr.Textbox(
                label="Input Text",
                placeholder="Type your message here and press enter to submit"
            )

            submit_button = gr.Button("Generate")

            gr.Examples(
                examples=[
                    [f"examples/bear.mp4", "What is unusual in the video?"],
                    [f"examples/dog.mp4", "Please describe the video in detail."],
                    [f"examples/exercise.mp4", "What is the man doing in the video?"],
                ],
                inputs=[input_video, input_text],
                label="Video examples"
            )

        with gr.Tab(label="Configure"):
            with gr.Accordion("Generation Config", open=True):
                do_sample = gr.Checkbox(value=True, label="Do Sample")
                temperature = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.2, label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.0, maximum=1.0, value=0.9, label="Top P"
                )
                max_new_tokens = gr.Slider(
                    minimum=0, maximum=4096, value=2048, step=1, label="Max New Tokens"
                )

            with gr.Accordion("Video Config", open=True):
                fps = gr.Slider(
                    minimum=0.0, maximum=10.0, value=1, label="FPS"
                )
                max_frames = gr.Slider(
                    minimum=0, maximum=256, value=180, step=1, label="Max Frames"
                )

# NOTE: the client will call these 3 endpoints:
#  1) /_on_video_upload
#  2) /_on_image_upload
#  3) /_on_text_submit
#  4) /_predict
# They are triggered by Gradio events below:

# Trigger for uploading a video
input_video.change(
    _on_video_upload,
    [chatbot, input_video],
    [chatbot, input_video],
    api_name="/_on_video_upload"
)

# Trigger for uploading an image
input_image.change(
    _on_image_upload,
    [chatbot, input_image],
    [chatbot, input_image],
    api_name="/_on_image_upload"
)

# Trigger for text submission (Enter key in the textbox)
input_text.submit(
    _on_text_submit,
    [chatbot, input_text],
    [chatbot, input_text],
    api_name="/_on_text_submit"
)

# Trigger for "Generate" button => calls _predict
submit_button.click(
    _predict,
    [
        chatbot,
        input_text,
        do_sample,
        temperature,
        top_p,
        max_new_tokens,
        fps,
        max_frames
    ],
    [chatbot],
    api_name="/_predict"
)

if name == "main":
interface.launch(show_error=True) # You can set show_error=True to see any server errors

Sign up or log in to comment