sanghan commited on
Commit
2619be1
·
1 Parent(s): 417d801

run with concurrency

Browse files
Files changed (1) hide show
  1. app.py +22 -3
app.py CHANGED
@@ -1,12 +1,31 @@
1
- import os
2
  import torch
3
  import gradio as gr
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
6
 
7
  if torch.cuda.is_available():
8
- print("Using GPU")
 
9
  model = model.cuda()
 
 
 
 
10
 
11
  convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
12
 
@@ -47,4 +66,4 @@ with gr.Blocks(title="Robust Video Matting") as block:
47
  "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
48
  )
49
 
50
- block.queue(api_open=False, max_size=5).launch()
 
 
1
  import torch
2
  import gradio as gr
3
 
4
+
5
+ def get_free_memory_gb():
6
+ gpu_index = torch.cuda.current_device()
7
+ # Get the GPU's properties
8
+ gpu_properties = torch.cuda.get_device_properties(gpu_index)
9
+
10
+ # Get the total and allocated memory
11
+ total_memory = gpu_properties.total_memory
12
+ allocated_memory = torch.cuda.memory_allocated(gpu_index)
13
+
14
+ # Calculate the free memory
15
+ free_memory = total_memory - allocated_memory
16
+ return free_memory / 1024**3
17
+
18
+
19
  model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
20
 
21
  if torch.cuda.is_available():
22
+ free_memory = get_free_memory_gb()
23
+ concurrency_count = int(free_memory // 7.4)
24
  model = model.cuda()
25
+ print(f"Using GPU with concurrency: {concurrency_count}")
26
+ else:
27
+ print("Using CPU")
28
+ concurrency_count = 1
29
 
30
  convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
31
 
 
66
  "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.11515'>Robust High-Resolution Video Matting with Temporal Guidance</a> | <a href='https://github.com/PeterL1n/RobustVideoMatting'>Github Repo</a></p>"
67
  )
68
 
69
+ block.queue(api_open=False, max_size=5, concurrency_count=concurrency_count).launch()