File size: 5,950 Bytes
dab9336
4cc5cdc
 
 
466d650
4cc5cdc
 
2a4546a
 
4cc5cdc
 
 
 
466d650
 
 
aa3849a
466d650
 
 
 
 
 
 
 
 
 
 
 
 
4cc5cdc
 
 
 
 
 
 
 
 
 
 
 
316a9be
4cc5cdc
 
 
 
 
 
 
8dd73c6
4cc5cdc
 
 
 
 
 
 
acf6483
4cc5cdc
 
 
 
 
 
 
 
 
 
 
 
 
c310219
4cc5cdc
 
c310219
 
 
4cc5cdc
 
 
 
 
 
 
 
 
 
c310219
 
4cc5cdc
 
 
 
c310219
4cc5cdc
3fd8c57
4cc5cdc
dab9336
4cc5cdc
 
 
 
 
 
6a70142
df16240
4cc5cdc
 
 
e9310bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os, torchvision, transformers, subprocess, huggingface_hub, time
from functools import partial
import gradio as gr

from inference import LiveInfer
logger = transformers.logging.get_logger('liveinfer')

huggingface_hub.login(os.getenv('HF_TOKEN'))

# python -m demo.app --resume_from_checkpoint ... 

liveinfer = LiveInfer()

def ffmpeg_once(src_path: str, dst_path: str, *, fps: int = None, resolution: int = None, pad: str = '#000000', mode='bicubic'):
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    command = [
        './ffmpeg/ffmpeg',
        '-y',
        '-sws_flags', mode,
        '-i', src_path,
        '-an',
        '-threads', '10',
    ]
    if fps is not None:
        command += ['-r', str(fps)]
    if resolution is not None:
        command += ['-vf', f"scale='if(gt(iw\\,ih)\\,{resolution}\\,-2)':'if(gt(iw\\,ih)\\,-2\\,{resolution})',pad={resolution}:{resolution}:(ow-iw)/2:(oh-ih)/2:color='{pad}'"]
    command += [dst_path]
    subprocess.run(command, check=True)
    
css = """
    #gr_title {text-align: center;}
    #gr_video {max-height: 480px;}
    #gr_chatbot {max-height: 480px;}
"""

with gr.Blocks(title="VideoLLM-online", css=css) as demo:
    gr.Markdown("# VideoLLM-online: Online Video Large Language Model for Streaming Video", elem_id='gr_title')
    with gr.Row():
        with gr.Column():
            gr_video = gr.Video(label="video stream", elem_id="gr_video", visible=True, sources=['upload'], autoplay=True)
            gr_examples = gr.Examples(
                examples=[["cooking.mp4"], ["bicycle.mp4"]],
                inputs=gr_video,
                outputs=gr_video,
                label="Examples"
            )
            gr.Markdown("## Tips:")
            gr.Markdown("- When you upload/click a video, the model starts processing the video stream. You can input a query before or after that, at any point during the video as you like.")
            gr.Markdown("- **Gradio refreshes the chatbot box to update the answer, which will delay the program. If you want to enjoy faster demo as we show in teaser video, please use https://github.com/showlab/videollm-online/blob/main/demo/cli.py.**")
            gr.Markdown("- This work is primarily done at a university, and our resources are limited. Our model is trained with limited data, so it cannot solve very complicated questions and may produce hallucination. However, we have seen the potential of 'learning in streaming'. We are working on new data method to scale streaming dialogue data to our next model.")
        
        with gr.Column():
            gr_chat_interface = gr.ChatInterface(
                fn=liveinfer.input_query_stream,
                chatbot=gr.Chatbot(
                    elem_id="gr_chatbot",
                    label='chatbot',
                    avatar_images=('user_avatar.png', 'assistant_avatar.png'),
                    render=False
                ),
                examples=['Please narrate the video in real time.', 'Please describe what I am doing.', 'Could you summarize what have been done?', 'Hi, guide me the next step.'],
            )
            
            def gr_frame_token_interval_threshold_change(frame_token_interval_threshold):
                liveinfer.frame_token_interval_threshold = frame_token_interval_threshold
            gr_frame_token_interval_threshold = gr.Slider(minimum=0, maximum=1, step=0.05, value=liveinfer.frame_token_interval_threshold, interactive=True, label="Streaming Threshold")
            gr_frame_token_interval_threshold.change(gr_frame_token_interval_threshold_change, inputs=[gr_frame_token_interval_threshold])

        gr_video_time = gr.Number(value=0, visible=False)
        gr_liveinfer_queue_refresher = gr.Number(value=False, visible=False)

        def gr_video_change(src_video_path, history, video_time, gate):
            name, ext = os.path.splitext(src_video_path)
            ffmpeg_video_path = os.path.join('demo/assets/cache', name + f'_{liveinfer.frame_fps}fps_{liveinfer.frame_resolution}' + ext)
            if ffmpeg_video_path == liveinfer.video_path:
                return 
            liveinfer.video_path = ffmpeg_video_path
            if not os.path.exists(ffmpeg_video_path):
                os.makedirs(os.path.dirname(ffmpeg_video_path), exist_ok=True)
                ffmpeg_once(src_video_path, ffmpeg_video_path, fps=liveinfer.frame_fps, resolution=liveinfer.frame_resolution)
                logger.warning(f'{src_video_path} -> {ffmpeg_video_path}, {liveinfer.frame_fps} FPS, {liveinfer.frame_resolution} Resolution')
            liveinfer.load_video(ffmpeg_video_path)
            liveinfer.input_video_stream(0)
            query, response = liveinfer()
            if query or response:
                history.append((query, response))
            return history, video_time + 1 / liveinfer.frame_fps, not gate
        gr_video.change(
            gr_video_change, inputs=[gr_video, gr_chat_interface.chatbot, gr_video_time, gr_liveinfer_queue_refresher], 
            outputs=[gr_chat_interface.chatbot, gr_video_time, gr_liveinfer_queue_refresher]
        )
        
        def gr_video_time_change(_, video_time):
            video_time += 1 / liveinfer.frame_fps
            liveinfer.input_video_stream(video_time)
            print(video_time)
            return video_time
        gr_video_time.change(gr_video_time_change, [gr_video, gr_video_time], [gr_video_time])

        def gr_liveinfer_queue_refresher_change(history):
            while True:
                query, response = liveinfer()
                if query or response:
                    history[-1][1] += f'\n{response}'
                    print(history)
                    yield history
        gr_liveinfer_queue_refresher.change(gr_liveinfer_queue_refresher_change, inputs=[gr_chat_interface.chatbot], outputs=[gr_chat_interface.chatbot])
    
    demo.queue()
    demo.launch(share=False, debug=True)