File size: 5,202 Bytes
11dbf82
734ce72
bf3bbe6
b158f70
 
1c89dfa
734ce72
bf3bbe6
c1cd135
734ce72
 
2619be1
11dbf82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2619be1
 
 
 
 
 
 
 
 
 
 
b158f70
1c89dfa
b158f70
 
 
 
 
 
 
734ce72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf3bbe6
013bae8
11dbf82
013bae8
11dbf82
 
b158f70
 
 
758e6c4
c1cd135
efac2d0
 
 
734ce72
efac2d0
 
 
 
 
 
c1cd135
734ce72
 
 
 
c1cd135
 
11dbf82
1c89dfa
 
 
758e6c4
 
 
371ea96
11dbf82
 
 
 
 
371ea96
11dbf82
 
 
 
 
 
 
 
 
 
758e6c4
11dbf82
 
 
 
 
 
 
 
 
 
 
bf3bbe6
11dbf82
758e6c4
11dbf82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import av
import os
import torch
import tempfile
import shutil
import atexit
import subprocess
import gradio as gr

from convert import convert_video


def get_video_length_av(video_path):
    with av.open(video_path) as container:
        stream = container.streams.video[0]
        if container.duration is not None:
            duration_in_seconds = float(container.duration) / av.time_base
        else:
            duration_in_seconds = stream.duration * stream.time_base

    return duration_in_seconds


def get_video_dimensions(video_path):
    with av.open(video_path) as container:
        video_stream = container.streams.video[0]
        width = video_stream.width
        height = video_stream.height

    return width, height


def get_free_memory_gb():
    gpu_index = torch.cuda.current_device()
    gpu_properties = torch.cuda.get_device_properties(gpu_index)

    total_memory = gpu_properties.total_memory
    allocated_memory = torch.cuda.memory_allocated(gpu_index)

    free_memory = total_memory - allocated_memory
    return free_memory / 1024**3


def cleanup_temp_directories():
    print("Deleting temporary files")
    for temp_dir in temp_directories:
        try:
            shutil.rmtree(temp_dir)
        except FileNotFoundError:
            print(f"Could not delete directory {temp_dir}")


def ffmpeg_remux_audio(source_video_path, dest_video_path, output_path):
    # Build the ffmpeg command to extract audio and remux into another video
    command = [
        "ffmpeg",
        "-i",
        dest_video_path,  # Input destination video file
        "-i",
        source_video_path,  # Input source video file (for the audio)
        "-c:v",
        "copy",  # Copy the video stream as is
        "-c:a",
        "copy",  # Copy the audio stream as is
        "-map",
        "0:v:0",  # Map the video stream from the destination file
        "-map",
        "1:a:0",  # Map the audio stream from the source file
        output_path,  # Specify the output file path
    ]

    try:
        # Run the ffmpeg command
        subprocess.run(command, check=True)
    except subprocess.CalledProcessError as e:
        # Handle errors during the subprocess execution
        print(f"An error occurred: {e}")
        return dest_video_path

    return output_path


def inference(video):
    if get_video_length_av(video) > 30:
        raise gr.Error("Length of video cannot be over 30 seconds")
    if get_video_dimensions(video) > (1920, 1920):
        raise gr.Error("Video resolution must not be higher than 1920x1080")

    temp_dir = tempfile.mkdtemp()
    temp_directories.append(temp_dir)

    output_composition = temp_dir + "/matted_video.mp4"
    convert_video(
        model,  # The loaded model, can be on any device (cpu or cuda).
        input_source=video,  # A video file or an image sequence directory.
        downsample_ratio=0.25,  # [Optional] If None, make downsampled max size be 512px.
        output_composition=output_composition,  # File path if video; directory path if png sequence.
        output_alpha=None,  # [Optional] Output the raw alpha prediction.
        output_foreground=None,  # [Optional] Output the raw foreground prediction.
        output_video_mbps=4,  # Output video mbps. Not needed for png sequence.
        seq_chunk=12,  # Process n frames at once for better parallelism.
        num_workers=1,  # Only for image sequence input. Reader threads.
        progress=True,  # Print conversion progress.
    )

    resulting_video = f"{temp_dir}/matted_{os.path.split(video)[1]}"

    return ffmpeg_remux_audio(video, output_composition, resulting_video)


if __name__ == "__main__":
    temp_directories = []
    atexit.register(cleanup_temp_directories)

    model = torch.hub.load(
        "PeterL1n/RobustVideoMatting", "mobilenetv3", trust_repo=True
    )

    if torch.cuda.is_available():
        free_memory = get_free_memory_gb()
        concurrency_count = int(free_memory // 7)
        print(f"Using GPU with concurrency: {concurrency_count}")
        print(f"Available video memory: {free_memory} GB")
        model = model.cuda()
    else:
        print("Using CPU")
        concurrency_count = 1

    with gr.Blocks(title="Robust Video Matting") as block:
        gr.Markdown("# Robust Video Matting")
        gr.Markdown(
            "Gradio demo for Robust Video Matting. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
        )
        with gr.Row():
            inp = gr.Video(label="Input Video", sources=["upload"], include_audio=True)
            out = gr.Video(label="Output Video")
        btn = gr.Button("Run")
        btn.click(inference, inputs=inp, outputs=out)

        gr.Examples(
            examples=[["example.mp4"]],
            inputs=[inp],
        )
        gr.HTML(
            "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
        )

    block.queue(
        api_open=False, max_size=5, default_concurrency_limit=concurrency_count
    ).launch()