sanghan commited on
Commit
1c89dfa
·
1 Parent(s): b158f70

load models outside functions

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -2,10 +2,9 @@ import av
2
  import torch
3
  import tempfile
4
  import shutil
 
5
  import gradio as gr
6
 
7
- temp_directories = []
8
-
9
 
10
  def get_video_length_av(video_path):
11
  with av.open(video_path) as container:
@@ -39,12 +38,12 @@ def get_free_memory_gb():
39
 
40
 
41
  def cleanup_temp_directories():
 
42
  for temp_dir in temp_directories:
43
  try:
44
  shutil.rmtree(temp_dir)
45
  except FileNotFoundError:
46
  print(f"Could not delete directory {temp_dir}")
47
- print(f"Temporary directory {temp_dir} has been removed")
48
 
49
 
50
  def inference(video):
@@ -53,14 +52,9 @@ def inference(video):
53
  if get_video_dimensions(video) > (1920, 1920):
54
  raise gr.Error("Video resolution must not be higher than 1920x1080")
55
 
56
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
57
- if torch.cuda.is_available():
58
- model = model.cuda()
59
-
60
  temp_dir = tempfile.mkdtemp()
61
  temp_directories.append(temp_dir)
62
 
63
- convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
64
  convert_video(
65
  model, # The loaded model, can be on any device (cpu or cuda).
66
  input_source=video, # A video file or an image sequence directory.
@@ -80,11 +74,18 @@ def inference(video):
80
 
81
 
82
  if __name__ == "__main__":
 
 
 
 
 
 
83
  if torch.cuda.is_available():
84
  free_memory = get_free_memory_gb()
85
  concurrency_count = int(free_memory // 7)
86
  print(f"Using GPU with concurrency: {concurrency_count}")
87
  print(f"Available video memory: {free_memory} GB")
 
88
  else:
89
  print("Using CPU")
90
  concurrency_count = 1
 
2
  import torch
3
  import tempfile
4
  import shutil
5
+ import atexit
6
  import gradio as gr
7
 
 
 
8
 
9
  def get_video_length_av(video_path):
10
  with av.open(video_path) as container:
 
38
 
39
 
40
  def cleanup_temp_directories():
41
+ print("Deleting temporary files")
42
  for temp_dir in temp_directories:
43
  try:
44
  shutil.rmtree(temp_dir)
45
  except FileNotFoundError:
46
  print(f"Could not delete directory {temp_dir}")
 
47
 
48
 
49
  def inference(video):
 
52
  if get_video_dimensions(video) > (1920, 1920):
53
  raise gr.Error("Video resolution must not be higher than 1920x1080")
54
 
 
 
 
 
55
  temp_dir = tempfile.mkdtemp()
56
  temp_directories.append(temp_dir)
57
 
 
58
  convert_video(
59
  model, # The loaded model, can be on any device (cpu or cuda).
60
  input_source=video, # A video file or an image sequence directory.
 
74
 
75
 
76
  if __name__ == "__main__":
77
+ temp_directories = []
78
+ atexit.register(cleanup_temp_directories)
79
+
80
+ model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
81
+ convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
82
+
83
  if torch.cuda.is_available():
84
  free_memory = get_free_memory_gb()
85
  concurrency_count = int(free_memory // 7)
86
  print(f"Using GPU with concurrency: {concurrency_count}")
87
  print(f"Available video memory: {free_memory} GB")
88
+ model = model.cuda()
89
  else:
90
  print("Using CPU")
91
  concurrency_count = 1