SkalskiP commited on
Commit
44b03d2
1 Parent(s): 5bd8d5b

drop ZERO GPU

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. utils/__init__.py +0 -0
  3. utils/models.py +6 -0
app.py CHANGED
@@ -2,7 +2,6 @@ from typing import Optional
2
 
3
  import gradio as gr
4
  import numpy as np
5
- import spaces
6
  import supervision as sv
7
  import torch
8
  from PIL import Image
@@ -18,14 +17,15 @@ video by treating images as single-frame videos. Its design, a simple transforme
18
  architecture with streaming memory, enables real-time video processing.
19
  """
20
 
21
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
22
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
23
  CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
24
  CONFIG = "sam2_hiera_l.yaml"
25
 
26
  sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
27
 
28
- @spaces.GPU
29
  def process(image_input) -> Optional[Image.Image]:
30
  mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
31
  image = np.array(image_input.convert("RGB"))
 
2
 
3
  import gradio as gr
4
  import numpy as np
 
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
 
17
  architecture with streaming memory, enables real-time video processing.
18
  """
19
 
20
+ # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ DEVICE = torch.device('cuda')
22
+
23
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
24
  CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
25
  CONFIG = "sam2_hiera_l.yaml"
26
 
27
  sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
28
 
 
29
  def process(image_input) -> Optional[Image.Image]:
30
  mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
31
  image = np.array(image_input.convert("RGB"))
utils/__init__.py ADDED
File without changes
utils/models.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ CHECKPOINTS = {
2
+ "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
3
+ "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
4
+ "base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
5
+ "large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
6
+ }