sanghan commited on
Commit
b158f70
·
1 Parent(s): 51dd778

return from temporary directories

Browse files
Files changed (1) hide show
  1. app.py +20 -2
app.py CHANGED
@@ -1,7 +1,11 @@
1
  import av
2
  import torch
 
 
3
  import gradio as gr
4
 
 
 
5
 
6
  def get_video_length_av(video_path):
7
  with av.open(video_path) as container:
@@ -34,6 +38,15 @@ def get_free_memory_gb():
34
  return free_memory / 1024**3
35
 
36
 
 
 
 
 
 
 
 
 
 
37
  def inference(video):
38
  if get_video_length_av(video) > 30:
39
  raise gr.Error("Length of video cannot be over 30 seconds")
@@ -44,13 +57,18 @@ def inference(video):
44
  if torch.cuda.is_available():
45
  model = model.cuda()
46
 
 
 
 
47
  convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
48
  convert_video(
49
  model, # The loaded model, can be on any device (cpu or cuda).
50
  input_source=video, # A video file or an image sequence directory.
51
  downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px.
52
  output_type="video", # Choose "video" or "png_sequence"
53
- output_composition="com.mp4", # File path if video; directory path if png sequence.
 
 
54
  output_alpha=None, # [Optional] Output the raw alpha prediction.
55
  output_foreground=None, # [Optional] Output the raw foreground prediction.
56
  output_video_mbps=4, # Output video mbps. Not needed for png sequence.
@@ -58,7 +76,7 @@ def inference(video):
58
  num_workers=1, # Only for image sequence input. Reader threads.
59
  progress=True, # Print conversion progress.
60
  )
61
- return "com.mp4"
62
 
63
 
64
  if __name__ == "__main__":
 
1
  import av
2
  import torch
3
+ import tempfile
4
+ import shutil
5
  import gradio as gr
6
 
7
+ temp_directories = []
8
+
9
 
10
  def get_video_length_av(video_path):
11
  with av.open(video_path) as container:
 
38
  return free_memory / 1024**3
39
 
40
 
41
+ def cleanup_temp_directories():
42
+ for temp_dir in temp_directories:
43
+ try:
44
+ shutil.rmtree(temp_dir)
45
+ except FileNotFoundError:
46
+ print(f"Could not delete directory {temp_dir}")
47
+ print(f"Temporary directory {temp_dir} has been removed")
48
+
49
+
50
  def inference(video):
51
  if get_video_length_av(video) > 30:
52
  raise gr.Error("Length of video cannot be over 30 seconds")
 
57
  if torch.cuda.is_available():
58
  model = model.cuda()
59
 
60
+ temp_dir = tempfile.mkdtemp()
61
+ temp_directories.append(temp_dir)
62
+
63
  convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
64
  convert_video(
65
  model, # The loaded model, can be on any device (cpu or cuda).
66
  input_source=video, # A video file or an image sequence directory.
67
  downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px.
68
  output_type="video", # Choose "video" or "png_sequence"
69
+ output_composition=(
70
+ temp_dir + "/matted_video.mp4"
71
+ ), # File path if video; directory path if png sequence.
72
  output_alpha=None, # [Optional] Output the raw alpha prediction.
73
  output_foreground=None, # [Optional] Output the raw foreground prediction.
74
  output_video_mbps=4, # Output video mbps. Not needed for png sequence.
 
76
  num_workers=1, # Only for image sequence input. Reader threads.
77
  progress=True, # Print conversion progress.
78
  )
79
+ return temp_dir + "/matted_video.mp4"
80
 
81
 
82
  if __name__ == "__main__":