File size: 3,627 Bytes
a17aefb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr

from demo import VideoCLSModel


sample_videos = [
    [
        "data/svitt-ego-demo/0/video/2d560d56-dc47-4c76-8d41-889c8aa55d66-converted.mp4",
        "data/svitt-ego-demo/0/video/eb5cb2b0-59e6-45da-af1b-ba86c7ab0b54-converted.mp4",
        "data/svitt-ego-demo/0/video/0a3097fc-baed-4d11-a4c9-30f07eb91af6-converted.mp4",
        "data/svitt-ego-demo/0/video/1a870d5d-5787-4098-ad8d-fe7343c43698-converted.mp4",
        "data/svitt-ego-demo/0/video/014b473f-aec0-49c7-b394-abc7309ca3c7-converted.mp4",
    ],
    [
        "data/svitt-ego-demo/1/video/029eeb9a-8853-48a4-a1dc-e8868b58adf3-converted.mp4",
        "data/svitt-ego-demo/1/video/968139e2-987e-4615-a2d4-fa2e683bae8a-converted.mp4",
        "data/svitt-ego-demo/1/video/fb9fda68-f264-465d-9208-19876f5ef90f-converted.mp4",
        "data/svitt-ego-demo/1/video/53da674a-089d-428a-a719-e322b2de002b-converted.mp4",
        "data/svitt-ego-demo/1/video/060e07d8-e818-4f9c-9d6b-6504f5fd42a3-converted.mp4",
    ],
    [
        "data/svitt-ego-demo/2/video/fa2f1291-3796-41a6-8f7b-6e7c1491b9b2-converted.mp4",
        "data/svitt-ego-demo/2/video/8d83478f-c5d2-4ac3-a823-e1b2ac7594d7-converted.mp4",
        "data/svitt-ego-demo/2/video/5f6f87ea-e1c3-4868-bb60-22c9e874d056-converted.mp4",
        "data/svitt-ego-demo/2/video/77718528-2de9-48b4-b6b8-e7c602032afb-converted.mp4",
        "data/svitt-ego-demo/2/video/9abbf7f4-68f0-4f52-812f-df2a3df48f7b-converted.mp4",
    ],
    [
        "data/svitt-ego-demo/3/video/2a6b3d10-8da9-4f0e-a681-59ba48a55dbf-converted.mp4",
        "data/svitt-ego-demo/3/video/5afd7421-fb6b-4c65-a09a-716f79a7a935-converted.mp4",
        "data/svitt-ego-demo/3/video/f7aec252-bd4f-4696-8de5-ef7b871e2194-converted.mp4",
        "data/svitt-ego-demo/3/video/84d6855a-242b-44a6-b48d-2db302b5ea7a-converted.mp4",
        "data/svitt-ego-demo/3/video/81fff27c-97c0-483a-ad42-47fa947977a9-converted.mp4",
    ],
]
sample_text = [
    "drops the palm fronds on the ground",
    "stands up",
    "throws nuts in a bowl",
    "puts the speaker and notepad in both hands on a seat",
]
sample_text_dict = {
    "drops the palm fronds on the ground": 0,
    "stands up": 1,
    "throws nuts in a bowl": 2,
    "puts the speaker and notepad in both hands on a seat": 3,
}
num_samples = len(sample_videos[0])
labels = [f"video-{i}" for i in range(num_samples)]

def main():
    svitt = VideoCLSModel(
        "configs/ego_mcq/svitt.yml", 
        sample_videos,
    )
    def predict(text):
        idx = sample_text_dict[text]
        ft_action, gt_action = svitt.predict(idx, text)
        return labels[gt_action], labels[ft_action]

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # SViTT-Ego for Multiple Choice Question
            Choose a sample query and click predict to view the results.
            """
        )
        with gr.Row():        
            with gr.Column():
                videos = [gr.Video(label=labels[i], format='mp4', height=256, min_width=340) for i in range(num_samples)]
            with gr.Column():
                text = gr.Text(label="Query", visible=False)
                label = gr.Text(label="Ground Truth")
                ours = gr.Text(label="SViTT-Ego prediction")
        btn = gr.Button("Predict", variant="primary")
        btn.click(predict, inputs=[text], outputs=[label, ours])
        inputs = [text]
        inputs.extend(videos)
        gr.Examples(examples=[[sample_text[i], x[0], x[1], x[2], x[3], x[4]] for i, x in enumerate(sample_videos)], inputs=inputs)

    demo.launch(share=True)


if __name__ == "__main__":
    main()